Skip to content
This repository has been archived by the owner on Apr 8, 2024. It is now read-only.

Commit

Permalink
Add plus-n operator like model+2 and 1+model (#294)
Browse files Browse the repository at this point in the history
* Add plus-n operator like model+2 and 1+model

* Remove some unnecessary casting

* Refactor selector graph operators for easier handling

* Add test for graphs with many levels
  • Loading branch information
chamini2 authored May 10, 2022
1 parent 1376b27 commit fb8b208
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 40 deletions.
25 changes: 25 additions & 0 deletions integration_tests/features/flow_run.feature
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,31 @@ Feature: `flow run` command
Then the following models are calculated:
| intermediate_model_1 | intermediate_model_2 | intermediate_model_3 |

Scenario: fal flow run command with plus operator with number selectors
Given the project 001_flow_run_with_selectors
When the data is seeded

When the following command is invoked:
"""
fal flow run --profiles-dir $profilesDir --project-dir $baseDir --select 2+intermediate_model_3 --threads 1
"""
Then the following models are calculated:
| intermediate_model_1 | intermediate_model_2 | intermediate_model_3 |

When the following command is invoked:
"""
fal flow run --profiles-dir $profilesDir --project-dir $baseDir --select intermediate_model_1+1 --threads 1
"""
Then the following models are calculated:
| intermediate_model_1 | intermediate_model_2 |

When the following command is invoked:
"""
fal flow run --profiles-dir $profilesDir --project-dir $baseDir --select intermediate_model_1+0 --threads 1
"""
Then the following models are calculated:
| intermediate_model_1 |

Scenario: fal flow run command with selectors
Given the project 001_flow_run_with_selectors
When the data is seeded
Expand Down
116 changes: 80 additions & 36 deletions src/fal/cli/selectors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import re
from typing import List
from dataclasses import dataclass
import re
from typing import List, Optional, Union
from fal.node_graph import NodeGraph
from faldbt.project import CompileArgs, FalDbt
from dbt.task.compile import CompileTask
from enum import Enum
from functools import reduce
import faldbt.lib as lib


class ExecutionPlan:
Expand Down Expand Up @@ -67,32 +66,43 @@ def _filter_node_ids(
) -> List[str]:
"""Filter list of unique_ids according to a selector."""
output = []
selector_plans = list(
map(
lambda selector: SelectorPlan(selector, unique_ids, fal_dbt),
selected_nodes,
)
selector_plans = map(
lambda selector: SelectorPlan(selector, unique_ids, fal_dbt),
selected_nodes,
)

for selector_plan in selector_plans:
for id in selector_plan.unique_ids:
output.append(id)

if selector_plan.children:
children = list(nodeGraph.get_descendants(id))
if selector_plan.children_levels is None:
children = nodeGraph.get_descendants(id)
else:
children = nodeGraph.get_successors(
id, selector_plan.children_levels
)
output.extend(children)

if selector_plan.parents:
parents = list(nodeGraph.get_ancestors(id))
if selector_plan.parents_levels is None:
parents = nodeGraph.get_ancestors(id)
else:
parents = nodeGraph.get_predecessors(
id, selector_plan.parents_levels
)
output.extend(parents)

if selector_plan.children_with_parents:
ids = _get_children_with_parents(id, nodeGraph)
output.extend(ids)

return output


def _get_children_with_parents(node_id, nodeGraph) -> List[str]:
def _get_children_with_parents(node_id: str, nodeGraph: NodeGraph) -> List[str]:
children = nodeGraph.get_descendants(node_id)
output = reduce(
lambda a, b: a + list(nodeGraph.get_ancestors(b)), children, children
)
output = reduce(lambda l, ch: l + nodeGraph.get_ancestors(ch), children, children)

output = list(set(output))

Expand All @@ -117,6 +127,7 @@ class SelectType(Enum):
COMPLEX = 3


@dataclass(init=False)
class SelectorPlan:
"""
Represents a single selector, for example in the command
Expand All @@ -128,24 +139,41 @@ class SelectorPlan:

unique_ids: List[str]
children: bool
children_levels: Optional[int]
children_with_parents: bool
parents: bool
parents_levels: Optional[int]
type: SelectType
raw: str

def __init__(self, selector: str, unique_ids: List[str], fal_dbt: FalDbt):
self.children = _needs_children(selector)
self.parents = _need_parents(selector)
self.children_with_parents = _needs_children_with_parents(selector)
self.raw = selector
self.children_with_parents = OP_CHILDREN_WITH_PARENTS.match(selector)
selector = OP_CHILDREN_WITH_PARENTS.rest(selector)

self.parents = OP_PARENTS.match(selector)
self.parents_levels = OP_PARENTS.depth(selector)
selector = OP_PARENTS.rest(selector)

self.children = OP_CHILDREN.match(selector)
self.children_levels = OP_CHILDREN.depth(selector)
selector = OP_CHILDREN.rest(selector)

self.type = _to_select_type(selector)
node_name = _remove_graph_selectors(selector)

if self.type == SelectType.MODEL:
self.unique_ids = [f"model.{fal_dbt.project_name}.{node_name}"]
self.unique_ids = [f"model.{fal_dbt.project_name}.{selector}"]
elif self.type == SelectType.SCRIPT:
self.unique_ids = _expand_script(node_name, unique_ids)
self.unique_ids = _expand_script(selector, unique_ids)
elif self.type == SelectType.COMPLEX:
self.unique_ids = unique_ids_from_complex_selector(selector, fal_dbt)

def __post_init__(self):
if self.children and self.children_with_parents:
raise RuntimeError(
f'Invalid node spec {self.raw} - "@" prefix and "+" suffix are incompatible'
)


def unique_ids_from_complex_selector(select, fal_dbt: FalDbt) -> List[str]:
args = CompileArgs(None, [select], [select], tuple(), fal_dbt._state, None)
Expand All @@ -156,12 +184,11 @@ def unique_ids_from_complex_selector(select, fal_dbt: FalDbt) -> List[str]:
return list(graph.queued)


def _to_select_type(select: str) -> SelectType:
if ":" in select:
def _to_select_type(selector: str) -> SelectType:
if ":" in selector:
return SelectType.COMPLEX
else:
node_name = _remove_graph_selectors(select)
if _is_script_node(node_name):
if _is_script_node(selector):
return SelectType.SCRIPT
else:
return SelectType.MODEL
Expand All @@ -171,24 +198,41 @@ def _is_script_node(node_name: str) -> bool:
return node_name.endswith(".py")


def _remove_graph_selectors(selector: str) -> str:
selector = selector.replace("+", "")
return selector.replace("@", "")
class SelectorGraphOp:
_regex: re.Pattern

def __init__(self, regex: re.Pattern):
self._regex = regex
assert (
"rest" in regex.groupindex
), 'rest must be in regex. Use `re.compile("something(?P<rest>.*)")`'

def _select(self, selector: str, group: Union[str, int]) -> Optional[str]:
match = self._regex.match(selector)
if match:
return match.group(group)

def match(self, selector: str) -> bool:
return self._select(selector, 0) is not None

def _needs_children(selector: str) -> bool:
children_operation_regex = re.compile(".*\\+$")
return bool(children_operation_regex.match(selector))
def rest(self, selector: str) -> str:
rest = self._select(selector, "rest")
if rest is not None:
return rest
return selector


def _needs_children_with_parents(selector: str) -> bool:
children_operation_regex = re.compile("^\\@.*")
return bool(children_operation_regex.match(selector))
class SelectorGraphOpDepth(SelectorGraphOp):
def depth(self, selector: str) -> Optional[int]:
depth = self._select(selector, "depth")
if depth:
return int(depth)


def _need_parents(selector: str) -> bool:
parent_operation_regex = re.compile("^\\+.*")
return bool(parent_operation_regex.match(selector))
# Graph operators from their regex Patterns
OP_CHILDREN_WITH_PARENTS = SelectorGraphOp(re.compile("^\\@(?P<rest>.*)"))
OP_PARENTS = SelectorGraphOpDepth(re.compile("^(?P<depth>\\d*)\\+(?P<rest>.*)"))
OP_CHILDREN = SelectorGraphOpDepth(re.compile("(?P<rest>.*)\\+(?P<depth>\\d*)$"))


def _is_before_scipt(id: str) -> bool:
Expand Down
26 changes: 22 additions & 4 deletions src/fal/node_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,32 @@ def __init__(self, graph: nx.DiGraph, node_lookup: Dict[str, FalFlowNode]):
self.graph = graph
self.node_lookup = node_lookup

def get_successors(self, id: str) -> List[str]:
return list(self.graph.successors(id))
def get_successors(self, id: str, levels: int) -> List[str]:
assert levels >= 0
if levels == 0:
return []
else:
current: List[str] = list(self.graph.successors(id))
return reduce(
lambda acc, id: acc + self.get_successors(id, levels - 1),
current,
current,
)

def get_descendants(self, id: str) -> List[str]:
return list(nx.descendants(self.graph, id))

def get_predecessors(self, id: str) -> List[str]:
return list(self.graph.predecessors(id))
def get_predecessors(self, id: str, levels: int) -> List[str]:
assert levels >= 0
if levels == 0:
return []
else:
current: List[str] = list(self.graph.predecessors(id))
return reduce(
lambda acc, id: acc + self.get_predecessors(id, levels - 1),
current,
current,
)

def get_ancestors(self, id: str) -> List[str]:
return list(nx.ancestors(self.graph, id))
Expand Down
47 changes: 47 additions & 0 deletions tests/graph/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,53 @@ def test_create_plan_start_model_upstream():
assert execution_plan.after_scripts == []


def test_create_plan_large_graph_model_levels():
def _model(s: str) -> str:
return f"model.test_project.{s}"

def _after_script_for_model(model: str) -> str:
return f"script.{model}.AFTER.script.py"

digraph = nx.DiGraph()

for n in range(100):
modeln_name = f"model{n}"
modeln = _model(modeln_name)

digraph.add_edge(modeln, _after_script_for_model(modeln_name))

modelnext = f"model.test_project.model{n+1}"
digraph.add_edge(modeln, modelnext)

for m in range(10): # Avoid cycles with these ranges
modelm = _model(f"model{n}_{m}")
digraph.add_edge(modeln, modelm)

graph = NodeGraph(digraph, {})

parsed = Namespace(select=["model0+70"])

execution_plan = ExecutionPlan.create_plan_from_graph(
parsed, graph, MagicMock(project_name=PROJECT_NAME)
)

assert execution_plan.before_scripts == []
assert_contains_only(
execution_plan.dbt_models,
# model0, model1, ..., model70
[_model(f"model{n}") for n in range(71)]
# model0_0, model0_1, ..., model0_9, model1_0, ..., model69_0, ..., model69_9
# not the children of model70, it ends in model70
+ [_model(f"model{n}_{m}") for m in range(10) for n in range(70)],
)
assert_contains_only(
execution_plan.after_scripts,
# after script for model0, model1, ..., model69
# not model70 because that is one level too far
[_after_script_for_model(f"model{n}") for n in range(70)],
)


def test_create_plan_start_model_upstream_and_downstream():
parsed = Namespace(select=["+modelA+"])
graph = _create_test_graph()
Expand Down

0 comments on commit fb8b208

Please sign in to comment.