Skip to content

Commit

Permalink
AIP-72: Handling failed TI state for AirflowFailException & `Airf…
Browse files Browse the repository at this point in the history
…lowSensorTimeout` (#44954)

related: #44414

We already have support for handling terminal states from the task execution side as well as the task SDK client side. (almost) and failed state is part of the terminal state.

This PR extends the task runner's run function to handle cases when we have to fail a task: `AirflowFailException, AirflowSensorTimeout`. It is functionally very similar to #44786

As part of failing a task, multiple other things also needs to be done like:
- Callbacks: which will eventually be converted to teardown tasks
- Retries: Handled in #44351
- unmapping TIs: #44351
- Handling task history: will be handled by #44952
- Handling downstream tasks and non teardown tasks: will be handled by #44951

### Testing performed
#### End to End with Postman

1. Run airflow with breeze and run any DAG
![image](https://github.com/user-attachments/assets/fafc89ea-4e28-4802-912b-d72bf401d94b)

2. Login to metadata DB and get the "id" for your task instance from TI table
![image](https://github.com/user-attachments/assets/75440f0f-f62a-4277-a2e6-cb78bd666dd4)

3. Send a request to `fail` your task
![image](https://github.com/user-attachments/assets/5991e944-f416-4b79-9954-15f1a6ebdd79)

Or using curl:
```
curl --location --request PATCH 'http://localhost:29091/execution/task-instances/0193cec2-f46b-7348-9c27-9869d835dc7b/state' \
--header 'Content-Type: application/json' \
--data '{
    "state": "failed",
    "end_date": "2024-10-31T12:00:00Z"
}'
```

4. Refresh back the Airflow UI to see that the task is in failed state.
![image](https://github.com/user-attachments/assets/bb866dc6-e1d6-435e-abe4-2d04c97280ad)
  • Loading branch information
amoghrajesh authored Dec 18, 2024
1 parent f631bef commit cfc2157
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 2 deletions.
4 changes: 4 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ def ti_update_state(

if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=ti_patch_payload.state)
if ti_patch_payload.state == State.FAILED:
# clear the next_method and next_kwargs
query = query.values(next_method=None, next_kwargs=None)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down
11 changes: 10 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,16 @@ def run(ti: RuntimeTaskInstance, log: Logger):
...
except (AirflowFailException, AirflowSensorTimeout):
# If AirflowFailException is raised, task should not retry.
...
# If a sensor in reschedule mode reaches timeout, task should not retry.

# TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951
# TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)

# TODO: Run task failure callbacks here
except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated):
...
except SystemExit:
Expand Down
2 changes: 2 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,8 @@ def watched_subprocess(self, mocker):
{"ok": True},
id="set_xcom_with_map_index",
),
# we aren't adding all states under TerminalTIState here, because this test's scope is only to check
# if it can handle TaskState message
pytest.param(
TaskState(state=TerminalTIState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")),
b"",
Expand Down
51 changes: 50 additions & 1 deletion task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pytest
from uuid6 import uuid7

from airflow.exceptions import AirflowSkipException
from airflow.exceptions import AirflowFailException, AirflowSensorTimeout, AirflowSkipException
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState
Expand Down Expand Up @@ -333,6 +333,55 @@ def __init__(self, *args, **kwargs):
)


@pytest.mark.parametrize(
["dag_id", "task_id", "fail_with_exception"],
[
pytest.param(
"basic_failed", "fail-exception", AirflowFailException("Oops. Failing by AirflowFailException!")
),
pytest.param(
"basic_failed2",
"sensor-timeout-exception",
AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"),
),
],
)
def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context):
"""Test running a basic task that marks itself as failed by raising exception."""

class CustomOperator(BaseOperator):
def __init__(self, e, *args, **kwargs):
super().__init__(*args, **kwargs)
self.e = e

def execute(self, context):
print(f"raising exception {self.e}")
raise self.e

task = CustomOperator(task_id=task_id, e=fail_with_exception)

what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, dag_id, task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY
)


class TestRuntimeTaskInstance:
def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
"""Test get_template_context without ti_context_from_server."""
Expand Down
30 changes: 30 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,36 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m
session.refresh(ti)
assert ti.last_heartbeat_at == time_now.add(minutes=10)

def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance):
from math import ceil

ti = create_task_instance(
task_id="test_ti_update_state_to_failed_table_check",
state=State.RUNNING,
)
ti.start_date = DEFAULT_START_DATE
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": State.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == State.FAILED
assert ti.next_method is None
assert ti.next_kwargs is None
# TODO: remove/amend this once https://github.com/apache/airflow/pull/45002 is merged
assert ceil(ti.duration) == 3600.00


class TestTIPutRTIF:
def setup_method(self):
Expand Down

0 comments on commit cfc2157

Please sign in to comment.