Skip to content

Commit

Permalink
AIP-72: Allow pushing and pulling XCom from Task Context
Browse files Browse the repository at this point in the history
Part of #44481
  • Loading branch information
kaxil committed Dec 20, 2024
1 parent b6e3d1c commit ed34297
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 10 deletions.
4 changes: 3 additions & 1 deletion task_sdk/src/airflow/sdk/definitions/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ def __setattr__(self: BaseOperator, key: str, value: Any):
# an operator, example:
# op = BashOperator()
# op.bash_command = "sleep 1"
self._set_xcomargs_dependency(key, value)
# self._set_xcomargs_dependency(key, value)
# TODO: The above line raises recursion error, so we need to find a way to resolve this.
...

def __init__(
self,
Expand Down
9 changes: 8 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ class XComResult(XComResponse):

type: Literal["XComResult"] = "XComResult"

@classmethod
def from_xcom_response(cls, xcom_response: XComResponse) -> XComResult:
# Exclude defaults to avoid sending unnecessary data
# Pass the type as ConnectionResult explicitly so we can then call model_dump_json with exclude_unset=True
# to avoid sending unset fields (which are defaults in our case).
return cls(**xcom_response.model_dump())


class ConnectionResult(ConnectionResponse):
type: Literal["ConnectionResult"] = "ConnectionResult"
Expand Down Expand Up @@ -141,7 +148,7 @@ class GetXCom(BaseModel):
dag_id: str
run_id: str
task_id: str
map_index: int = -1
map_index: int | None = -1
type: Literal["GetXCom"] = "GetXCom"


Expand Down
8 changes: 5 additions & 3 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
ErrorResponse,
GetConnection,
GetVariable,
GetXCom,
Expand All @@ -72,6 +71,7 @@
StartupDetails,
TaskState,
ToSupervisor,
XComResult,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -715,14 +715,16 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result.model_dump_json(exclude_unset=True).encode()
elif isinstance(conn, ErrorResponse):
else:
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
resp = var.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
resp = xcom.model_dump_json(exclude_unset=True).encode()
xcom_result = XComResult.from_xcom_response(xcom)
resp = xcom_result.model_dump_json().encode()
log.info("XCom value response", resp=resp, xcom=xcom)
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.id, msg)
Expand Down
46 changes: 44 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
RescheduleTask,
SetRenderedFields,
SetXCom,
StartupDetails,
TaskState,
ToSupervisor,
ToTask,
XComResult,
)
from airflow.sdk.execution_time.context import ConnectionAccessor

Expand Down Expand Up @@ -113,9 +116,48 @@ def get_template_context(self):
context.update(context_from_server)
return context

def xcom_pull(self, *args, **kwargs): ...
def xcom_pull(
self,
task_ids: str | None = None, # TODO: Simplify to a single task_id (breaking change)
dag_id: str | None = None,
key: str = "return_value",
include_prior_dates: bool = False,
*,
map_index: int | None = None,
default: Any = None,
run_id: str | None = None,
) -> Any:
"""Pull XComs from the execution context."""
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
dag_id=dag_id or self.dag_id,
task_id=task_ids or self.task_id,
run_id=run_id or self.run_id,
map_index=map_index or self.map_index,
),
)

msg = SUPERVISOR_COMMS.get_message()
if TYPE_CHECKING:
isinstance(msg, XComResult)
log.info("The value is ", xcom=msg.value)
return msg.value or default

def xcom_push(self, *args, **kwargs): ...
def xcom_push(self, key: str, value: Any):
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=SetXCom(
key=key,
value=value,
dag_id=self.dag_id,
task_id=self.task_id,
run_id=self.run_id,
),
)


def parse(what: StartupDetails) -> RuntimeTaskInstance:
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_getattr_connection(self):
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = conn_result

# Fetch the connection; Triggers __getattr__
# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn

expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
Expand Down
4 changes: 2 additions & 2 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def watched_subprocess(self, mocker):
),
pytest.param(
GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"),
b'{"key":"test_key","value":"test_value"}\n',
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", -1),
XComResult(key="test_key", value="test_value"),
Expand All @@ -823,7 +823,7 @@ def watched_subprocess(self, mocker):
GetXCom(
dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key", map_index=2
),
b'{"key":"test_key","value":"test_value"}\n',
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
"xcoms.get",
("test_dag", "test_run", "test_task", "test_key", 2),
XComResult(key="test_key", value="test_value"),
Expand Down

0 comments on commit ed34297

Please sign in to comment.