From 2510a63d5ca6993848c010ef66159e3aad62b778 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sat, 21 Dec 2024 14:00:25 +0800 Subject: [PATCH] Refactor _interleave_logs and _read_from_local --- airflow/utils/log/file_task_handler.py | 219 +++++++++++++++++++------ 1 file changed, 171 insertions(+), 48 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 09866de7214ed..70fc9ef9c8dfd 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -19,14 +19,16 @@ from __future__ import annotations +import heapq import logging import os -from collections.abc import Iterable +from collections.abc import Generator, Iterable from contextlib import suppress from enum import Enum -from functools import cached_property +from functools import cached_property, partial +from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional from urllib.parse import urljoin import pendulum @@ -49,6 +51,18 @@ logger = logging.getLogger(__name__) +CHUNK_SIZE = 1024 * 1024 * 5 # 5MB +DEFAULT_SORT_DATETIME = pendulum.datetime(2000, 1, 1) +HEAP_DUMP_SIZE = 500000 +HALF_HEAP_DUMP_SIZE = HEAP_DUMP_SIZE // 2 + +_ParsedLogRecordType = tuple[Optional[pendulum.DateTime], int, str] +"""Tuple of timestamp, line number, and line.""" +_ParsedLogStreamType = Generator[_ParsedLogRecordType, None, None] +"""Generator of parsed log streams, each yielding a tuple of timestamp, line number, and line.""" +_LogSourceType = tuple[list[str], list[_ParsedLogStreamType], int] +"""Tuple of messages, parsed log streams, total size of logs.""" + class LogType(str, Enum): """ @@ -110,30 +124,113 @@ def _parse_timestamp(line: str): return pendulum.parse(timestamp_str.strip("[]")) -def _parse_timestamps_in_log_file(lines: Iterable[str]): - timestamp = None - next_timestamp = None - for idx, line in enumerate(lines): - if line: - with suppress(Exception): - # next_timestamp unchanged if line can't be parsed - next_timestamp = _parse_timestamp(line) - if next_timestamp: - timestamp = next_timestamp - yield timestamp, idx, line +def _get_parsed_log_stream(file_path: Path) -> _ParsedLogStreamType: + with open(file_path) as f: + for file_chunk in iter(partial(f.read, CHUNK_SIZE), b""): + if not file_chunk: + break + # parse log lines + lines = file_chunk.splitlines() + timestamp = None + next_timestamp = None + for idx, line in enumerate(lines): + if line: + with suppress(Exception): + # next_timestamp unchanged if line can't be parsed + next_timestamp = _parse_timestamp(line) + if next_timestamp: + timestamp = next_timestamp + yield timestamp, idx, line + + +def _sort_key(timestamp: pendulum.DateTime | None, line_num: int) -> int: + """ + Generate a sort key for log record, to be used in K-way merge. + + :param timestamp: timestamp of the log line + :param line_num: line number of the log line + :return: a integer as sort key to avoid overhead of memory usage + """ + return (timestamp or DEFAULT_SORT_DATETIME).int_timestamp * 10000000 + line_num + + +def _add_log_from_parsed_log_streams_to_heap( + heap: list[tuple[int, str]], + parsed_log_streams: list[_ParsedLogStreamType], +) -> None: + """ + Add one log record from each parsed log stream to the heap. + + Remove any empty log stream from the list while iterating. + + :param heap: heap to store log records + :param parsed_log_streams: list of parsed log streams + """ + for log_stream in parsed_log_streams: + if log_stream is None: + parsed_log_streams.remove(log_stream) + continue + record: _ParsedLogRecordType | None = next(log_stream, None) + if record is None: + parsed_log_streams.remove(log_stream) + continue + timestamp, line_num, line = record + # take int as sort key to avoid overhead of memory usage + heapq.heappush(heap, (_sort_key(timestamp, line_num), line)) + + +def _interleave_logs(*parsed_log_streams: _ParsedLogStreamType) -> Generator[str, None, None]: + """ + Merge parsed log streams using K-way merge. + + By yielding HALF_CHUNK_SIZE records when heap size exceeds CHUNK_SIZE, we can reduce the chance of messing up the global order. + Since there are multiple log streams, we can't guarantee that the records are in global order. + + e.g. + + log_stream1: ---------- + log_stream2: ---- + log_stream3: -------- + + The first record of log_stream3 is later than the fourth record of log_stream1 ! + :param parsed_log_streams: parsed log streams + :return: interleaved log stream + """ + # don't need to push whole tuple into heap, which increases too much overhead + # push only sort_key and line into heap + heap: list[tuple[int, str]] = [] + # to allow removing empty streams while iterating + log_streams: list[_ParsedLogStreamType] = [log_stream for log_stream in parsed_log_streams] + + # add first record from each log stream to heap + _add_log_from_parsed_log_streams_to_heap(heap, log_streams) -def _interleave_logs(*logs): - records = [] - for log in logs: - records.extend(_parse_timestamps_in_log_file(log.splitlines())) + # keep adding records from logs until all logs are empty last = None - for timestamp, _, line in sorted( - records, key=lambda x: (x[0], x[1]) if x[0] else (pendulum.datetime(2000, 1, 1), x[1]) - ): - if line != last or not timestamp: # dedupe + while heap: + if not log_streams: + break + + _add_log_from_parsed_log_streams_to_heap(heap, log_streams) + + # yield HALF_HEAP_DUMP_SIZE records when heap size exceeds HEAP_DUMP_SIZE + if len(heap) >= HEAP_DUMP_SIZE: + for _ in range(HALF_HEAP_DUMP_SIZE): + _, line = heapq.heappop(heap) + if line != last: # dedupe + yield line + last = line + continue + + # yield remaining records + for _, line in heap: + if line != last: # dedupe yield line last = line + # free memory + del heap + del log_streams def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance: @@ -349,11 +446,15 @@ def _read( # is needed to get correct log path. worker_log_rel_path = self._render_filename(ti, try_number) messages_list: list[str] = [] - remote_logs: list[str] = [] - local_logs: list[str] = [] + remote_parsed_logs: list[_ParsedLogStreamType] = [] + remote_logs_size = 0 + local_parsed_logs: list[_ParsedLogStreamType] = [] + local_logs_size = 0 executor_messages: list[str] = [] - executor_logs: list[str] = [] - served_logs: list[str] = [] + executor_parsed_logs: list[_ParsedLogStreamType] = [] + executor_logs_size = 0 + served_parsed_logs: list[_ParsedLogStreamType] = [] + served_logs_size = 0 with suppress(NotImplementedError): remote_messages, remote_logs = self._read_remote_logs(ti, try_number, metadata) messages_list.extend(remote_messages) @@ -368,27 +469,18 @@ def _read( if not (remote_logs and ti.state not in State.unfinished): # when finished, if we have remote logs, no need to check local worker_log_full_path = Path(self.local_base, worker_log_rel_path) - local_messages, local_logs = self._read_from_local(worker_log_full_path) + local_messages, local_parsed_logs, local_logs_size = self._read_from_local(worker_log_full_path) messages_list.extend(local_messages) if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not has_k8s_exec_pod: served_messages, served_logs = self._read_from_logs_server(ti, worker_log_rel_path) messages_list.extend(served_messages) - elif ti.state not in State.unfinished and not (local_logs or remote_logs): + elif ti.state not in State.unfinished and not (local_parsed_logs or remote_logs): # ordinarily we don't check served logs, with the assumption that users set up # remote logging or shared drive for logs for persistence, but that's not always true # so even if task is done, if no local logs or remote logs are found, we'll check the worker served_messages, served_logs = self._read_from_logs_server(ti, worker_log_rel_path) messages_list.extend(served_messages) - logs = "\n".join( - _interleave_logs( - *local_logs, - *remote_logs, - *(executor_logs or []), - *served_logs, - ) - ) - log_pos = len(logs) # Log message source details are grouped: they are not relevant for most users and can # distract them from finding the root cause of their errors messages = " INFO - ::group::Log message source details\n" @@ -398,11 +490,29 @@ def _read( TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, ) + + current_total_logs_size = local_logs_size + remote_logs_size + executor_logs_size + served_logs_size + interleave_log_stream = _interleave_logs( + *local_parsed_logs, + *remote_parsed_logs, + *(executor_parsed_logs or []), + *served_parsed_logs, + ) + + # skip log stream until the last position if metadata and "log_pos" in metadata: - previous_chars = metadata["log_pos"] - logs = logs[previous_chars:] # Cut off previously passed log test as new tail - out_message = logs if "log_pos" in (metadata or {}) else messages + logs - return out_message, {"end_of_log": end_of_log, "log_pos": log_pos} + offset = metadata["log_pos"] + for _ in range(offset): + next(interleave_log_stream, None) + + out_stream: Iterable[str] + if "log_pos" in (metadata or {}): + # don't need to add messages, since we're in the middle of the log + out_stream = interleave_log_stream + else: + # first time reading log, add messages before interleaved log stream + out_stream = chain((msg for msg in messages), interleave_log_stream) + return out_stream, {"end_of_log": end_of_log, "log_pos": current_total_logs_size} @staticmethod def _get_pod_namespace(ti: TaskInstance): @@ -537,14 +647,27 @@ def _init_file(self, ti, *, identifier: str | None = None): return full_path @staticmethod - def _read_from_local(worker_log_path: Path) -> tuple[list[str], list[str]]: - messages = [] + def _read_from_local(worker_log_path: Path) -> _LogSourceType: + """ + Read logs from local file. + + :param worker_log_path: Path to the worker log file + :return: Tuple of messages, log streams, total size of logs + """ + total_log_size: int = 0 + messages: list[str] = [] + parsed_log_streams: list[_ParsedLogStreamType] = [] paths = sorted(worker_log_path.parent.glob(worker_log_path.name + "*")) - if paths: - messages.append("Found local files:") - messages.extend(f" * {x}" for x in paths) - logs = [file.read_text() for file in paths] - return messages, logs + if not paths: + return messages, parsed_log_streams, total_log_size + + messages.append("Found local files:") + for path in paths: + total_log_size += path.stat().st_size + messages.append(f" * {path}") + parsed_log_streams.append(_get_parsed_log_stream(path)) + + return messages, parsed_log_streams, total_log_size def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[list[str], list[str]]: messages = []