Skip to content

Commit

Permalink
AIP-72: Fix recursion bug with XComArg (#45112)
Browse files Browse the repository at this point in the history
It fixes the following bug

```python
{"timestamp":"2024-12-20T10:38:56.890735","logger":"task","error_detail":
[{"exc_type":"RecursionError","exc_value":"maximum recursion depth exceeded in comparison","syntax_error":null,"is_cause":false,"frames":
[
	{"filename":"/opt/airflow/task_sdk/src/airflow/sdk/execution_time/task_runner.py","lineno":382,"name":"main"},
	{"filename":"/opt/airflow/task_sdk/src/airflow/sdk/execution_time/task_runner.py","lineno":317,"name":"run"},
	{"filename":"/opt/airflow/airflow/models/baseoperator.py","lineno":378,"name":"wrapper"},
	{"filename":"/opt/airflow/providers/src/airflow/providers/standard/operators/python.py","lineno":182,"name":"execute"},
	{"filename":"/opt/airflow/task_sdk/src/airflow/sdk/definitions/baseoperator.py","lineno":660,"name":"__setattr__"},
	{"filename":"/opt/airflow/task_sdk/src/airflow/sdk/definitions/baseoperator.py","lineno":1126,"name":"_set_xcomargs_dependency"},
	{"filename":"/opt/airflow/airflow/models/xcom_arg.py","lineno":132,"name":"apply_upstream_relationship"},
	{"filename":"/opt/airflow/airflow/models/xcom_arg.py","lineno":118,"name":"iter_xcom_references"},
	{"filename":"/opt/airflow/airflow/models/xcom_arg.py","lineno":121,"name":"iter_xcom_references"},
	{"filename":"/opt/airflow/airflow/models/xcom_arg.py","lineno":118,"name":"iter_xcom_references"},
	...
```

To reproduce just run `tutorial_dag` or the following minimal dag:

```python
import pendulum

from airflow.models.dag import DAG
from airflow.providers.standard.operators.python import PythonOperator

with DAG(
    "sdk_tutorial_dag",
    schedule=None,
    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
    catchup=False,
    tags=["example"],
) as dag:
    dag.doc_md = __doc__

    def extract(**kwargs):
        ti = kwargs["ti"]
        data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
        ti.xcom_push("order_data", data_string)

    extract_task = PythonOperator(
        task_id="extract",
        python_callable=extract,
    )

    extract_task
```

I need this fix for #45075 (part of the getting [Task Context working with AIP-72](#44481))
  • Loading branch information
kaxil authored Dec 20, 2024
1 parent 7fbe16a commit 0465cdd
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 19 deletions.
7 changes: 0 additions & 7 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import collections.abc
import contextlib
import copy
import functools
import logging
from collections.abc import Collection, Iterable, Sequence
Expand Down Expand Up @@ -703,12 +702,6 @@ def get_outlet_defs(self):
extended/overridden by subclasses.
"""

def prepare_for_execution(self) -> BaseOperator:
"""Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
other = copy.copy(self)
other._lock_for_execution = True
return other

@prepare_lineage
def pre_execute(self, context: Any):
"""Execute right before self.execute() is called."""
Expand Down
3 changes: 2 additions & 1 deletion airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.typing_compat import Self
from airflow.utils.context import Context

# As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
Expand Down Expand Up @@ -408,7 +409,7 @@ def _get_next_poke_interval(
self.log.info("new %s interval is %s", self.mode, new_interval)
return new_interval

def prepare_for_execution(self) -> BaseOperator:
def prepare_for_execution(self) -> Self:
task = super().prepare_for_execution()

# Sensors in `poke` mode can block execution of DAGs when running
Expand Down
9 changes: 8 additions & 1 deletion task_sdk/src/airflow/sdk/definitions/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.serialization.enums import DagAttributeTypes
from airflow.typing_compat import Self
from airflow.utils.operator_resources import Resources

# TODO: Task-SDK
Expand Down Expand Up @@ -974,7 +975,7 @@ def __deepcopy__(self, memo: dict[int, Any]):

def __getstate__(self):
state = dict(self.__dict__)
if self._log:
if "_log" in state:
del state["_log"]

return state
Expand Down Expand Up @@ -1219,6 +1220,12 @@ def get_serialized_fields(cls):

return cls.__serialized_fields

def prepare_for_execution(self) -> Self:
"""Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
other = copy.copy(self)
other._lock_for_execution = True
return other

def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Serialize; required by DAGNode."""
from airflow.serialization.enums import DagAttributeTypes
Expand Down
1 change: 1 addition & 0 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: pre execute etc.
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task = ti.task.prepare_for_execution()
context = ti.get_template_context()
ti.task.execute(context) # type: ignore[attr-defined]
msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
Expand Down
12 changes: 3 additions & 9 deletions task_sdk/tests/defintions/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,12 @@ def test_warnings_are_properly_propagated(self):
assert warning.filename == __file__

def test_setattr_performs_no_custom_action_at_execute_time(self, spy_agency):
from airflow.models.xcom_arg import XComArg

op = MockOperator(task_id="test_task")

op._lock_for_execution = True
# TODO: Task-SDK
# op_copy = op.prepare_for_execution()
op_copy = op

spy_agency.spy_on(XComArg.apply_upstream_relationship, call_original=False)
op_copy = op.prepare_for_execution()
spy_agency.spy_on(op._set_xcomargs_dependency, call_original=False)
op_copy.arg1 = "b"
assert XComArg.apply_upstream_relationship.called is False
assert op._set_xcomargs_dependency.called is False

def test_upstream_is_set_when_template_field_is_xcomarg(self):
with DAG("xcomargs_test", schedule=None):
Expand Down
10 changes: 9 additions & 1 deletion task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_parse(test_dags_dir: Path, make_ti_context):
assert isinstance(ti.task.dag, DAG)


def test_run_basic(time_machine, mocked_parse, make_ti_context):
def test_run_basic(time_machine, mocked_parse, make_ti_context, spy_agency):
"""Test running a basic task."""
what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
Expand All @@ -163,8 +163,16 @@ def test_run_basic(time_machine, mocked_parse, make_ti_context):
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
ti = mocked_parse(what, "super_basic_run", CustomOperator(task_id="hello"))

# Ensure that task is locked for execution
spy_agency.spy_on(ti.task.prepare_for_execution)
assert not ti.task._lock_for_execution

run(ti, log=mock.MagicMock())

spy_agency.assert_spy_called(ti.task.prepare_for_execution)
assert ti.task._lock_for_execution

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), log=mock.ANY
)
Expand Down

0 comments on commit 0465cdd

Please sign in to comment.