Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Handling failed TI state for AirflowFailException & AirflowSensorTimeout #44954

Merged
merged 10 commits into from
Dec 18, 2024
4 changes: 4 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ def ti_update_state(

if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=ti_patch_payload.state)
if ti_patch_payload.state == State.FAILED:
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
# clear the next_method and next_kwargs
query = query.values(next_method=None, next_kwargs=None)
kaxil marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down
11 changes: 10 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,16 @@ def run(ti: RuntimeTaskInstance, log: Logger):
...
except (AirflowFailException, AirflowSensorTimeout):
# If AirflowFailException is raised, task should not retry.
...
# If a sensor in reschedule mode reaches timeout, task should not retry.

# TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951
# TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)

# TODO: Run task failure callbacks here
except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated):
...
except SystemExit:
Expand Down
2 changes: 2 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,8 @@ def watched_subprocess(self, mocker):
{"ok": True},
id="set_xcom_with_map_index",
),
# we aren't adding all states under TerminalTIState here, because this test's scope is only to check
# if it can handle TaskState message
pytest.param(
TaskState(state=TerminalTIState.SKIPPED, end_date=timezone.parse("2024-10-31T12:00:00Z")),
b"",
Expand Down
51 changes: 50 additions & 1 deletion task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pytest
from uuid6 import uuid7

from airflow.exceptions import AirflowSkipException
from airflow.exceptions import AirflowFailException, AirflowSensorTimeout, AirflowSkipException
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState
Expand Down Expand Up @@ -333,6 +333,55 @@ def __init__(self, *args, **kwargs):
)


@pytest.mark.parametrize(
["dag_id", "task_id", "fail_with_exception"],
[
pytest.param(
"basic_failed", "fail-exception", AirflowFailException("Oops. Failing by AirflowFailException!")
),
pytest.param(
"basic_failed2",
"sensor-timeout-exception",
AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"),
),
],
)
def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context):
"""Test running a basic task that marks itself as failed by raising exception."""

class CustomOperator(BaseOperator):
def __init__(self, e, *args, **kwargs):
super().__init__(*args, **kwargs)
self.e = e

def execute(self, context):
print(f"raising exception {self.e}")
raise self.e

task = CustomOperator(task_id=task_id, e=fail_with_exception)

what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, dag_id, task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY
)


class TestRuntimeTaskInstance:
def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
"""Test get_template_context without ti_context_from_server."""
Expand Down
30 changes: 30 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,36 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m
session.refresh(ti)
assert ti.last_heartbeat_at == time_now.add(minutes=10)

def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance):
from math import ceil

ti = create_task_instance(
task_id="test_ti_update_state_to_failed_table_check",
state=State.RUNNING,
)
ti.start_date = DEFAULT_START_DATE
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": State.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == State.FAILED
assert ti.next_method is None
assert ti.next_kwargs is None
# TODO: remove/amend this once https://github.com/apache/airflow/pull/45002 is merged
assert ceil(ti.duration) == 3600.00
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved


class TestTIPutRTIF:
def setup_method(self):
Expand Down