Skip to content

Commit

Permalink
AIP-72: Allow pushing and pulling XCom from Task Context
Browse files Browse the repository at this point in the history
Part of #44481
  • Loading branch information
kaxil committed Dec 20, 2024
1 parent 2723508 commit d44c9f1
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 8 deletions.
2 changes: 1 addition & 1 deletion airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
"""
if not isinstance(source, Context):
# Sometimes we are passed a plain dict (usually in tests, or in User's
# custom operators) -- be lienent about what we accept so we don't
# custom operators) -- be lenient about what we accept so we don't
# break anything for users.
return source

Expand Down
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class GetXCom(BaseModel):
dag_id: str
run_id: str
task_id: str
map_index: int = -1
map_index: int | None = -1
type: Literal["GetXCom"] = "GetXCom"


Expand Down
4 changes: 2 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
ErrorResponse,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -719,7 +718,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result.model_dump_json(exclude_unset=True).encode()
elif isinstance(conn, ErrorResponse):
else:
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
Expand All @@ -729,6 +728,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
xcom_result = XComResult.from_xcom_response(xcom)
resp = xcom_result.model_dump_json().encode()
log.info("XCom value response", resp=resp, xcom=xcom)
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.id, msg)
Expand Down
52 changes: 49 additions & 3 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
RescheduleTask,
SetRenderedFields,
SetXCom,
StartupDetails,
TaskState,
ToSupervisor,
ToTask,
XComResult,
)
from airflow.sdk.execution_time.context import ConnectionAccessor

Expand All @@ -54,6 +57,7 @@ class RuntimeTaskInstance(TaskInstance):
"""The Task Instance context from the API server, if any."""

def get_template_context(self):
from airflow.utils.context import Context
# TODO: Move this to `airflow.sdk.execution_time.context`
# once we port the entire context logic from airflow/utils/context.py ?

Expand Down Expand Up @@ -111,11 +115,53 @@ def get_template_context(self):
"ts_nodash_with_tz": ts_nodash_with_tz,
}
context.update(context_from_server)
return context

def xcom_pull(self, *args, **kwargs): ...
# Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it
# is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890
return Context(context) # type: ignore

def xcom_pull(
self,
task_ids: str | None = None, # TODO: Simplify to a single task_id (breaking change)
dag_id: str | None = None,
key: str = "return_value",
include_prior_dates: bool = False,
*,
map_index: int | None = None,
default: Any = None,
run_id: str | None = None,
) -> Any:
"""Pull XComs from the execution context."""
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
dag_id=dag_id or self.dag_id,
task_id=task_ids or self.task_id,
run_id=run_id or self.run_id,
map_index=map_index or self.map_index,
),
)

msg = SUPERVISOR_COMMS.get_message()
if TYPE_CHECKING:
isinstance(msg, XComResult)
log.info("The value is ", xcom=msg.value)
return msg.value or default

def xcom_push(self, *args, **kwargs): ...
def xcom_push(self, key: str, value: Any):
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=SetXCom(
key=key,
value=value,
dag_id=self.dag_id,
task_id=self.task_id,
run_id=self.run_id,
),
)


def parse(what: StartupDetails) -> RuntimeTaskInstance:
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_getattr_connection(self):
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = conn_result

# Fetch the connection; Triggers __getattr__
# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn

expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
Expand Down

0 comments on commit d44c9f1

Please sign in to comment.