Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Handling up_for_retry task instance states #45070

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,24 @@ class TIRescheduleStatePayload(BaseModel):
end_date: UtcDateTime


class TIRetryStatePayload(BaseModel):
"""Schema for updating TaskInstance to a up_for_retry state."""

state: Annotated[
Literal[IntermediateTIState.UP_FOR_RETRY],
# Specify a default in the schema, but not in code, so Pydantic marks it as required.
WithJsonSchema(
{
"type": "string",
"enum": [IntermediateTIState.UP_FOR_RETRY],
"default": IntermediateTIState.UP_FOR_RETRY,
}
),
]
end_date: UtcDateTime
task_retries: int


def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
"""
Determine the discriminator key for TaskInstance state transitions.
Expand All @@ -129,6 +147,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
return "deferred"
elif state == TIState.UP_FOR_RESCHEDULE:
return "up_for_reschedule"
elif state == TIState.UP_FOR_RETRY:
return "up_for_retry"
return "_other_"


Expand All @@ -140,6 +160,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str:
Annotated[TITargetStatePayload, Tag("_other_")],
Annotated[TIDeferredStatePayload, Tag("deferred")],
Annotated[TIRescheduleStatePayload, Tag("up_for_reschedule")],
Annotated[TIRetryStatePayload, Tag("up_for_retry")],
],
Discriminator(ti_state_discriminator),
]
Expand Down
31 changes: 31 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TIEnterRunningPayload,
TIHeartbeatInfo,
TIRescheduleStatePayload,
TIRetryStatePayload,
TIRunContext,
TIStateUpdate,
TITerminalStatePayload,
Expand Down Expand Up @@ -167,6 +168,7 @@ def ti_run(
status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"},
status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"},
status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"},
status.HTTP_400_BAD_REQUEST: {"description": "Not a valid state transition"},
},
)
def ti_update_state(
Expand Down Expand Up @@ -252,6 +254,20 @@ def ti_update_state(
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
# clear the next_method and next_kwargs so that none of the retries pick them up
query = query.values(state=State.UP_FOR_RESCHEDULE, next_method=None, next_kwargs=None)
elif isinstance(ti_patch_payload, TIRetryStatePayload):
task_instance = session.get(TI, ti_id_str)
if not _is_eligible_to_retry(task_instance, ti_patch_payload.task_retries):
log.error("Task Instance %s cannot be retried", ti_id_str)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "bad_request",
"message": "Task Instance is not eligible to retry",
},
)
query = update(TI).where(TI.id == ti_id_str)
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=State.UP_FOR_RETRY, next_method=None, next_kwargs=None)
# TODO: Replace this with FastAPI's Custom Exception handling:
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
try:
Expand Down Expand Up @@ -354,3 +370,18 @@ def ti_put_rtif(
_update_rtif(task_instance, put_rtif_payload, session)

return {"message": "Rendered task instance fields successfully set"}


def _is_eligible_to_retry(task_instance, task_retries: int):
"""
Is task instance is eligible for retry.

:param task_instance: the task instance

:meta private:
"""
if task_instance.state == State.RESTARTING:
# If a task is RESTARTING state it is always eligible for retry
return True

return task_retries and task_instance.try_number <= task_instance.max_tries
Comment on lines +375 to +387
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly a port over of

def _is_eligible_to_retry(*, task_instance: TaskInstance):
"""
Is task instance is eligible for retry.
:param task_instance: the task instance
:meta private:
"""
if task_instance.state == TaskInstanceState.RESTARTING:
# If a task is cleared when running, it goes into RESTARTING state and is always
# eligible for retry
return True
if not getattr(task_instance, "task", None):
# Couldn't load the task, don't know number of retries, guess:
return task_instance.try_number <= task_instance.max_tries
if TYPE_CHECKING:
assert task_instance.task
return task_instance.task.retries and task_instance.try_number <= task_instance.max_tries
because we do not have "task_instance" table entries in SDK anymore.

Tried splitting it too because we do not have "task_instance.task" here

10 changes: 9 additions & 1 deletion task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TIEnterRunningPayload,
TIHeartbeatInfo,
TIRescheduleStatePayload,
TIRetryStatePayload,
TIRunContext,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
Expand All @@ -49,7 +50,7 @@
if TYPE_CHECKING:
from datetime import datetime

from airflow.sdk.execution_time.comms import RescheduleTask
from airflow.sdk.execution_time.comms import RescheduleTask, RetryTask
from airflow.typing_compat import ParamSpec

P = ParamSpec("P")
Expand Down Expand Up @@ -146,6 +147,13 @@ def reschedule(self, id: uuid.UUID, msg: RescheduleTask):
# Create a reschedule state payload from msg
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def retry(self, id: uuid.UUID, msg: RetryTask):
"""Tell the API server that this TI wants to retry."""
body = TIRetryStatePayload(**msg.model_dump(exclude_unset=True))

# Create a retry state payload from msg
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> dict[str, bool]:
"""Set Rendered Task Instance Fields via the API server."""
self.client.put(f"task-instances/{id}/rtif", json=body)
Expand Down
10 changes: 10 additions & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ class TIRescheduleStatePayload(BaseModel):
reschedule_date: Annotated[datetime, Field(title="Reschedule Date")]


class TIRetryStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a up_for_retry state.
"""

state: Annotated[Literal["up_for_retry"] | None, Field(title="State")] = "up_for_retry"
end_date: Annotated[datetime, Field(title="End Date")]
task_retries: Annotated[int, Field(title="Task Retries")]


class TITargetStatePayload(BaseModel):
"""
Schema for updating TaskInstance to a target state, excluding terminal and running states.
Expand Down
8 changes: 8 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
TerminalTIState,
TIDeferredStatePayload,
TIRescheduleStatePayload,
TIRetryStatePayload,
TIRunContext,
VariableResponse,
XComResponse,
Expand Down Expand Up @@ -122,6 +123,12 @@ class RescheduleTask(TIRescheduleStatePayload):
type: Literal["RescheduleTask"] = "RescheduleTask"


class RetryTask(TIRetryStatePayload):
"""Update a task instance state to up_for_retry."""

type: Literal["RetryTask"] = "RetryTask"


class GetXCom(BaseModel):
key: str
dag_id: str
Expand Down Expand Up @@ -200,6 +207,7 @@ class SetRenderedFields(BaseModel):
SetXCom,
SetRenderedFields,
RescheduleTask,
RetryTask,
],
Field(discriminator="type"),
]
4 changes: 4 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
GetXCom,
PutVariable,
RescheduleTask,
RetryTask,
SetXCom,
StartupDetails,
TaskState,
Expand Down Expand Up @@ -702,6 +703,9 @@ def _handle_request(self, msg, log):
elif isinstance(msg, RescheduleTask):
self._terminal_state = IntermediateTIState.UP_FOR_RESCHEDULE
self.client.task_instances.reschedule(self.id, msg)
elif isinstance(msg, RetryTask):
self._terminal_state = IntermediateTIState.UP_FOR_RETRY
self.client.task_instances.retry(self.id, msg)
elif isinstance(msg, SetXCom):
self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index)
elif isinstance(msg, PutVariable):
Expand Down
17 changes: 16 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.sdk.execution_time.comms import (
DeferTask,
RescheduleTask,
RetryTask,
SetRenderedFields,
StartupDetails,
TaskState,
Expand Down Expand Up @@ -296,8 +297,22 @@ def run(ti: RuntimeTaskInstance, log: Logger):
)

# TODO: Run task failure callbacks here
except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated):
except AirflowTaskTerminated:
...
except (AirflowTaskTimeout, AirflowException):
# Couldn't load the task, don't know number of retries, guess
if not getattr(ti, "task", None):
# Let us set the task_retries to default = 0
msg = RetryTask(
end_date=datetime.now(tz=timezone.utc),
task_retries=0,
)
else:
msg = RetryTask(
end_date=datetime.now(tz=timezone.utc),
# is `or 0` needed?
task_retries=ti.task.retries or 0,
)
Comment on lines +302 to +315
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the API and this PR, all we should care about is whether we should retry or not. The task ran, complained that it needs to retry, so we send a retry API call. The core logic of how retry works should be out of the scope of this PR.

except SystemExit:
...
except BaseException:
Expand Down
9 changes: 9 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
GetXCom,
PutVariable,
RescheduleTask,
RetryTask,
SetXCom,
TaskState,
VariableResult,
Expand Down Expand Up @@ -794,6 +795,14 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_deferred",
),
pytest.param(
RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z"), task_retries=1),
b"",
"task_instances.retry",
(TI_ID, RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z"), task_retries=1)),
"",
id="patch_task_instance_to_retry",
),
pytest.param(
RescheduleTask(
reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
Expand Down
37 changes: 37 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,43 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan
assert trs[0].map_index == -1
assert trs[0].duration == 129600

def test_ti_update_state_to_retry(self, client, session, create_task_instance, time_machine):
"""
Test that tests if the transition to retry state is handled correctly.
"""

instant = timezone.datetime(2024, 10, 30)
time_machine.move_to(instant, tick=False)

ti = create_task_instance(
task_id="test_ti_update_state_to_retry",
state=State.RUNNING,
session=session,
)
ti.start_date = instant
session.commit()

payload = {
"state": "up_for_retry",
"end_date": DEFAULT_END_DATE.isoformat(),
# a running task moving to up_for_retry
"task_retries": 1,
}

response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

tis = session.query(TaskInstance).all()
assert len(tis) == 1
assert tis[0].state == TaskInstanceState.UP_FOR_RETRY
assert tis[0].next_method is None
assert tis[0].next_kwargs is None
assert tis[0].duration == 129600


class TestTIHealthEndpoint:
def setup_method(self):
Expand Down