diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index a467bb182291c..78eab44df83c3 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -305,8 +305,18 @@ def run(ti: RuntimeTaskInstance, log: Logger): ) # TODO: Run task failure callbacks here - except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated): + except (AirflowTaskTimeout, AirflowException): + # TODO: handle the case of up_for_retry here ... + except AirflowTaskTerminated: + # External state updates are already handled with `ti_heartbeat` and will be + # updated already be another UI API. So, these exceptions should ideally never be thrown. + # If these are thrown, we should mark the TI state as failed. + msg = TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + # TODO: Run task failure callbacks here except SystemExit: ... except BaseException: diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 78f8058accc0b..96ac89db5cd9d 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -26,7 +26,12 @@ import pytest from uuid6 import uuid7 -from airflow.exceptions import AirflowFailException, AirflowSensorTimeout, AirflowSkipException +from airflow.exceptions import ( + AirflowFailException, + AirflowSensorTimeout, + AirflowSkipException, + AirflowTaskTerminated, +) from airflow.sdk import DAG, BaseOperator, Connection from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( @@ -352,6 +357,11 @@ def __init__(self, *args, **kwargs): "sensor-timeout-exception", AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"), ), + pytest.param( + "basic_failed3", + "task-terminated-exception", + AirflowTaskTerminated("Oops. Failing by AirflowTaskTerminated!"), + ), ], ) def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context):