diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 3a1545283e81b..ac3f80092a9c1 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -201,6 +201,10 @@ def ti_update_state( if isinstance(ti_patch_payload, TITerminalStatePayload): query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) + query = query.values(state=ti_patch_payload.state) + if ti_patch_payload.state == State.FAILED: + # clear the next_method and next_kwargs + query = query.values(next_method=None, next_kwargs=None) elif isinstance(ti_patch_payload, TIDeferredStatePayload): # Calculate timeout if it was passed timeout = None diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 92f400d46e2bb..11341e76356d2 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -283,7 +283,16 @@ def run(ti: RuntimeTaskInstance, log: Logger): ... except (AirflowFailException, AirflowSensorTimeout): # If AirflowFailException is raised, task should not retry. - ... + # If a sensor in reschedule mode reaches timeout, task should not retry. + + # TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951 + # TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952 + msg = TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + + # TODO: Run task failure callbacks here except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated): ... except SystemExit: diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 70f9e26486408..51a31b8982fe8 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -854,6 +854,8 @@ def watched_subprocess(self, mocker): {"ok": True}, id="set_xcom_with_map_index", ), + # we aren't adding all states under TerminalTIState here, because this test's scope is only to check + # if it can handle TaskState message pytest.param( TaskState(state=TerminalTIState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), b"", diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 2b812c92a7338..35ff65414f837 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -26,7 +26,7 @@ import pytest from uuid6 import uuid7 -from airflow.exceptions import AirflowSkipException +from airflow.exceptions import AirflowFailException, AirflowSensorTimeout, AirflowSkipException from airflow.sdk import DAG, BaseOperator from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState @@ -333,6 +333,55 @@ def __init__(self, *args, **kwargs): ) +@pytest.mark.parametrize( + ["dag_id", "task_id", "fail_with_exception"], + [ + pytest.param( + "basic_failed", "fail-exception", AirflowFailException("Oops. Failing by AirflowFailException!") + ), + pytest.param( + "basic_failed2", + "sensor-timeout-exception", + AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"), + ), + ], +) +def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context): + """Test running a basic task that marks itself as failed by raising exception.""" + + class CustomOperator(BaseOperator): + def __init__(self, e, *args, **kwargs): + super().__init__(*args, **kwargs) + self.e = e + + def execute(self, context): + print(f"raising exception {self.e}") + raise self.e + + task = CustomOperator(task_id=task_id, e=fail_with_exception) + + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + + ti = mocked_parse(what, dag_id, task) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + run(ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY + ) + + class TestRuntimeTaskInstance: def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): """Test get_template_context without ti_context_from_server.""" diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index e67d82a718cd6..85b6d11ee3b6f 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -455,6 +455,36 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m session.refresh(ti) assert ti.last_heartbeat_at == time_now.add(minutes=10) + def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance): + from math import ceil + + ti = create_task_instance( + task_id="test_ti_update_state_to_failed_table_check", + state=State.RUNNING, + ) + ti.start_date = DEFAULT_START_DATE + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": State.FAILED, + "end_date": DEFAULT_END_DATE.isoformat(), + }, + ) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + ti = session.get(TaskInstance, ti.id) + assert ti.state == State.FAILED + assert ti.next_method is None + assert ti.next_kwargs is None + # TODO: remove/amend this once https://github.com/apache/airflow/pull/45002 is merged + assert ceil(ti.duration) == 3600.00 + class TestTIPutRTIF: def setup_method(self):