diff --git a/airflow/utils/context.py b/airflow/utils/context.py index c6cf2db498532..28bcd2fe6701d 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -477,7 +477,7 @@ def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: """ if not isinstance(source, Context): # Sometimes we are passed a plain dict (usually in tests, or in User's - # custom operators) -- be lienent about what we accept so we don't + # custom operators) -- be lenient about what we accept so we don't # break anything for users. return source diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 97fadcafc409e..3912cddb211ec 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -167,7 +167,7 @@ class GetXCom(BaseModel): dag_id: str run_id: str task_id: str - map_index: int = -1 + map_index: int | None = -1 type: Literal["GetXCom"] = "GetXCom" diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 73bc446a28df8..d6d368f2df6b6 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -62,7 +62,6 @@ from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, - ErrorResponse, GetConnection, GetVariable, GetXCom, @@ -719,7 +718,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): if isinstance(conn, ConnectionResponse): conn_result = ConnectionResult.from_conn_response(conn) resp = conn_result.model_dump_json(exclude_unset=True).encode() - elif isinstance(conn, ErrorResponse): + else: resp = conn.model_dump_json().encode() elif isinstance(msg, GetVariable): var = self.client.variables.get(msg.key) @@ -729,6 +728,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) xcom_result = XComResult.from_xcom_response(xcom) resp = xcom_result.model_dump_json().encode() + log.info("XCom value response", resp=resp, xcom=xcom) elif isinstance(msg, DeferTask): self._terminal_state = IntermediateTIState.DEFERRED self.client.task_instances.defer(self.id, msg) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 48dd3ecbfcd67..8369ff03e20bd 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -33,12 +33,15 @@ from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.execution_time.comms import ( DeferTask, + GetXCom, RescheduleTask, SetRenderedFields, + SetXCom, StartupDetails, TaskState, ToSupervisor, ToTask, + XComResult, ) from airflow.sdk.execution_time.context import ConnectionAccessor @@ -54,6 +57,7 @@ class RuntimeTaskInstance(TaskInstance): """The Task Instance context from the API server, if any.""" def get_template_context(self): + from airflow.utils.context import Context # TODO: Move this to `airflow.sdk.execution_time.context` # once we port the entire context logic from airflow/utils/context.py ? @@ -111,11 +115,53 @@ def get_template_context(self): "ts_nodash_with_tz": ts_nodash_with_tz, } context.update(context_from_server) - return context - def xcom_pull(self, *args, **kwargs): ... + # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it + # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890 + return Context(context) # type: ignore + + def xcom_pull( + self, + task_ids: str | None = None, # TODO: Simplify to a single task_id (breaking change) + dag_id: str | None = None, + key: str = "return_value", + include_prior_dates: bool = False, + *, + map_index: int | None = None, + default: Any = None, + run_id: str | None = None, + ) -> Any: + """Pull XComs from the execution context.""" + log = structlog.get_logger(logger_name="task") + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXCom( + key=key, + dag_id=dag_id or self.dag_id, + task_id=task_ids or self.task_id, + run_id=run_id or self.run_id, + map_index=map_index or self.map_index, + ), + ) + + msg = SUPERVISOR_COMMS.get_message() + if TYPE_CHECKING: + isinstance(msg, XComResult) + log.info("The value is ", xcom=msg.value) + return msg.value or default - def xcom_push(self, *args, **kwargs): ... + def xcom_push(self, key: str, value: Any): + log = structlog.get_logger(logger_name="task") + SUPERVISOR_COMMS.send_request( + log=log, + msg=SetXCom( + key=key, + value=value, + dag_id=self.dag_id, + task_id=self.task_id, + run_id=self.run_id, + ), + ) def parse(what: StartupDetails) -> RuntimeTaskInstance: diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index 65d2b50f8a17f..a3220c3bef1e3 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -67,7 +67,7 @@ def test_getattr_connection(self): ) as mock_supervisor_comms: mock_supervisor_comms.get_message.return_value = conn_result - # Fetch the connection; Triggers __getattr__ + # Fetch the connection; triggers __getattr__ conn = accessor.mysql_conn expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)