Skip to content

Commit

Permalink
fixup! AIP-72: Allow pushing and pulling XCom from Task Context
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil committed Dec 20, 2024
1 parent 13a742c commit 3a60948
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
7 changes: 0 additions & 7 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import collections.abc
import contextlib
import copy
import functools
import logging
from collections.abc import Collection, Iterable, Sequence
Expand Down Expand Up @@ -703,12 +702,6 @@ def get_outlet_defs(self):
extended/overridden by subclasses.
"""

def prepare_for_execution(self) -> BaseOperator:
"""Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
other = copy.copy(self)
other._lock_for_execution = True
return other

@prepare_lineage
def pre_execute(self, context: Any):
"""Execute right before self.execute() is called."""
Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
"""
if not isinstance(source, Context):
# Sometimes we are passed a plain dict (usually in tests, or in User's
# custom operators) -- be lienent about what we accept so we don't
# custom operators) -- be lenient about what we accept so we don't
# break anything for users.
return source

Expand Down
10 changes: 7 additions & 3 deletions task_sdk/src/airflow/sdk/definitions/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,9 +657,7 @@ def __setattr__(self: BaseOperator, key: str, value: Any):
# an operator, example:
# op = BashOperator()
# op.bash_command = "sleep 1"
# self._set_xcomargs_dependency(key, value)
# TODO: The above line raises recursion error, so we need to find a way to resolve this.
...
self._set_xcomargs_dependency(key, value)

def __init__(
self,
Expand Down Expand Up @@ -1221,6 +1219,12 @@ def get_serialized_fields(cls):

return cls.__serialized_fields

def prepare_for_execution(self) -> BaseOperator:
"""Lock task for execution to disable custom action in ``__setattr__`` and return a copy."""
other = copy.copy(self)
other._lock_for_execution = True
return other

def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Serialize; required by DAGNode."""
from airflow.serialization.enums import DagAttributeTypes
Expand Down
7 changes: 6 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class RuntimeTaskInstance(TaskInstance):
"""The Task Instance context from the API server, if any."""

def get_template_context(self):
from airflow.utils.context import Context
# TODO: Move this to `airflow.sdk.execution_time.context`
# once we port the entire context logic from airflow/utils/context.py ?

Expand Down Expand Up @@ -114,7 +115,10 @@ def get_template_context(self):
"ts_nodash_with_tz": ts_nodash_with_tz,
}
context.update(context_from_server)
return context

# Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it
# is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890
return Context(context) # type: ignore

def xcom_pull(
self,
Expand Down Expand Up @@ -313,6 +317,7 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: pre execute etc.
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
ti.task = ti.task.prepare_for_execution()
context = ti.get_template_context()
ti.task.execute(context) # type: ignore[attr-defined]
msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
Expand Down

0 comments on commit 3a60948

Please sign in to comment.