From 96bfe01184f19da638608e0b3f889afae16fbcb5 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Sat, 19 Oct 2024 13:30:11 +0100 Subject: [PATCH 1/6] Don't use optional fields for dependency graph --- .../validation_pipeline/validations/cycle.rs | 81 ++++++++++++++----- .../class/mutually_recursive_types.baml | 20 +++++ .../class/recursive_type.baml | 4 + 3 files changed, 83 insertions(+), 22 deletions(-) create mode 100644 engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml create mode 100644 engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs index 1c7f33306..eb75d9575 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs @@ -1,40 +1,54 @@ use std::collections::HashSet; +use either::Either; use internal_baml_diagnostics::DatamodelError; -use internal_baml_schema_ast::ast::{TypeExpId, WithIdentifier, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan}; use crate::validate::validation_pipeline::context::Context; +/// Validates if there's a cycle in any dependency graph. pub(super) fn validate(ctx: &mut Context<'_>) { - // Validates if there's a cycle in any dependency graph. - let mut deps_list = ctx + // We're only going to consider type dependencies that can actually cause + // infinite recursion. Unions and optionals can stop the recursion at any + // point, so they don't have to be part of the "dependency" graph because + // technically an optional field doesn't "depend" on anything, it can just + // be null. + let mut required_deps = ctx .db .walk_classes() - .map(|f| { - ( - f.id, - f.dependencies() - .into_iter() - .filter(|f| match ctx.db.find_type_by_str(f) { - Some(either::Either::Left(_cls)) => true, - // Don't worry about enum dependencies, they can't form cycles. - Some(either::Either::Right(_enm)) => false, - None => { - panic!("Unknown class `{}`", f); - } - }) - .collect::>(), - ) + .map(|cls| { + let expr_block = &ctx.db.ast()[cls.class_id()]; + + // TODO: There's already a hash set that returns "dependencies" in + // the DB, it shoudn't be necessary to traverse all the fields here + // again, we need to refactor .dependencies() or add a new method + // that returns not only the dependency name but also field arity. + // The arity could be computed at the same time as the dependencies + // hash set. Code is here: + // + // baml-lib/parser-database/src/types/mod.rs + // fn visit_class() + let mut deps = HashSet::new(); + + for field in &expr_block.fields { + if let Some(field_type) = &field.expr { + insert_deps(field_type, ctx, &mut deps); + } + } + + (cls.id, deps) }) .collect::>(); + // println!("{:?}", required_deps); + // Now we can check for cycles using topological sort. let mut stack: Vec<(TypeExpId, Vec)> = Vec::new(); // This stack now also keeps track of the path let mut visited = HashSet::new(); let mut in_stack = HashSet::new(); // Find all items with 0 dependencies - for (id, deps) in &deps_list { + for (id, deps) in &required_deps { if deps.is_empty() { stack.push((*id, vec![*id])); } @@ -70,7 +84,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) { in_stack.insert(current); visited.insert(current); - deps_list.iter_mut().for_each(|(id, deps)| { + required_deps.iter_mut().for_each(|(id, deps)| { if deps.remove(&name) { // If this item has now 0 dependencies, add it to the stack if deps.is_empty() { @@ -85,8 +99,8 @@ pub(super) fn validate(ctx: &mut Context<'_>) { } // If there are still items left in deps_list after the above steps, there's a cycle - if visited.len() != deps_list.len() { - for (id, _) in &deps_list { + if visited.len() != required_deps.len() { + for (id, _) in &required_deps { if !visited.contains(id) { let cls = &ctx.db.ast()[*id]; ctx.push_error(DatamodelError::new_validation_error( @@ -97,3 +111,26 @@ pub(super) fn validate(ctx: &mut Context<'_>) { } } } + +/// Inserts all the required dependencies of a field into the given set. +/// +/// Recursively deals with unions of unions. Can be implemented iteratively with +/// a while loop and a stack/queue if this ends up being slow / inefficient. +fn insert_deps(field: &FieldType, ctx: &Context<'_>, deps: &mut HashSet) { + match field { + FieldType::Symbol(arity, ident, _) if arity.is_required() => { + let name = ident.name(); + if let Some(Either::Left(_cls_dep)) = ctx.db.find_type_by_str(&name) { + deps.insert(name.to_string()); + } + } + + FieldType::Union(arity, field_types, _, _) if arity.is_required() => { + for f in field_types { + insert_deps(f, ctx, deps); + } + } + + _ => {} + } +} diff --git a/engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml b/engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml new file mode 100644 index 000000000..949b5fd85 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml @@ -0,0 +1,20 @@ +// class Tree { +// data int +// children Forest +// } + +// class Forest { +// trees Tree[] +// } + +class One { + t Two +} + +class Two { + t Three +} + +class Three { + o One +} \ No newline at end of file diff --git a/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml b/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml new file mode 100644 index 000000000..9142bdc3a --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml @@ -0,0 +1,4 @@ +class Node { + data int + next Node? +} \ No newline at end of file From 97f5f93b0c442f226a497673d6df2a817dc4bc2b Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Sun, 20 Oct 2024 04:48:08 +0100 Subject: [PATCH 2/6] Implement Tarjan's strongly connected components algorithm --- .../validation_pipeline/validations/cycle.rs | 363 +++++++++++++----- .../class/mutually_recursive_types.baml | 4 + engine/baml-lib/parser-database/src/lib.rs | 35 +- engine/baml-lib/schema-ast/src/ast.rs | 7 + 4 files changed, 288 insertions(+), 121 deletions(-) diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs index eb75d9575..331db99f4 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs @@ -1,136 +1,291 @@ -use std::collections::HashSet; +use std::{ + cmp, + collections::{HashMap, HashSet, VecDeque}, +}; use either::Either; use internal_baml_diagnostics::DatamodelError; -use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithIdentifier, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithName, WithSpan}; use crate::validate::validation_pipeline::context::Context; -/// Validates if there's a cycle in any dependency graph. +/// Validates if there's a cycle in the dependency graph. pub(super) fn validate(ctx: &mut Context<'_>) { - // We're only going to consider type dependencies that can actually cause - // infinite recursion. Unions and optionals can stop the recursion at any - // point, so they don't have to be part of the "dependency" graph because - // technically an optional field doesn't "depend" on anything, it can just - // be null. - let mut required_deps = ctx - .db - .walk_classes() - .map(|cls| { - let expr_block = &ctx.db.ast()[cls.class_id()]; - - // TODO: There's already a hash set that returns "dependencies" in - // the DB, it shoudn't be necessary to traverse all the fields here - // again, we need to refactor .dependencies() or add a new method - // that returns not only the dependency name but also field arity. - // The arity could be computed at the same time as the dependencies - // hash set. Code is here: - // - // baml-lib/parser-database/src/types/mod.rs - // fn visit_class() - let mut deps = HashSet::new(); - - for field in &expr_block.fields { - if let Some(field_type) = &field.expr { - insert_deps(field_type, ctx, &mut deps); - } - } - - (cls.id, deps) - }) - .collect::>(); + // First, build a graph of all the "required" dependencies represented as an + // adjacency list. We're only going to consider type dependencies that can + // actually cause infinite recursion. Unions and optionals can stop the + // recursion at any point, so they don't have to be part of the "dependency" + // graph because technically an optional field doesn't "depend" on anything, + // it can just be null. + let dependency_graph = HashMap::from_iter(ctx.db.walk_classes().map(|class| { + let expr_block = &ctx.db.ast()[class.id]; - // println!("{:?}", required_deps); + // TODO: There's already a hash set that returns "dependencies" in + // the DB, it shoudn't be necessary to traverse all the fields here + // again and build yet another graph, we need to refactor + // .dependencies() or add a new method that returns not only the + // dependency name but also field arity. The arity could be computed at + // the same time as the dependencies hash set. Code is here: + // + // baml-lib/parser-database/src/types/mod.rs + // fn visit_class() + let mut dependencies = HashSet::new(); - // Now we can check for cycles using topological sort. - let mut stack: Vec<(TypeExpId, Vec)> = Vec::new(); // This stack now also keeps track of the path - let mut visited = HashSet::new(); - let mut in_stack = HashSet::new(); - - // Find all items with 0 dependencies - for (id, deps) in &required_deps { - if deps.is_empty() { - stack.push((*id, vec![*id])); + for field in &expr_block.fields { + if let Some(field_type) = &field.expr { + insert_required_deps(field_type, ctx, &mut dependencies); + } } - } - while let Some((current, path)) = stack.pop() { - let name = ctx.db.ast()[current].name().to_string(); - let span = ctx.db.ast()[current].span(); - - if in_stack.contains(¤t) { - let cycle_start_index = match path.iter().position(|&x| x == current) { - Some(index) => index, - None => { - ctx.push_error(DatamodelError::new_validation_error( - "Cycle start index not found in the path.", - span.clone(), - )); - return; - } - }; - let cycle = path[cycle_start_index..] - .iter() - .map(|&x| ctx.db.ast()[x].name()) - .collect::>() - .join(" -> "); - ctx.push_error(DatamodelError::new_validation_error( - &format!("These classes form a dependency cycle: {}", cycle), - span.clone(), - )); - return; - } + (class.id, dependencies) + })); - in_stack.insert(current); - visited.insert(current); + for component in Tarjan::components(&dependency_graph) { + let cycle = component + .iter() + .map(|id| ctx.db.ast()[*id].name().to_string()) + .collect::>() + .join(" -> "); - required_deps.iter_mut().for_each(|(id, deps)| { - if deps.remove(&name) { - // If this item has now 0 dependencies, add it to the stack - if deps.is_empty() { - let mut new_path = path.clone(); - new_path.push(*id); - stack.push((*id, new_path)); - } - } - }); - - in_stack.remove(¤t); - } - - // If there are still items left in deps_list after the above steps, there's a cycle - if visited.len() != required_deps.len() { - for (id, _) in &required_deps { - if !visited.contains(id) { - let cls = &ctx.db.ast()[*id]; - ctx.push_error(DatamodelError::new_validation_error( - &format!("These classes form a dependency cycle: {}", cls.name()), - cls.identifier().span().clone(), - )); - } - } + ctx.push_error(DatamodelError::new_validation_error( + &format!("These classes form a dependency cycle: {}", cycle), + ctx.db.ast()[component[0]].span().clone(), + )); } } /// Inserts all the required dependencies of a field into the given set. /// /// Recursively deals with unions of unions. Can be implemented iteratively with -/// a while loop and a stack/queue if this ends up being slow / inefficient. -fn insert_deps(field: &FieldType, ctx: &Context<'_>, deps: &mut HashSet) { +/// a while loop and a stack/queue if this ends up being slow / inefficient or +/// it reaches stack overflows with large inputs. +fn insert_required_deps(field: &FieldType, ctx: &Context<'_>, deps: &mut HashSet) { match field { FieldType::Symbol(arity, ident, _) if arity.is_required() => { - let name = ident.name(); - if let Some(Either::Left(_cls_dep)) = ctx.db.find_type_by_str(&name) { - deps.insert(name.to_string()); + if let Some(Either::Left(class)) = ctx.db.find_type_by_str(ident.name()) { + deps.insert(class.id); } } FieldType::Union(arity, field_types, _, _) if arity.is_required() => { for f in field_types { - insert_deps(f, ctx, deps); + insert_required_deps(f, ctx, deps); } } _ => {} } } + +/// Dependency graph represented as an adjacency list. +type Graph = HashMap>; + +/// State of each node for Tarjan's algorithm. +#[derive(Clone, Copy)] +struct NodeState { + /// Node unique index. + index: usize, + /// Low link value. + /// + /// Represents the smallest index of any node on the stack known to be + /// reachable from `self` through `self`'s DFS subtree. + low_link: usize, + /// Whether the node is on the stack. + on_stack: bool, +} + +/// Tarjan's strongly connected components algorithm implementation. +/// +/// This algorithm finds and returns all the cycles in a graph. Read more about +/// it [here](https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm). +/// +/// This struct is simply bookkeeping for the algorithm, it can be implemented +/// with just function calls but the recursive one would need 6 parameters which +/// is pretty ugly. +struct Tarjan<'g> { + /// Ref to the depdenency graph. + graph: &'g Graph, + /// Node number counter. + index: usize, + /// Nodes are placed on a stack in the order in which they are visited. + stack: Vec, + /// State of each node. + state: HashMap, + /// Strongly connected components. + components: Vec>, +} + +impl<'g> Tarjan<'g> { + /// Unvisited node marker. + /// + /// Technically we should use [`Option`] and [`None`] for + /// [`NodeState::index`] and [`NodeState::low_link`] but that would require + /// some ugly and repetitive [`Option::unwrap`] calls. [`usize::MAX`] won't + /// be reached as an index anyway, the algorithm will stack overflow much + /// sooner than that :/ + const UNVISITED: usize = usize::MAX; + + /// Public entry point for the algorithm. + /// + /// Loops through all the nodes in the graph and visits them if they haven't + /// been visited already. When the algorithm is done, [`Self::components`] + /// will contain all the cycles in the graph. + pub fn components(graph: &'g Graph) -> Vec> { + let mut tarjans = Self { + graph, + index: 0, + stack: Vec::new(), + state: HashMap::from_iter(graph.keys().map(|&node| { + let state = NodeState { + index: Self::UNVISITED, + low_link: Self::UNVISITED, + on_stack: false, + }; + + (node, state) + })), + components: Vec::new(), + }; + + for node in tarjans.graph.keys() { + if tarjans.state[node].index == Self::UNVISITED { + tarjans.strong_connect(*node); + } + } + + tarjans.components + } + + /// Recursive DFS. + /// + /// This is where the "algorithm" runs. Again, could be implemented + /// iteratively if needed at some point. + fn strong_connect(&mut self, node_id: TypeExpId) { + // Initialize node state. This node has not yet been visited so we don't + // have to grab the state from the hash map. And if we did, then we'd + // have to fight the borrow checker by taking mut refs and unique refs + // over and over again as needed (which requires hashing the same entry + // many times and is not as readable). + let mut node = NodeState { + index: self.index, + low_link: self.index, + on_stack: true, + }; + + // Increment index and push node to stack. + self.index += 1; + self.stack.push(node_id); + + // Visit neighbors to find strongly connected components. + for successor_id in &self.graph[&node_id] { + // Grab owned state to circumvent borrow checker. + let mut successor = *&self.state[successor_id]; + if successor.index == Self::UNVISITED { + // Make sure state is updated before the recursive call. + self.state.insert(node_id, node); + self.strong_connect(*successor_id); + // Grab updated state after recursive call. + successor = *&self.state[successor_id]; + node.low_link = cmp::min(node.low_link, successor.low_link); + } else if successor.on_stack { + node.low_link = cmp::min(node.low_link, successor.index); + } + } + + // Update state in case we haven't already. We store this in a hash map + // so we have to run the hashing algorithm every time we update the + // state. Keep it to a minimum :) + self.state.insert(node_id, node); + + if node.low_link == node.index { + let mut component = Vec::new(); + + while let Some(successor_id) = self.stack.pop() { + // This should not fail since all nodes should be stored in + // the state hash map. + if let Some(successor) = self.state.get_mut(&successor_id) { + successor.on_stack = false; + } + + component.push(successor_id); + + if successor_id == node_id { + break; + } + } + + // Path should be shown as parent -> child. + // TODO: The start node is random because hash maps. A simple fix + // is to consider that the cycle starts at the node with the + // smallest ID. + component.reverse(); + + self.components.push(component); + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::{HashMap, HashSet}; + + use internal_baml_schema_ast::ast::TypeExpId; + + use super::Tarjan; + + fn type_exp_ids(ids: &[u32]) -> impl Iterator + '_ { + ids.iter().copied().map(TypeExpId::from) + } + + fn graph(from: &[(u32, &[u32])]) -> HashMap> { + HashMap::from_iter(from.iter().map(|(node, successors)| { + (TypeExpId::from(*node), type_exp_ids(&successors).collect()) + })) + } + + fn expected_components(components: &[&[u32]]) -> Vec> { + components + .iter() + .map(|ids| type_exp_ids(ids).collect()) + .collect() + } + + /// Ignores the graph cycle path. + /// + /// The graph is stored in a HashMap so Tarjan's algorithm will not always + /// follow the same path due to random state. We can't use Vecs to compare + /// determinstically so we'll just ignore the cycle path and compare the + /// nodes that form the cycle. + /// + /// TODO: Implement the fix mentioned in the implementation and this won't + /// be necessary. + fn ignore_path(components: Vec>) -> Vec> { + components + .into_iter() + .map(|v| v.into_iter().collect()) + .collect() + } + + fn assert_eq_components(actual: Vec>, expected: Vec>) { + assert_eq!(ignore_path(actual), ignore_path(expected)); + } + + #[test] + fn find_cycles() { + let graph = graph(&[ + (0, &[1]), + (1, &[2]), + (2, &[0]), + (3, &[1, 2, 4]), + (4, &[5, 3]), + (5, &[2, 6]), + (6, &[5]), + (7, &[4, 6, 7]), + ]); + + assert_eq_components( + Tarjan::components(&graph), + expected_components(&[&[0, 1, 2], &[5, 6], &[3, 4], &[7]]), + ); + } +} diff --git a/engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml b/engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml index 949b5fd85..5a2b69f87 100644 --- a/engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml +++ b/engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml @@ -17,4 +17,8 @@ class Two { class Three { o One +} + +class Four { + f Four } \ No newline at end of file diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index 5fe20ef7a..d4287206b 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -190,24 +190,25 @@ impl ParserDatabase { } } - if max_loops == 0 && !deps.is_empty() { - let circular_deps = deps - .iter() - .map(|(k, _)| self.ast[*k].name()) - .collect::>() - .join(" -> "); + // TODO: Is this code necessary? Dependency cycles are already checked in the validation. + // if max_loops == 0 && !deps.is_empty() { + // let circular_deps = deps + // .iter() + // .map(|(k, _)| self.ast[*k].name()) + // .collect::>() + // .join(" -> "); - deps.iter().for_each(|(k, _)| { - diag.push_error(DatamodelError::new_validation_error( - &format!( - "Circular dependency detected for class `{}`.\n{}", - self.ast[*k].name(), - circular_deps - ), - self.ast[*k].identifier().span().clone(), - )); - }); - } + // deps.iter().for_each(|(k, _)| { + // diag.push_error(DatamodelError::new_validation_error( + // &format!( + // "Circular dependency detected for class `{}`.\n{}", + // self.ast[*k].name(), + // circular_deps + // ), + // self.ast[*k].identifier().span().clone(), + // )); + // }); + // } // Additionally ensure the same thing for functions, but since we've already handled classes, // this should be trivial. diff --git a/engine/baml-lib/schema-ast/src/ast.rs b/engine/baml-lib/schema-ast/src/ast.rs index df322df8f..8238d5902 100644 --- a/engine/baml-lib/schema-ast/src/ast.rs +++ b/engine/baml-lib/schema-ast/src/ast.rs @@ -84,6 +84,13 @@ impl SchemaAst { /// An opaque identifier for an enum in a schema AST. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct TypeExpId(u32); + +impl From for TypeExpId { + fn from(id: u32) -> Self { + TypeExpId(id) + } +} + impl std::ops::Index for SchemaAst { type Output = TypeExpressionBlock; From 12c8b92bfd329b8e675894e7e1e1b166acf17c83 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Sun, 20 Oct 2024 17:39:27 +0100 Subject: [PATCH 3/6] Skip strongly connected components with no cycles --- .../validation_pipeline/validations/cycle.rs | 86 ++++++++++++------- 1 file changed, 54 insertions(+), 32 deletions(-) diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs index 331db99f4..e9da3c75b 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs @@ -1,6 +1,6 @@ use std::{ cmp, - collections::{HashMap, HashSet, VecDeque}, + collections::{HashMap, HashSet}, }; use either::Either; @@ -47,6 +47,8 @@ pub(super) fn validate(ctx: &mut Context<'_>) { .collect::>() .join(" -> "); + // TODO: We can push an error for every sinlge class here (that's what + // Rust does), for now it's an error for every cycle found. ctx.push_error(DatamodelError::new_validation_error( &format!("These classes form a dependency cycle: {}", cycle), ctx.db.ast()[component[0]].span().clone(), @@ -197,30 +199,56 @@ impl<'g> Tarjan<'g> { // state. Keep it to a minimum :) self.state.insert(node_id, node); + // Root node of a strongly connected component. if node.low_link == node.index { let mut component = Vec::new(); - while let Some(successor_id) = self.stack.pop() { + while let Some(parent_id) = self.stack.pop() { // This should not fail since all nodes should be stored in // the state hash map. - if let Some(successor) = self.state.get_mut(&successor_id) { - successor.on_stack = false; + if let Some(parent) = self.state.get_mut(&parent_id) { + parent.on_stack = false; } - component.push(successor_id); + component.push(parent_id); - if successor_id == node_id { + if parent_id == node_id { break; } } - // Path should be shown as parent -> child. - // TODO: The start node is random because hash maps. A simple fix - // is to consider that the cycle starts at the node with the - // smallest ID. + // Path should be shown as parent -> child not child -> parent. component.reverse(); - self.components.push(component); + // Find index of minimum element in the component. + // + // The cycle path is not computed deterministacally because the + // graph is stored in a hash map, so random state will cause the + // traversal algorithm to start at different nodes each time. + // + // Therefore, to avoid reporting errors to the user differently + // every time, we'll use a simple deterministic way to determine + // the start node of a cycle. + // + // Basically, the start node will always be the smallest type ID in + // the cycle. That gets rid of the random state. + let min_index = component + .iter() + .enumerate() + .min_by(|(_, a), (_, b)| a.cmp(b)) + .map(|(i, _)| i); + + // We have a cycle if the component contains more than one node or + // it contains a single node that points to itself. Otherwise it's + // just a normal node with no cycles whatsoever, so we'll skip it. + if component.len() > 1 + || (component.len() == 1 && self.graph[&node_id].contains(&node_id)) + { + if let Some(index) = min_index { + component.rotate_left(index); + self.components.push(component); + } + } } } } @@ -250,26 +278,6 @@ mod tests { .collect() } - /// Ignores the graph cycle path. - /// - /// The graph is stored in a HashMap so Tarjan's algorithm will not always - /// follow the same path due to random state. We can't use Vecs to compare - /// determinstically so we'll just ignore the cycle path and compare the - /// nodes that form the cycle. - /// - /// TODO: Implement the fix mentioned in the implementation and this won't - /// be necessary. - fn ignore_path(components: Vec>) -> Vec> { - components - .into_iter() - .map(|v| v.into_iter().collect()) - .collect() - } - - fn assert_eq_components(actual: Vec>, expected: Vec>) { - assert_eq!(ignore_path(actual), ignore_path(expected)); - } - #[test] fn find_cycles() { let graph = graph(&[ @@ -283,9 +291,23 @@ mod tests { (7, &[4, 6, 7]), ]); - assert_eq_components( + assert_eq!( Tarjan::components(&graph), expected_components(&[&[0, 1, 2], &[5, 6], &[3, 4], &[7]]), ); } + + #[test] + fn no_cycles_found() { + let graph = graph(&[ + (0, &[1]), + (1, &[2, 3]), + (2, &[4]), + (3, &[5]), + (4, &[]), + (5, &[]), + ]); + + assert_eq!(Tarjan::components(&graph), expected_components(&[])); + } } From f6ea7561bed243203a990f26f59728c72452f9e5 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Sun, 20 Oct 2024 21:13:53 +0100 Subject: [PATCH 4/6] Fix dependency graph construction --- .../validation_pipeline/validations/cycle.rs | 49 +++++++++- .../class/dependency_cycle.baml | 89 ++++++++++++++++--- .../class/mutually_recursive_types.baml | 24 ----- .../class/recursive_type.baml | 32 +++++++ 4 files changed, 155 insertions(+), 39 deletions(-) delete mode 100644 engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs index e9da3c75b..6ebc6f003 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs @@ -33,7 +33,7 @@ pub(super) fn validate(ctx: &mut Context<'_>) { for field in &expr_block.fields { if let Some(field_type) = &field.expr { - insert_required_deps(field_type, ctx, &mut dependencies); + insert_required_deps(class.id, field_type, ctx, &mut dependencies); } } @@ -61,7 +61,12 @@ pub(super) fn validate(ctx: &mut Context<'_>) { /// Recursively deals with unions of unions. Can be implemented iteratively with /// a while loop and a stack/queue if this ends up being slow / inefficient or /// it reaches stack overflows with large inputs. -fn insert_required_deps(field: &FieldType, ctx: &Context<'_>, deps: &mut HashSet) { +fn insert_required_deps( + id: TypeExpId, + field: &FieldType, + ctx: &Context<'_>, + deps: &mut HashSet, +) { match field { FieldType::Symbol(arity, ident, _) if arity.is_required() => { if let Some(Either::Left(class)) = ctx.db.find_type_by_str(ident.name()) { @@ -70,9 +75,40 @@ fn insert_required_deps(field: &FieldType, ctx: &Context<'_>, deps: &mut HashSet } FieldType::Union(arity, field_types, _, _) if arity.is_required() => { + // All the dependencies of union. + let mut union_deps = HashSet::new(); + + // All the dependencies of a single field in the union. This is + // reused on every iteration of the loop below to avoid allocating + // a new hash set every time. + let mut nested_deps = HashSet::new(); + for f in field_types { - insert_required_deps(f, ctx, deps); + insert_required_deps(id, f, ctx, &mut nested_deps); + + // No nested deps found on this component, this makes the + // union finite. + if nested_deps.is_empty() { + return; // Finite union, no need to go deeper. + } + + // Add the nested deps to the overall union deps and clear the + // iteration hash set. + union_deps.extend(nested_deps.drain()); } + + // A union does not depend on itself if the field can take other + // values. However, if it only depends on itself, it means we have + // something like this: + // + // class Example { + // field: Example | Example | Example + // } + if union_deps.len() > 1 { + union_deps.remove(&id); + } + + deps.extend(union_deps); } _ => {} @@ -155,6 +191,11 @@ impl<'g> Tarjan<'g> { } } + // Sort components by the first element in each cycle (which is already + // sorted as well). This should get rid of all the randomness caused by + // hash maps and hash sets. + tarjans.components.sort_by(|a, b| a[0].cmp(&b[0])); + tarjans.components } @@ -293,7 +334,7 @@ mod tests { assert_eq!( Tarjan::components(&graph), - expected_components(&[&[0, 1, 2], &[5, 6], &[3, 4], &[7]]), + expected_components(&[&[0, 1, 2], &[3, 4], &[5, 6], &[7]]), ); } diff --git a/engine/baml-lib/baml/tests/validation_files/class/dependency_cycle.baml b/engine/baml-lib/baml/tests/validation_files/class/dependency_cycle.baml index f8c630914..3560aa338 100644 --- a/engine/baml-lib/baml/tests/validation_files/class/dependency_cycle.baml +++ b/engine/baml-lib/baml/tests/validation_files/class/dependency_cycle.baml @@ -1,3 +1,4 @@ +// Basic Mutual recursion between two classes. class InterfaceTwo { interface InterfaceOne } @@ -6,25 +7,91 @@ class InterfaceOne { interface InterfaceTwo } +// Infinite recursion on the same class. class InterfaceThree { interface InterfaceThree } -// error: Error validating: These classes form a dependency cycle: InterfaceTwo -// --> class/dependency_cycle.baml:1 +// Long cycle. +class One { + p Two +} + +class Two { + p Three +} + +class Three { + p Four +} + +class Four { + p Five +} + +class Five { + p One +} + +// Second independend long cycle. +class A { + p B +} + +class B { + p C +} + +class C { + p D +} + +class D { + p A +} + +// Union that depends on itself. +class Union { + u Union | Union | (Union | Union) +} + +// error: Error validating: These classes form a dependency cycle: InterfaceTwo -> InterfaceOne +// --> class/dependency_cycle.baml:2 // | +// 1 | // Basic Mutual recursion between two classes. +// 2 | class InterfaceTwo { +// 3 | interface InterfaceOne +// 4 | } // | -// 1 | class InterfaceTwo { +// error: Error validating: These classes form a dependency cycle: InterfaceThree +// --> class/dependency_cycle.baml:11 // | -// error: Error validating: These classes form a dependency cycle: InterfaceOne -// --> class/dependency_cycle.baml:5 +// 10 | // Infinite recursion on the same class. +// 11 | class InterfaceThree { +// 12 | interface InterfaceThree +// 13 | } // | -// 4 | -// 5 | class InterfaceOne { +// error: Error validating: These classes form a dependency cycle: One -> Two -> Three -> Four -> Five +// --> class/dependency_cycle.baml:16 // | -// error: Error validating: These classes form a dependency cycle: InterfaceThree -// --> class/dependency_cycle.baml:9 +// 15 | // Long cycle. +// 16 | class One { +// 17 | p Two +// 18 | } +// | +// error: Error validating: These classes form a dependency cycle: A -> B -> C -> D +// --> class/dependency_cycle.baml:37 +// | +// 36 | // Second independend long cycle. +// 37 | class A { +// 38 | p B +// 39 | } +// | +// error: Error validating: These classes form a dependency cycle: Union +// --> class/dependency_cycle.baml:54 // | -// 8 | -// 9 | class InterfaceThree { +// 53 | // Union that depends on itself. +// 54 | class Union { +// 55 | u Union | Union | (Union | Union) +// 56 | } // | diff --git a/engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml b/engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml deleted file mode 100644 index 5a2b69f87..000000000 --- a/engine/baml-lib/baml/tests/validation_files/class/mutually_recursive_types.baml +++ /dev/null @@ -1,24 +0,0 @@ -// class Tree { -// data int -// children Forest -// } - -// class Forest { -// trees Tree[] -// } - -class One { - t Two -} - -class Two { - t Three -} - -class Three { - o One -} - -class Four { - f Four -} \ No newline at end of file diff --git a/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml b/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml index 9142bdc3a..e78022689 100644 --- a/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml +++ b/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml @@ -1,4 +1,36 @@ +// Trivial recursive type. class Node { data int next Node? +} + +// Mutually recursive types. +class Tree { + data int + children Forest +} + +class Forest { + trees Tree[] +} + +// Unions. +class BasicUnion { + data int | BasicUnion +} + +class OptionalUnion { + data OptionalUnion? | OptionalUnion +} + +class FullyOptionalUnion { + data (FullyOptionalUnion | FullyOptionalUnion)? +} + +class NestedUnion { + data NestedUnion | BasicUnion +} + +class UnionOfUnions { + data (Node | (BasicUnion | Tree)) | (NestedUnion | UnionOfUnions) | Forest } \ No newline at end of file From bce5354198fa167e3a21b865fcd7536f01846384 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Mon, 21 Oct 2024 15:06:44 +0100 Subject: [PATCH 5/6] Add `NOTE` for `finalize_dependencies()` cycle detection --- .../validation_pipeline/validations/cycle.rs | 2 +- .../validation_files/class/recursive_type.baml | 2 +- engine/baml-lib/parser-database/src/lib.rs | 13 ++++++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs index 6ebc6f003..36de299e3 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs @@ -75,7 +75,7 @@ fn insert_required_deps( } FieldType::Union(arity, field_types, _, _) if arity.is_required() => { - // All the dependencies of union. + // All the dependencies of the union. let mut union_deps = HashSet::new(); // All the dependencies of a single field in the union. This is diff --git a/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml b/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml index e78022689..c035d539a 100644 --- a/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml +++ b/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml @@ -33,4 +33,4 @@ class NestedUnion { class UnionOfUnions { data (Node | (BasicUnion | Tree)) | (NestedUnion | UnionOfUnions) | Forest -} \ No newline at end of file +} diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index d4287206b..8904e3c7a 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -190,7 +190,18 @@ impl ParserDatabase { } } - // TODO: Is this code necessary? Dependency cycles are already checked in the validation. + // NOTE: Class dependency cycles are already checked at + // baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs + // + // The algorithm at `cycle.rs` takes arity and recursive types into + // account unlike the topological sort performed here. The code below + // does not need to run since cycles would have already been detected + // at the validation stage, which runs before this function. Check + // baml-lib/baml-core/src/lib.rs + // + // The code above this comment seems modifies the AST by extending + // .class_dependencies so that one still needs to run. + // if max_loops == 0 && !deps.is_empty() { // let circular_deps = deps // .iter() From d5355200935072dd160f396ac12d70402b49fb00 Mon Sep 17 00:00:00 2001 From: Antonio Sarosi Date: Tue, 22 Oct 2024 23:40:58 +0100 Subject: [PATCH 6/6] Get rid of redundant dependency cycle check --- .../validation_pipeline/validations/cycle.rs | 6 +- ...cursive_type.baml => recursive_types.baml} | 0 engine/baml-lib/parser-database/src/lib.rs | 88 +------------------ 3 files changed, 5 insertions(+), 89 deletions(-) rename engine/baml-lib/baml/tests/validation_files/class/{recursive_type.baml => recursive_types.baml} (100%) diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs index 36de299e3..b7ae79b40 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs @@ -9,7 +9,7 @@ use internal_baml_schema_ast::ast::{FieldType, TypeExpId, WithName, WithSpan}; use crate::validate::validation_pipeline::context::Context; -/// Validates if there's a cycle in the dependency graph. +/// Validates if the dependency graph contains one or more infinite cycles. pub(super) fn validate(ctx: &mut Context<'_>) { // First, build a graph of all the "required" dependencies represented as an // adjacency list. We're only going to consider type dependencies that can @@ -87,9 +87,9 @@ fn insert_required_deps( insert_required_deps(id, f, ctx, &mut nested_deps); // No nested deps found on this component, this makes the - // union finite. + // union finite, so no need to go deeper. if nested_deps.is_empty() { - return; // Finite union, no need to go deeper. + return; } // Add the nested deps to the overall union deps and clear the diff --git a/engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml b/engine/baml-lib/baml/tests/validation_files/class/recursive_types.baml similarity index 100% rename from engine/baml-lib/baml/tests/validation_files/class/recursive_type.baml rename to engine/baml-lib/baml/tests/validation_files/class/recursive_types.baml diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index 8904e3c7a..cf52b3e13 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -129,97 +129,13 @@ impl ParserDatabase { } fn finalize_dependencies(&mut self, diag: &mut Diagnostics) { - let mut deps = self - .types - .class_dependencies - .iter() - .map(|f| { - ( - *f.0, - f.1.iter() - .fold((0, 0, 0), |prev, i| match self.find_type_by_str(i) { - Some(Either::Left(_)) => (prev.0 + 1, prev.1 + 1, prev.2), - Some(Either::Right(_)) => (prev.0 + 1, prev.1, prev.2 + 1), - _ => prev, - }), - ) - }) - .collect::>(); - - // Can only process deps which have 0 class dependencies. - let mut max_loops = 100; - while !deps.is_empty() && max_loops > 0 { - max_loops -= 1; - // Remove all the ones which have 0 class dependencies. - let removed = deps - .iter() - .filter(|(_, v)| v.1 == 0) - .map(|(k, _)| *k) - .collect::>(); - deps.retain(|(_, v)| v.1 > 0); - for cls in removed { - let child_deps = self - .types - .class_dependencies - .get(&cls) - // These must exist by definition so safe to unwrap. - .unwrap() - .iter() - .filter_map(|f| match self.find_type_by_str(f) { - Some(Either::Left(walker)) => { - Some(walker.dependencies().iter().cloned().collect::>()) - } - Some(Either::Right(walker)) => Some(vec![walker.name().to_string()]), - _ => panic!("Unknown class `{}`", f), - }) - .flatten() - .collect::>(); - let name = self.ast[cls].name(); - deps.iter_mut() - .filter(|(k, _)| self.types.class_dependencies[k].contains(name)) - .for_each(|(_, v)| { - v.1 -= 1; - }); - - // Get the dependencies of all my dependencies. - self.types - .class_dependencies - .get_mut(&cls) - .unwrap() - .extend(child_deps); - } - } - // NOTE: Class dependency cycles are already checked at // baml-lib/baml-core/src/validate/validation_pipeline/validations/cycle.rs // - // The algorithm at `cycle.rs` takes arity and recursive types into - // account unlike the topological sort performed here. The code below - // does not need to run since cycles would have already been detected - // at the validation stage, which runs before this function. Check + // The validation pipeline runs before this code. Check // baml-lib/baml-core/src/lib.rs // - // The code above this comment seems modifies the AST by extending - // .class_dependencies so that one still needs to run. - - // if max_loops == 0 && !deps.is_empty() { - // let circular_deps = deps - // .iter() - // .map(|(k, _)| self.ast[*k].name()) - // .collect::>() - // .join(" -> "); - - // deps.iter().for_each(|(k, _)| { - // diag.push_error(DatamodelError::new_validation_error( - // &format!( - // "Circular dependency detected for class `{}`.\n{}", - // self.ast[*k].name(), - // circular_deps - // ), - // self.ast[*k].identifier().span().clone(), - // )); - // }); - // } + // So we won't check cycles again. // Additionally ensure the same thing for functions, but since we've already handled classes, // this should be trivial.