Skip to content

Commit 11fe89d

Browse files
committed
fixup! AIP-72: Fix recursion bug with XComArg
1 parent 1cf9ec3 commit 11fe89d

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

airflow/sensors/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
if TYPE_CHECKING:
5252
from sqlalchemy.orm.session import Session
5353

54+
from airflow.typing_compat import Self
5455
from airflow.utils.context import Context
5556

5657
# As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html.
@@ -408,7 +409,7 @@ def _get_next_poke_interval(
408409
self.log.info("new %s interval is %s", self.mode, new_interval)
409410
return new_interval
410411

411-
def prepare_for_execution(self) -> BaseOperator:
412+
def prepare_for_execution(self) -> Self:
412413
task = super().prepare_for_execution()
413414

414415
# Sensors in `poke` mode can block execution of DAGs when running

task_sdk/src/airflow/sdk/definitions/baseoperator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from airflow.sdk.definitions.dag import DAG
7070
from airflow.sdk.definitions.taskgroup import TaskGroup
7171
from airflow.serialization.enums import DagAttributeTypes
72+
from airflow.typing_compat import Self
7273
from airflow.utils.operator_resources import Resources
7374

7475
# TODO: Task-SDK
@@ -1219,7 +1220,7 @@ def get_serialized_fields(cls):
12191220

12201221
return cls.__serialized_fields
12211222

1222-
def prepare_for_execution(self) -> BaseOperator:
1223+
def prepare_for_execution(self) -> Self:
12231224
"""Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
12241225
other = copy.copy(self)
12251226
other._lock_for_execution = True

task_sdk/tests/execution_time/test_task_runner.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_parse(test_dags_dir: Path, make_ti_context):
147147
assert isinstance(ti.task.dag, DAG)
148148

149149

150-
def test_run_basic(time_machine, mocked_parse, make_ti_context):
150+
def test_run_basic(time_machine, mocked_parse, make_ti_context, spy_agency):
151151
"""Test running a basic task."""
152152
what = StartupDetails(
153153
ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1),
@@ -163,8 +163,16 @@ def test_run_basic(time_machine, mocked_parse, make_ti_context):
163163
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
164164
) as mock_supervisor_comms:
165165
ti = mocked_parse(what, "super_basic_run", CustomOperator(task_id="hello"))
166+
167+
# Ensure that task is locked for execution
168+
spy_agency.spy_on(ti.task.prepare_for_execution)
169+
assert not ti.task._lock_for_execution
170+
166171
run(ti, log=mock.MagicMock())
167172

173+
spy_agency.assert_spy_called(ti.task.prepare_for_execution)
174+
assert ti.task._lock_for_execution
175+
168176
mock_supervisor_comms.send_request.assert_called_once_with(
169177
msg=TaskState(state=TerminalTIState.SUCCESS, end_date=instant), log=mock.ANY
170178
)

0 commit comments

Comments
 (0)