diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index d59548003697a..08839cc0bf720 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -25,7 +25,6 @@ import collections.abc import contextlib -import copy import functools import logging from collections.abc import Collection, Iterable, Sequence @@ -703,12 +702,6 @@ def get_outlet_defs(self): extended/overridden by subclasses. """ - def prepare_for_execution(self) -> BaseOperator: - """Lock task for execution to disable custom action in ``__setattr__`` and return a copy.""" - other = copy.copy(self) - other._lock_for_execution = True - return other - @prepare_lineage def pre_execute(self, context: Any): """Execute right before self.execute() is called.""" diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index aaaee72a25df1..4cfd59686e7e0 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -974,7 +974,7 @@ def __deepcopy__(self, memo: dict[int, Any]): def __getstate__(self): state = dict(self.__dict__) - if self._log: + if "_log" in state: del state["_log"] return state @@ -1219,6 +1219,12 @@ def get_serialized_fields(cls): return cls.__serialized_fields + def prepare_for_execution(self) -> BaseOperator: + """Lock task for execution to disable custom action in ``__setattr__`` and return a copy.""" + other = copy.copy(self) + other._lock_for_execution = True + return other + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Serialize; required by DAGNode.""" from airflow.serialization.enums import DagAttributeTypes diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py index 427d1ee0e3efd..f9431b4fe9c3c 100644 --- a/task_sdk/tests/defintions/test_baseoperator.py +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -211,18 +211,12 @@ def test_warnings_are_properly_propagated(self): assert warning.filename == __file__ def test_setattr_performs_no_custom_action_at_execute_time(self, spy_agency): - from airflow.models.xcom_arg import XComArg - op = MockOperator(task_id="test_task") - op._lock_for_execution = True - # TODO: Task-SDK - # op_copy = op.prepare_for_execution() - op_copy = op - - spy_agency.spy_on(XComArg.apply_upstream_relationship, call_original=False) + op_copy = op.prepare_for_execution() + spy_agency.spy_on(op._set_xcomargs_dependency, call_original=False) op_copy.arg1 = "b" - assert XComArg.apply_upstream_relationship.called is False + assert op._set_xcomargs_dependency.called is False def test_upstream_is_set_when_template_field_is_xcomarg(self): with DAG("xcomargs_test", schedule=None):