Is possible to add a loop condition in Airflow? #21726
-
how about a operator for "clear"According to the "Directed Acyclic Graph " (DAG) concept, we will not get a explict loop flow graph in our dag or webserverUI. BAD DAG[do a task] >> is_result_ok? >> no >> continue [do a task] (which is not DAG!) GOOD DAGI believe we CAN NOT add “>> do a task“ in the downstream for it has occured in the upstream. So, solution might be: Is this feasible?it will be tricky:
|
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 7 replies
-
Look at Branch Operators - for example here: https://www.astronomer.io/guides/airflow-branch-operator/ |
Beta Was this translation helpful? Give feedback.
-
To handle the "loop" condition, i have a simple example DAG if interested. I imaged if there will be a XCOM object for "DAG level" but "taskinstance level", which will be more "global". I know Variables can manage, but it still be a trouble when you just want to temporally record a special DAG value. |
Beta Was this translation helpful? Give feedback.
-
example dag here if interested. : ) @potiuk I believe it has not break the "acyclic" rule although it still be very trick. But the loop condition still trouble me. The dag below can manage it,but it still be some probs.
I think airflow will be more universal if having such a feature to manage such a loop workflow or other similar condition in our daily use avoiding TriggerDagRunOpt running another DAG instance. Many thanks! import numpy as np
from airflow.models import DAG
from airflow.models import XCom
from airflow.operators.python import BranchPythonOperator
from airflow.operators.python import get_current_context
from airflow.utils.dates import days_ago
from airflow.utils.edgemodifier import Label
from airflow.decorators import task_group, task
DAG_NAME = "clear_loop_example"
default_parameters = {}
def dummy_evaluate(result):
return False
def do_evaluate(evaluate_func, result, fail_task_id, success_task_id):
ctx = get_current_context()
cond = evaluate_func(result)
ctx["ti"].xcom_push(key="condition" ,value=cond)
if cond:
return fail_task_id
else:
return success_task_id
@task
def init_loop():
return 1
@task
def update_loop():
ctx =get_current_context()
loop_num = ctx["ti"].xcom_pull(task_ids="init_loop",key="return_value")
print("loop_num: ", loop_num)
XCom.set(
key="return_value",
value=loop_num+1,
task_id="init_loop",
dag_id=DAG_NAME,
execution_date=ctx["execution_date"],
)
pass
@task
def print_loop_num(loop_num):
print(loop_num)
return loop_num
@task
def opts_in_loop():
pass
@task
def report():
pass
def get_xcom_from_evaluate(evaluate_xcomargs):
cond = evaluate_xcomargs.output.get("condition")
if cond == True:
return True
else:
raise ValueError("DUMMY")
@task
def success():
pass
def clear_upstream_task(context):
execution_date = context.get("execution_date")
dag.clear(exclude_task_ids=["init_loop"], start_date=execution_date)
print("successfully clear task instance")
return True
with DAG(dag_id=DAG_NAME,
default_args=default_parameters,
schedule_interval=None, start_date=days_ago(2),) as dag:
some_evaluate = BranchPythonOperator(
task_id='evaluate',
python_callable=do_evaluate,
op_args=[dummy_evaluate, "blabla","success","update_loop"],
)
clean = task(get_xcom_from_evaluate,
task_id='run_next_loop',
trigger_rule="none_failed",
on_failure_callback=clear_upstream_task)(some_evaluate)
loop_num = init_loop()
print_loop_num(loop_num) >> opts_in_loop() >> some_evaluate
some_evaluate >> Label("success") >> success() >> clean
some_evaluate >> Label("failed") >> update_loop() >> clean
clean >> report()
|
Beta Was this translation helpful? Give feedback.
-
For future references for those that want to implement a looping condition in Airflow, here's a possible implementation: import abc
from typing import Any, Generic, Mapping, TypeVar, Union
from airflow.utils.operator_helpers import KeywordParameters
T = TypeVar('T')
class AbstractLoop(abc.ABC, Generic[T]):
"""
Abstract class to execute as a python_callable using a PythonOperator.
Runs the `run` method until the `condition` method returns false.
Example:
PythonOperator(
task_id="task_looper"
python_callable=Loop(
...args
),
op_kwargs={
...op_kwargs
}
)
"""
def __call__(self, **context)
condition_kwargs = self._determine_kwargs(self.condition, context)
run_kwargs = self._determine_kwargs(loop.run, context)
result = None
while loop.condition(previous=result, **condition_kwargs):
result = loop.run(previous=result, **run_kwargs)
@abc.abstractmethod
def condition(self, previous: Union[T, None], **context) -> bool:
pass
@abc.abstractmethod
def run(self, previous: Union[T, None], **context) -> T:
pass
def _determine_kwargs(self, fn, context: Mapping[str, Any]) -> Mapping[str, Any]:
return KeywordParameters.determine(fn, (), context).unpacking() Here you can find an example of how this can be used - My specific use case is to backfill a dbt snapshot given a set of daily captures. from typing import Union
import datetime as dt
from airflow.models.dagrun import DagRun
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.utils import timezone
from airflow.utils.types import DagRunType
class DateIntervalTriggerDagRunLoop(AbstractLoop[dt.date]):
"""
Triggers a DagRun for a given dag_id for each date in a given interval.
Args:
trigger_dag_id (str): Identifier of the Dag to trigger.
min_date_conf_key (str): DagRun configuration key for lower bound of the date interval.
Defaults to "min_date".
max_date_conf_key (str): DagRun configuration key for upper bound of the date interval.
Defaults to "max_date".
date_task_id (str, optional): Task id of a task within the Dag to trigger that pushes the date
of the current iteration to XCom.
Optional only if dag_run.conf[min_date_conf_key] is not None.
date_xcom_key (str, optional): XCom key used to push the date of the current iteration.
Optional only if dag_run.conf[min_date_conf_key] is not None.
PythonOperator Op Kwargs:
conf (Dict[str, str], optional): Configuration to pass to the DagRuns to trigger.
"""
def __init__(
self,
trigger_dag_id: str,
min_date_conf_key: Optional[str] = "min_date",
max_date_conf_key: Optional[str] = "max_date",
date_task_id: Optional[str] = None,
date_xcom_key: Optional[str] = None
):
self.trigger_dag_id = trigger_dag_id
self.min_date_conf_key = min_date_conf_key
self.max_date_conf_key = max_date_conf_key
self.date_task_id = date_task_id
self.date_xcom_key = date_xcom_key
def condition(
self,
previous: Union[dt.date, None],
dag_run,
):
# If previous is None, this is the first iteration. Keep looping.
if previous is None:
return True
# Else, this is a subsequent iteration. Determine what the upper bound is.
# Defaults to dag_run.execution_date to avoid looping in the future.
max_date = (
dag_run.execution_date.date()
if dag_run.conf.get(self.max_date_conf_key) is None
else dt.datetime.strptime(dag_run.conf.get(self.max_date_conf_key), '%Y-%m-%d').date()
)
# Keep looping if upper bound has not been met
return previous.date <= max_date
def run(
self,
previous: Union[dt.date, None],
dag_run: DagRun,
ti: TaskInstance,
conf: Dict[str, str],
):
# Get the date from the previous iteration if possible, else from the minimum date as specified
# in dag_run.conf.
date = (
dag_run.conf.get(self.min_date_conf_key, None)
if previous is None
else previous + dt.timedelta(days=1)
)
dag_run = context['dag_run']
# Prepare the conf for the DagRun to trigger
trigger_dag_conf = {
**conf,
"date": None if date is None else date.strftime('%Y-%m-%d'),
}
# Create a trigger_run_id - implementation from airflow/operators/trigger_dagrun.py
trigger_run_id = DagRun.generate_run_id(DagRunType.MANUAL, timezone.utcnow())
# Create a unique task_id from trigger_run_id
task_id = f"{trigger_dag_id}_{hex(hash(trigger_run_id))[2:]}"
# Trigger the DagRun & wait for completion
TriggerDagRunOperator(
task_id=task_id,
trigger_dag_id=trigger_dag_id,
trigger_run_id=trigger_run_id,
conf=trigger_dag_conf,
failed_states=["failed"],
poke_interval=30,
wait_for_completion=True,
).execute(context)
# If the date is None - that is if no previous run AND no minimum date specified, get
# iteration date from a task of the dagrun that returns the date of the iteration.
if date is None:
assert self.date_task_id
assert self.date_xcom_key
date = ti.xcom_pull(run_id=trigger_run_id, task_id=self.date_task_id, key=self.date_xcom_key)
return date PythonOperator(
task_id="snapshot_backfill_my_table"
python_callable=DateIntervalTriggerDagRunLoop(
trigger_dag_id="snapshot-backfill",
date_task_id="determine_latest_snapshot",
date_xcom_key="date",
min_date_conf_key="min_date",
max_date_conf_key="max_date",
),
op_kwargs={
"conf": {
"table_id": "{{dag_run.conf.project}}.{{dag_run.conf.client_id}}.stg_my_table",
"snapshot_table_id": "{{dag_run.conf.project}}.{{dag_run.conf.client_id}}.snapshot_my_table"
}
}
) For the case of @appassionate, a similar class could be created where the loop condition is a threshold or set of threshold on the model's performance. |
Beta Was this translation helpful? Give feedback.
For future references for those that want to implement a looping condition in Airflow, here's a possible implementation: