Skip to content

Commit

Permalink
AIP-72: Allow pushing and pulling XCom from Task Context
Browse files Browse the repository at this point in the history
Part of #44481
  • Loading branch information
kaxil committed Dec 20, 2024
1 parent 72ab6b4 commit 13a742c
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 7 deletions.
4 changes: 3 additions & 1 deletion task_sdk/src/airflow/sdk/definitions/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ def __setattr__(self: BaseOperator, key: str, value: Any):
# an operator, example:
# op = BashOperator()
# op.bash_command = "sleep 1"
self._set_xcomargs_dependency(key, value)
# self._set_xcomargs_dependency(key, value)
# TODO: The above line raises recursion error, so we need to find a way to resolve this.
...

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class GetXCom(BaseModel):
dag_id: str
run_id: str
task_id: str
map_index: int = -1
map_index: int | None = -1
type: Literal["GetXCom"] = "GetXCom"


Expand Down
4 changes: 2 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
ErrorResponse,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -718,7 +717,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result.model_dump_json(exclude_unset=True).encode()
elif isinstance(conn, ErrorResponse):
else:
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
Expand All @@ -728,6 +727,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
xcom_result = XComResult.from_xcom_response(xcom)
resp = xcom_result.model_dump_json().encode()
log.info("XCom value response", resp=resp, xcom=xcom)
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.id, msg)
Expand Down
46 changes: 44 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
RescheduleTask,
SetRenderedFields,
SetXCom,
StartupDetails,
TaskState,
ToSupervisor,
ToTask,
XComResult,
)
from airflow.sdk.execution_time.context import ConnectionAccessor

Expand Down Expand Up @@ -113,9 +116,48 @@ def get_template_context(self):
context.update(context_from_server)
return context

def xcom_pull(self, *args, **kwargs): ...
def xcom_pull(
self,
task_ids: str | None = None, # TODO: Simplify to a single task_id (breaking change)
dag_id: str | None = None,
key: str = "return_value",
include_prior_dates: bool = False,
*,
map_index: int | None = None,
default: Any = None,
run_id: str | None = None,
) -> Any:
"""Pull XComs from the execution context."""
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
dag_id=dag_id or self.dag_id,
task_id=task_ids or self.task_id,
run_id=run_id or self.run_id,
map_index=map_index or self.map_index,
),
)

msg = SUPERVISOR_COMMS.get_message()
if TYPE_CHECKING:
isinstance(msg, XComResult)
log.info("The value is ", xcom=msg.value)
return msg.value or default

def xcom_push(self, *args, **kwargs): ...
def xcom_push(self, key: str, value: Any):
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=SetXCom(
key=key,
value=value,
dag_id=self.dag_id,
task_id=self.task_id,
run_id=self.run_id,
),
)


def parse(what: StartupDetails) -> RuntimeTaskInstance:
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_getattr_connection(self):
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = conn_result

# Fetch the connection; Triggers __getattr__
# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn

expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
Expand Down

0 comments on commit 13a742c

Please sign in to comment.