Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Handling failed TI state for AirflowFailException & AirflowSensorTimeout #44954

Merged
merged 10 commits into from
Dec 18, 2024
6 changes: 6 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ def ti_update_state(

if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(
state=ti_patch_payload.state,
)
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
if ti_patch_payload.state == State.FAILED:
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
# clear the next_method and next_kwargs
query = query.values(next_method=None, next_kwargs=None)
kaxil marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down
11 changes: 10 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,16 @@ def run(ti: RuntimeTaskInstance, log: Logger):
...
except (AirflowFailException, AirflowSensorTimeout):
# If AirflowFailException is raised, task should not retry.
...
# If a sensor in reschedule mode reaches timeout, task should not retry.

# TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951
# TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)

# TODO: Run task failure callbacks here
except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated):
...
except SystemExit:
Expand Down
8 changes: 8 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,14 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_skipped",
),
pytest.param(
TaskState(state=TerminalTIState.FAILED, end_date=timezone.parse("2024-10-31T12:00:00Z")),
b"",
"",
(),
"",
id="patch_task_instance_to_failed",
),
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_handle_requests(
Expand Down
48 changes: 47 additions & 1 deletion task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pytest
from uuid6 import uuid7

from airflow.exceptions import AirflowSkipException
from airflow.exceptions import AirflowFailException, AirflowSensorTimeout, AirflowSkipException
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState
Expand Down Expand Up @@ -333,6 +333,52 @@ def __init__(self, *args, **kwargs):
)


@pytest.mark.parametrize(
["dag_id", "task_id", "fail_with_exception"],
[
pytest.param(
"basic_failed", "fail-exception", AirflowFailException("Oops. Failing by AirflowFailException!")
),
pytest.param(
"basic_failed2",
"sensor-timeout-exception",
AirflowSensorTimeout("Oops. Failing by AirflowSensorTimeout!"),
),
],
)
def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context):
"""Test running a basic task that marks itself as failed by raising exception."""
from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id=task_id,
python_callable=lambda: (_ for _ in ()).throw(
fail_with_exception,
),
)
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved

what = StartupDetails(
ti=TaskInstance(id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, dag_id, task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY
)


class TestRuntimeTaskInstance:
def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
"""Test get_template_context without ti_context_from_server."""
Expand Down
30 changes: 30 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,36 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m
session.refresh(ti)
assert ti.last_heartbeat_at == time_now.add(minutes=10)

def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance):
from math import ceil

ti = create_task_instance(
task_id="test_ti_update_state_to_terminal",
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
start_date=DEFAULT_START_DATE,
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
state=State.RUNNING,
)
ti.start_date = DEFAULT_START_DATE
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": State.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == State.FAILED
assert ti.next_method is None
assert ti.next_kwargs is None
assert ceil(ti.duration) == 3600.00
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved


class TestTIPutRTIF:
def setup_method(self):
Expand Down