From cfc2157d3cbefe7b26bbccc604e1ca3357dd4c87 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 18 Dec 2024 13:04:44 +0530 Subject: [PATCH] AIP-72: Handling `failed` TI state for `AirflowFailException` & `AirflowSensorTimeout` (#44954) related: https://github.com/apache/airflow/issues/44414 We already have support for handling terminal states from the task execution side as well as the task SDK client side. (almost) and failed state is part of the terminal state. This PR extends the task runner's run function to handle cases when we have to fail a task: `AirflowFailException, AirflowSensorTimeout`. It is functionally very similar to #44786 As part of failing a task, multiple other things also needs to be done like: - Callbacks: which will eventually be converted to teardown tasks - Retries: Handled in https://github.com/apache/airflow/issues/44351 - unmapping TIs: https://github.com/apache/airflow/issues/44351 - Handling task history: will be handled by https://github.com/apache/airflow/issues/44952 - Handling downstream tasks and non teardown tasks: will be handled by https://github.com/apache/airflow/issues/44951 ### Testing performed #### End to End with Postman 1. Run airflow with breeze and run any DAG ![image](https://github.com/user-attachments/assets/fafc89ea-4e28-4802-912b-d72bf401d94b) 2. Login to metadata DB and get the "id" for your task instance from TI table ![image](https://github.com/user-attachments/assets/75440f0f-f62a-4277-a2e6-cb78bd666dd4) 3. Send a request to `fail` your task ![image](https://github.com/user-attachments/assets/5991e944-f416-4b79-9954-15f1a6ebdd79) Or using curl: ``` curl --location --request PATCH 'http://localhost:29091/execution/task-instances/0193cec2-f46b-7348-9c27-9869d835dc7b/state' \ --header 'Content-Type: application/json' \ --data '{ "state": "failed", "end_date": "2024-10-31T12:00:00Z" }' ``` 4. Refresh back the Airflow UI to see that the task is in failed state. ![image](https://github.com/user-attachments/assets/bb866dc6-e1d6-435e-abe4-2d04c97280ad) --- .../execution_api/routes/task_instances.py | 4 ++ .../airflow/sdk/execution_time/task_runner.py | 11 +++- .../tests/execution_time/test_supervisor.py | 2 + .../tests/execution_time/test_task_runner.py | 51 ++++++++++++++++++- .../routes/test_task_instances.py | 30 +++++++++++ 5 files changed, 96 insertions(+), 2 deletions(-) diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 3a1545283e81b..ac3f80092a9c1 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -201,6 +201,10 @@ def ti_update_state( if isinstance(ti_patch_payload, TITerminalStatePayload): query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) + query = query.values(state=ti_patch_payload.state) + if ti_patch_payload.state == State.FAILED: + # clear the next_method and next_kwargs + query = query.values(next_method=None, next_kwargs=None) elif isinstance(ti_patch_payload, TIDeferredStatePayload): # Calculate timeout if it was passed timeout = None diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 92f400d46e2bb..11341e76356d2 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -283,7 +283,16 @@ def run(ti: RuntimeTaskInstance, log: Logger): ... except (AirflowFailException, AirflowSensorTimeout): # If AirflowFailException is raised, task should not retry. - ... + # If a sensor in reschedule mode reaches timeout, task should not retry. + + # TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951 + # TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952 + msg = TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + + # TODO: Run task failure callbacks here except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated): ... except SystemExit: diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 70f9e26486408..51a31b8982fe8 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -854,6 +854,8 @@ def watched_subprocess(self, mocker): {"ok": True}, id="set_xcom_with_map_index", ), + # we aren't adding all states under TerminalTIState here, because this test's scope is only to check + # if it can handle TaskState message pytest.param( TaskState(state=TerminalTIState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")), b"", diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 2b812c92a7338..35ff65414f837 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -26,7 +26,7 @@ import pytest from uuid6 import uuid7 -from airflow.exceptions import AirflowSkipException +from airflow.exceptions import AirflowFailException, AirflowSensorTimeout, AirflowSkipException from airflow.sdk import DAG, BaseOperator from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState @@ -333,6 +333,55 @@ def __init__(self, *args, **kwargs): ) +@pytest.mark.parametrize( + ["dag_id", "task_id", "fail_with_exception"], + [ + pytest.param( + "basic_failed", "fail-exception", AirflowFailException("Oops. Failing by AirflowFailException!") + ), + pytest.param( + "basic_failed2", + "sensor-timeout-exception", + AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"), + ), + ], +) +def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context): + """Test running a basic task that marks itself as failed by raising exception.""" + + class CustomOperator(BaseOperator): + def __init__(self, e, *args, **kwargs): + super().__init__(*args, **kwargs) + self.e = e + + def execute(self, context): + print(f"raising exception {self.e}") + raise self.e + + task = CustomOperator(task_id=task_id, e=fail_with_exception) + + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + + ti = mocked_parse(what, dag_id, task) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + run(ti, log=mock.MagicMock()) + + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY + ) + + class TestRuntimeTaskInstance: def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): """Test get_template_context without ti_context_from_server.""" diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index e67d82a718cd6..85b6d11ee3b6f 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -455,6 +455,36 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m session.refresh(ti) assert ti.last_heartbeat_at == time_now.add(minutes=10) + def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance): + from math import ceil + + ti = create_task_instance( + task_id="test_ti_update_state_to_failed_table_check", + state=State.RUNNING, + ) + ti.start_date = DEFAULT_START_DATE + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json={ + "state": State.FAILED, + "end_date": DEFAULT_END_DATE.isoformat(), + }, + ) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + ti = session.get(TaskInstance, ti.id) + assert ti.state == State.FAILED + assert ti.next_method is None + assert ti.next_kwargs is None + # TODO: remove/amend this once https://github.com/apache/airflow/pull/45002 is merged + assert ceil(ti.duration) == 3600.00 + class TestTIPutRTIF: def setup_method(self):