Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Allow pushing and pulling XCom from Task Context #45075

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
"""
if not isinstance(source, Context):
# Sometimes we are passed a plain dict (usually in tests, or in User's
# custom operators) -- be lienent about what we accept so we don't
# custom operators) -- be lenient about what we accept so we don't
# break anything for users.
return source

Expand Down
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class GetXCom(BaseModel):
dag_id: str
run_id: str
task_id: str
map_index: int = -1
map_index: int | None = -1
type: Literal["GetXCom"] = "GetXCom"


Expand Down
4 changes: 2 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
ErrorResponse,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -719,7 +718,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result.model_dump_json(exclude_unset=True).encode()
elif isinstance(conn, ErrorResponse):
else:
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
Expand All @@ -729,6 +728,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
xcom_result = XComResult.from_xcom_response(xcom)
resp = xcom_result.model_dump_json().encode()
log.info("XCom value response", resp=resp, xcom=xcom)
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.id, msg)
Expand Down
52 changes: 49 additions & 3 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
RescheduleTask,
SetRenderedFields,
SetXCom,
StartupDetails,
TaskState,
ToSupervisor,
ToTask,
XComResult,
)
from airflow.sdk.execution_time.context import ConnectionAccessor

Expand All @@ -54,6 +57,7 @@ class RuntimeTaskInstance(TaskInstance):
"""The Task Instance context from the API server, if any."""

def get_template_context(self):
from airflow.utils.context import Context
# TODO: Move this to `airflow.sdk.execution_time.context`
# once we port the entire context logic from airflow/utils/context.py ?

Expand Down Expand Up @@ -111,11 +115,53 @@ def get_template_context(self):
"ts_nodash_with_tz": ts_nodash_with_tz,
}
context.update(context_from_server)
return context

def xcom_pull(self, *args, **kwargs): ...
# Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it
# is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890
return Context(context) # type: ignore

def xcom_pull(
self,
task_ids: str | None = None, # TODO: Simplify to a single task_id (breaking change)
dag_id: str | None = None,
key: str = "return_value",
include_prior_dates: bool = False,
*,
map_index: int | None = None,
default: Any = None,
run_id: str | None = None,
) -> Any:
"""Pull XComs from the execution context."""
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
dag_id=dag_id or self.dag_id,
task_id=task_ids or self.task_id,
run_id=run_id or self.run_id,
map_index=map_index or self.map_index,
),
)

msg = SUPERVISOR_COMMS.get_message()
if TYPE_CHECKING:
isinstance(msg, XComResult)
log.info("The value is ", xcom=msg.value)
return msg.value or default

def xcom_push(self, *args, **kwargs): ...
def xcom_push(self, key: str, value: Any):
log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=SetXCom(
key=key,
value=value,
dag_id=self.dag_id,
task_id=self.task_id,
run_id=self.run_id,
),
)


def parse(what: StartupDetails) -> RuntimeTaskInstance:
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_getattr_connection(self):
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = conn_result

# Fetch the connection; Triggers __getattr__
# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn

expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
Expand Down
Loading