Skip to content

Commit

Permalink
Refactor FileTaskHandler and TaskLogReader public method interface
Browse files Browse the repository at this point in the history
- refactor _read and read methods in FileTaskHandler
- refactor read_log_chunks method in TaskLogReader
  • Loading branch information
jason810496 authored and potiuk committed Dec 21, 2024
1 parent 2510a63 commit 1d0e6ed
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 22 deletions.
28 changes: 16 additions & 12 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def _read(
ti: TaskInstance,
try_number: int,
metadata: dict[str, Any] | None = None,
):
) -> tuple[Iterable[str], dict[str, Any]]:
"""
Template method that contains custom logic of reading logs given the try_number.
Expand Down Expand Up @@ -545,15 +545,17 @@ def _get_log_retrieval_url(
log_relative_path,
)

def read(self, task_instance, try_number=None, metadata=None):
def read(
self, task_instance, try_number=None, metadata=None
) -> tuple[list[str], list[Generator[str, None, None]], list[dict[str, Any]]]:
"""
Read logs of given task instance from local machine.
:param task_instance: task instance object
:param try_number: task instance try_number to read logs from. If None
it returns all logs separated by try_number
:param metadata: log metadata, can be used for steaming log reading and auto-tailing.
:return: a list of listed tuples which order log string by host
:return: tuple of hosts, log streams, and metadata_array
"""
# Task instance increments its try number when it starts to run.
# So the log for a particular task try will only show up when
Expand All @@ -563,25 +565,27 @@ def read(self, task_instance, try_number=None, metadata=None):
next_try = task_instance.try_number + 1
try_numbers = list(range(1, next_try))
elif try_number < 1:
logs = [
[("default_host", f"Error fetching the logs. Try number {try_number} is invalid.")],
]
return logs, [{"end_of_log": True}]
error_logs = [(log for log in [f"Error fetching the logs. Try number {try_number} is invalid."])]
return ["default_host"], error_logs, [{"end_of_log": True}]
else:
try_numbers = [try_number]

logs = [""] * len(try_numbers)
metadata_array = [{}] * len(try_numbers)
hosts = [""] * len(try_numbers)
logs: list = [] * len(try_numbers)
metadata_array: list[dict] = [{}] * len(try_numbers)

# subclasses implement _read and may not have log_type, which was added recently
for i, try_number_element in enumerate(try_numbers):
log, out_metadata = self._read(task_instance, try_number_element, metadata)
log_stream, out_metadata = self._read(task_instance, try_number_element, metadata)
# es_task_handler return logs grouped by host. wrap other handler returning log string
# with default/ empty host so that UI can render the response in the same way
logs[i] = log if self._read_grouped_logs() else [(task_instance.hostname, log)]
if not self._read_grouped_logs():
hosts[i] = task_instance.hostname

logs[i] = log_stream
metadata_array[i] = out_metadata

return logs, metadata_array
return hosts, logs, metadata_array

@staticmethod
def _prepare_log_folder(directory: Path, new_folder_permissions: int):
Expand Down
27 changes: 17 additions & 10 deletions airflow/utils/log/log_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import logging
import time
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from airflow.configuration import conf
from airflow.utils.helpers import render_log_filename
Expand All @@ -42,7 +42,7 @@ class TaskLogReader:

def read_log_chunks(
self, ti: TaskInstance, try_number: int | None, metadata
) -> tuple[list[tuple[tuple[str, str]]], dict[str, str]]:
) -> tuple[str, Iterable[str], dict[str, Any]]:
"""
Read chunks of Task Instance logs.
Expand All @@ -62,9 +62,14 @@ def read_log_chunks(
contain information about the task log which can enable you read logs to the
end.
"""
logs, metadatas = self.log_handler.read(ti, try_number, metadata=metadata)
metadata = metadatas[0]
return logs, metadata
hosts: list[str]
log_streams: list[Iterable[str]]
metadata_array: list[dict[str, Any]]
hosts, log_streams, metadata_array = self.log_handler.read(ti, try_number, metadata=metadata)
host = hosts[0]
log_stream = log_streams[0]
metadata = metadata_array[0]
return host, log_stream, metadata

def read_log_stream(self, ti: TaskInstance, try_number: int | None, metadata: dict) -> Iterator[str]:
"""
Expand All @@ -85,14 +90,16 @@ def read_log_stream(self, ti: TaskInstance, try_number: int | None, metadata: di
metadata.pop("offset", None)
metadata.pop("log_pos", None)
while True:
logs, metadata = self.read_log_chunks(ti, current_try_number, metadata)
for host, log in logs[0]:
yield "\n".join([host or "", log]) + "\n"
host: str
log_stream: Iterable[str]
host, log_stream, metadata = self.read_log_chunks(ti, current_try_number, metadata)
for log in log_stream:
yield f"\n{host or ''}\n{log}\n"
if "end_of_log" not in metadata or (
not metadata["end_of_log"]
and ti.state not in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED)
):
if not logs[0]:
if not log_stream:
# we did not receive any logs in this loop
# sleeping to conserve resources / limit requests on external services
time.sleep(self.STREAM_LOOP_SLEEP_SECONDS)
Expand Down

0 comments on commit 1d0e6ed

Please sign in to comment.