diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index d59548003697a..08839cc0bf720 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -25,7 +25,6 @@ import collections.abc import contextlib -import copy import functools import logging from collections.abc import Collection, Iterable, Sequence @@ -703,12 +702,6 @@ def get_outlet_defs(self): extended/overridden by subclasses. """ - def prepare_for_execution(self) -> BaseOperator: - """Lock task for execution to disable custom action in ``__setattr__`` and return a copy.""" - other = copy.copy(self) - other._lock_for_execution = True - return other - @prepare_lineage def pre_execute(self, context: Any): """Execute right before self.execute() is called.""" diff --git a/airflow/utils/context.py b/airflow/utils/context.py index c6cf2db498532..28bcd2fe6701d 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -477,7 +477,7 @@ def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: """ if not isinstance(source, Context): # Sometimes we are passed a plain dict (usually in tests, or in User's - # custom operators) -- be lienent about what we accept so we don't + # custom operators) -- be lenient about what we accept so we don't # break anything for users. return source diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index 774387a39b2b2..33dc414139e8f 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -657,9 +657,7 @@ def __setattr__(self: BaseOperator, key: str, value: Any): # an operator, example: # op = BashOperator() # op.bash_command = "sleep 1" - # self._set_xcomargs_dependency(key, value) - # TODO: The above line raises recursion error, so we need to find a way to resolve this. - ... + self._set_xcomargs_dependency(key, value) def __init__( self, @@ -1221,6 +1219,12 @@ def get_serialized_fields(cls): return cls.__serialized_fields + def prepare_for_execution(self) -> BaseOperator: + """Lock task for execution to disable custom action in ``__setattr__`` and return a copy.""" + other = copy.copy(self) + other._lock_for_execution = True + return other + def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Serialize; required by DAGNode.""" from airflow.serialization.enums import DagAttributeTypes diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index b83f8f6cf919a..8369ff03e20bd 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -57,6 +57,7 @@ class RuntimeTaskInstance(TaskInstance): """The Task Instance context from the API server, if any.""" def get_template_context(self): + from airflow.utils.context import Context # TODO: Move this to `airflow.sdk.execution_time.context` # once we port the entire context logic from airflow/utils/context.py ? @@ -114,7 +115,10 @@ def get_template_context(self): "ts_nodash_with_tz": ts_nodash_with_tz, } context.update(context_from_server) - return context + + # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it + # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890 + return Context(context) # type: ignore def xcom_pull( self, @@ -313,6 +317,7 @@ def run(ti: RuntimeTaskInstance, log: Logger): # TODO: pre execute etc. # TODO next_method to support resuming from deferred # TODO: Get a real context object + ti.task = ti.task.prepare_for_execution() context = ti.get_template_context() ti.task.execute(context) # type: ignore[attr-defined] msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))