From 2294b4167180a7b09996ca4816ae6be5e4a31a70 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sat, 21 Dec 2024 14:00:25 +0800 Subject: [PATCH 1/9] Refactor _interleave_logs and _read_from_local --- airflow/utils/log/file_task_handler.py | 219 +++++++++++++++++++------ 1 file changed, 171 insertions(+), 48 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 09866de7214ed..70fc9ef9c8dfd 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -19,14 +19,16 @@ from __future__ import annotations +import heapq import logging import os -from collections.abc import Iterable +from collections.abc import Generator, Iterable from contextlib import suppress from enum import Enum -from functools import cached_property +from functools import cached_property, partial +from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional from urllib.parse import urljoin import pendulum @@ -49,6 +51,18 @@ logger = logging.getLogger(__name__) +CHUNK_SIZE = 1024 * 1024 * 5 # 5MB +DEFAULT_SORT_DATETIME = pendulum.datetime(2000, 1, 1) +HEAP_DUMP_SIZE = 500000 +HALF_HEAP_DUMP_SIZE = HEAP_DUMP_SIZE // 2 + +_ParsedLogRecordType = tuple[Optional[pendulum.DateTime], int, str] +"""Tuple of timestamp, line number, and line.""" +_ParsedLogStreamType = Generator[_ParsedLogRecordType, None, None] +"""Generator of parsed log streams, each yielding a tuple of timestamp, line number, and line.""" +_LogSourceType = tuple[list[str], list[_ParsedLogStreamType], int] +"""Tuple of messages, parsed log streams, total size of logs.""" + class LogType(str, Enum): """ @@ -110,30 +124,113 @@ def _parse_timestamp(line: str): return pendulum.parse(timestamp_str.strip("[]")) -def _parse_timestamps_in_log_file(lines: Iterable[str]): - timestamp = None - next_timestamp = None - for idx, line in enumerate(lines): - if line: - with suppress(Exception): - # next_timestamp unchanged if line can't be parsed - next_timestamp = _parse_timestamp(line) - if next_timestamp: - timestamp = next_timestamp - yield timestamp, idx, line +def _get_parsed_log_stream(file_path: Path) -> _ParsedLogStreamType: + with open(file_path) as f: + for file_chunk in iter(partial(f.read, CHUNK_SIZE), b""): + if not file_chunk: + break + # parse log lines + lines = file_chunk.splitlines() + timestamp = None + next_timestamp = None + for idx, line in enumerate(lines): + if line: + with suppress(Exception): + # next_timestamp unchanged if line can't be parsed + next_timestamp = _parse_timestamp(line) + if next_timestamp: + timestamp = next_timestamp + yield timestamp, idx, line + + +def _sort_key(timestamp: pendulum.DateTime | None, line_num: int) -> int: + """ + Generate a sort key for log record, to be used in K-way merge. + + :param timestamp: timestamp of the log line + :param line_num: line number of the log line + :return: a integer as sort key to avoid overhead of memory usage + """ + return (timestamp or DEFAULT_SORT_DATETIME).int_timestamp * 10000000 + line_num + + +def _add_log_from_parsed_log_streams_to_heap( + heap: list[tuple[int, str]], + parsed_log_streams: list[_ParsedLogStreamType], +) -> None: + """ + Add one log record from each parsed log stream to the heap. + + Remove any empty log stream from the list while iterating. + + :param heap: heap to store log records + :param parsed_log_streams: list of parsed log streams + """ + for log_stream in parsed_log_streams: + if log_stream is None: + parsed_log_streams.remove(log_stream) + continue + record: _ParsedLogRecordType | None = next(log_stream, None) + if record is None: + parsed_log_streams.remove(log_stream) + continue + timestamp, line_num, line = record + # take int as sort key to avoid overhead of memory usage + heapq.heappush(heap, (_sort_key(timestamp, line_num), line)) + + +def _interleave_logs(*parsed_log_streams: _ParsedLogStreamType) -> Generator[str, None, None]: + """ + Merge parsed log streams using K-way merge. + + By yielding HALF_CHUNK_SIZE records when heap size exceeds CHUNK_SIZE, we can reduce the chance of messing up the global order. + Since there are multiple log streams, we can't guarantee that the records are in global order. + + e.g. + + log_stream1: ---------- + log_stream2: ---- + log_stream3: -------- + + The first record of log_stream3 is later than the fourth record of log_stream1 ! + :param parsed_log_streams: parsed log streams + :return: interleaved log stream + """ + # don't need to push whole tuple into heap, which increases too much overhead + # push only sort_key and line into heap + heap: list[tuple[int, str]] = [] + # to allow removing empty streams while iterating + log_streams: list[_ParsedLogStreamType] = [log_stream for log_stream in parsed_log_streams] + + # add first record from each log stream to heap + _add_log_from_parsed_log_streams_to_heap(heap, log_streams) -def _interleave_logs(*logs): - records = [] - for log in logs: - records.extend(_parse_timestamps_in_log_file(log.splitlines())) + # keep adding records from logs until all logs are empty last = None - for timestamp, _, line in sorted( - records, key=lambda x: (x[0], x[1]) if x[0] else (pendulum.datetime(2000, 1, 1), x[1]) - ): - if line != last or not timestamp: # dedupe + while heap: + if not log_streams: + break + + _add_log_from_parsed_log_streams_to_heap(heap, log_streams) + + # yield HALF_HEAP_DUMP_SIZE records when heap size exceeds HEAP_DUMP_SIZE + if len(heap) >= HEAP_DUMP_SIZE: + for _ in range(HALF_HEAP_DUMP_SIZE): + _, line = heapq.heappop(heap) + if line != last: # dedupe + yield line + last = line + continue + + # yield remaining records + for _, line in heap: + if line != last: # dedupe yield line last = line + # free memory + del heap + del log_streams def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance: @@ -349,11 +446,15 @@ def _read( # is needed to get correct log path. worker_log_rel_path = self._render_filename(ti, try_number) messages_list: list[str] = [] - remote_logs: list[str] = [] - local_logs: list[str] = [] + remote_parsed_logs: list[_ParsedLogStreamType] = [] + remote_logs_size = 0 + local_parsed_logs: list[_ParsedLogStreamType] = [] + local_logs_size = 0 executor_messages: list[str] = [] - executor_logs: list[str] = [] - served_logs: list[str] = [] + executor_parsed_logs: list[_ParsedLogStreamType] = [] + executor_logs_size = 0 + served_parsed_logs: list[_ParsedLogStreamType] = [] + served_logs_size = 0 with suppress(NotImplementedError): remote_messages, remote_logs = self._read_remote_logs(ti, try_number, metadata) messages_list.extend(remote_messages) @@ -368,27 +469,18 @@ def _read( if not (remote_logs and ti.state not in State.unfinished): # when finished, if we have remote logs, no need to check local worker_log_full_path = Path(self.local_base, worker_log_rel_path) - local_messages, local_logs = self._read_from_local(worker_log_full_path) + local_messages, local_parsed_logs, local_logs_size = self._read_from_local(worker_log_full_path) messages_list.extend(local_messages) if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not has_k8s_exec_pod: served_messages, served_logs = self._read_from_logs_server(ti, worker_log_rel_path) messages_list.extend(served_messages) - elif ti.state not in State.unfinished and not (local_logs or remote_logs): + elif ti.state not in State.unfinished and not (local_parsed_logs or remote_logs): # ordinarily we don't check served logs, with the assumption that users set up # remote logging or shared drive for logs for persistence, but that's not always true # so even if task is done, if no local logs or remote logs are found, we'll check the worker served_messages, served_logs = self._read_from_logs_server(ti, worker_log_rel_path) messages_list.extend(served_messages) - logs = "\n".join( - _interleave_logs( - *local_logs, - *remote_logs, - *(executor_logs or []), - *served_logs, - ) - ) - log_pos = len(logs) # Log message source details are grouped: they are not relevant for most users and can # distract them from finding the root cause of their errors messages = " INFO - ::group::Log message source details\n" @@ -398,11 +490,29 @@ def _read( TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, ) + + current_total_logs_size = local_logs_size + remote_logs_size + executor_logs_size + served_logs_size + interleave_log_stream = _interleave_logs( + *local_parsed_logs, + *remote_parsed_logs, + *(executor_parsed_logs or []), + *served_parsed_logs, + ) + + # skip log stream until the last position if metadata and "log_pos" in metadata: - previous_chars = metadata["log_pos"] - logs = logs[previous_chars:] # Cut off previously passed log test as new tail - out_message = logs if "log_pos" in (metadata or {}) else messages + logs - return out_message, {"end_of_log": end_of_log, "log_pos": log_pos} + offset = metadata["log_pos"] + for _ in range(offset): + next(interleave_log_stream, None) + + out_stream: Iterable[str] + if "log_pos" in (metadata or {}): + # don't need to add messages, since we're in the middle of the log + out_stream = interleave_log_stream + else: + # first time reading log, add messages before interleaved log stream + out_stream = chain((msg for msg in messages), interleave_log_stream) + return out_stream, {"end_of_log": end_of_log, "log_pos": current_total_logs_size} @staticmethod def _get_pod_namespace(ti: TaskInstance): @@ -537,14 +647,27 @@ def _init_file(self, ti, *, identifier: str | None = None): return full_path @staticmethod - def _read_from_local(worker_log_path: Path) -> tuple[list[str], list[str]]: - messages = [] + def _read_from_local(worker_log_path: Path) -> _LogSourceType: + """ + Read logs from local file. + + :param worker_log_path: Path to the worker log file + :return: Tuple of messages, log streams, total size of logs + """ + total_log_size: int = 0 + messages: list[str] = [] + parsed_log_streams: list[_ParsedLogStreamType] = [] paths = sorted(worker_log_path.parent.glob(worker_log_path.name + "*")) - if paths: - messages.append("Found local files:") - messages.extend(f" * {x}" for x in paths) - logs = [file.read_text() for file in paths] - return messages, logs + if not paths: + return messages, parsed_log_streams, total_log_size + + messages.append("Found local files:") + for path in paths: + total_log_size += path.stat().st_size + messages.append(f" * {path}") + parsed_log_streams.append(_get_parsed_log_stream(path)) + + return messages, parsed_log_streams, total_log_size def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[list[str], list[str]]: messages = [] From c7fb4c13d24e8763872d36761bbe324c8a26a9aa Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sat, 21 Dec 2024 14:45:31 +0800 Subject: [PATCH 2/9] Refactor FileTaskHandler and TaskLogReader public method interface - refactor _read and read methods in FileTaskHandler - refactor read_log_chunks method in TaskLogReader --- airflow/utils/log/file_task_handler.py | 28 +++++++++++++++----------- airflow/utils/log/log_reader.py | 27 ++++++++++++++++--------- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 70fc9ef9c8dfd..69d5499e50334 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -421,7 +421,7 @@ def _read( ti: TaskInstance, try_number: int, metadata: dict[str, Any] | None = None, - ): + ) -> tuple[Iterable[str], dict[str, Any]]: """ Template method that contains custom logic of reading logs given the try_number. @@ -545,7 +545,9 @@ def _get_log_retrieval_url( log_relative_path, ) - def read(self, task_instance, try_number=None, metadata=None): + def read( + self, task_instance, try_number=None, metadata=None + ) -> tuple[list[str], list[Generator[str, None, None]], list[dict[str, Any]]]: """ Read logs of given task instance from local machine. @@ -553,7 +555,7 @@ def read(self, task_instance, try_number=None, metadata=None): :param try_number: task instance try_number to read logs from. If None it returns all logs separated by try_number :param metadata: log metadata, can be used for steaming log reading and auto-tailing. - :return: a list of listed tuples which order log string by host + :return: tuple of hosts, log streams, and metadata_array """ # Task instance increments its try number when it starts to run. # So the log for a particular task try will only show up when @@ -563,25 +565,27 @@ def read(self, task_instance, try_number=None, metadata=None): next_try = task_instance.try_number + 1 try_numbers = list(range(1, next_try)) elif try_number < 1: - logs = [ - [("default_host", f"Error fetching the logs. Try number {try_number} is invalid.")], - ] - return logs, [{"end_of_log": True}] + error_logs = [(log for log in [f"Error fetching the logs. Try number {try_number} is invalid."])] + return ["default_host"], error_logs, [{"end_of_log": True}] else: try_numbers = [try_number] - logs = [""] * len(try_numbers) - metadata_array = [{}] * len(try_numbers) + hosts = [""] * len(try_numbers) + logs: list = [] * len(try_numbers) + metadata_array: list[dict] = [{}] * len(try_numbers) # subclasses implement _read and may not have log_type, which was added recently for i, try_number_element in enumerate(try_numbers): - log, out_metadata = self._read(task_instance, try_number_element, metadata) + log_stream, out_metadata = self._read(task_instance, try_number_element, metadata) # es_task_handler return logs grouped by host. wrap other handler returning log string # with default/ empty host so that UI can render the response in the same way - logs[i] = log if self._read_grouped_logs() else [(task_instance.hostname, log)] + if not self._read_grouped_logs(): + hosts[i] = task_instance.hostname + + logs[i] = log_stream metadata_array[i] = out_metadata - return logs, metadata_array + return hosts, logs, metadata_array @staticmethod def _prepare_log_folder(directory: Path, new_folder_permissions: int): diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py index cc60500532fb1..9acdccaa281a1 100644 --- a/airflow/utils/log/log_reader.py +++ b/airflow/utils/log/log_reader.py @@ -18,9 +18,9 @@ import logging import time -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from airflow.configuration import conf from airflow.utils.helpers import render_log_filename @@ -42,7 +42,7 @@ class TaskLogReader: def read_log_chunks( self, ti: TaskInstance, try_number: int | None, metadata - ) -> tuple[list[tuple[tuple[str, str]]], dict[str, str]]: + ) -> tuple[str, Iterable[str], dict[str, Any]]: """ Read chunks of Task Instance logs. @@ -62,9 +62,14 @@ def read_log_chunks( contain information about the task log which can enable you read logs to the end. """ - logs, metadatas = self.log_handler.read(ti, try_number, metadata=metadata) - metadata = metadatas[0] - return logs, metadata + hosts: list[str] + log_streams: list[Iterable[str]] + metadata_array: list[dict[str, Any]] + hosts, log_streams, metadata_array = self.log_handler.read(ti, try_number, metadata=metadata) + host = hosts[0] + log_stream = log_streams[0] + metadata = metadata_array[0] + return host, log_stream, metadata def read_log_stream(self, ti: TaskInstance, try_number: int | None, metadata: dict) -> Iterator[str]: """ @@ -85,14 +90,16 @@ def read_log_stream(self, ti: TaskInstance, try_number: int | None, metadata: di metadata.pop("offset", None) metadata.pop("log_pos", None) while True: - logs, metadata = self.read_log_chunks(ti, current_try_number, metadata) - for host, log in logs[0]: - yield "\n".join([host or "", log]) + "\n" + host: str + log_stream: Iterable[str] + host, log_stream, metadata = self.read_log_chunks(ti, current_try_number, metadata) + for log in log_stream: + yield f"\n{host or ''}\n{log}\n" if "end_of_log" not in metadata or ( not metadata["end_of_log"] and ti.state not in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) ): - if not logs[0]: + if not log_stream: # we did not receive any logs in this loop # sleeping to conserve resources / limit requests on external services time.sleep(self.STREAM_LOOP_SLEEP_SECONDS) From c752e90dbf48460e3ad9a1fe773c2ea9545cc7e0 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sat, 21 Dec 2024 16:23:02 +0800 Subject: [PATCH 3/9] Refactor test for get from local, parse log timestamp --- tests/utils/test_log_handlers.py | 76 +++++++++++++++++++++++++------- 1 file changed, 60 insertions(+), 16 deletions(-) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index 19a432bb73767..b0f2d6f262e55 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -45,8 +45,8 @@ FileTaskHandler, LogType, _fetch_logs_from_service, + _get_parsed_log_stream, _interleave_logs, - _parse_timestamps_in_log_file, ) from airflow.utils.log.logging_mixin import set_context from airflow.utils.net import get_hostname @@ -370,19 +370,55 @@ def test__read_when_local(self, mock_read_local, create_task_instance): def test__read_from_local(self, tmp_path): """Tests the behavior of method _read_from_local""" - path1 = tmp_path / "hello1.log" - path2 = tmp_path / "hello1.log.suffix.log" - path1.write_text("file1 content") - path2.write_text("file2 content") + path1: Path = tmp_path / "hello1.log" + path2: Path = tmp_path / "hello1.log.suffix.log" + path1.write_text( + """file1 content 1 +file1 content 2 +[2022-11-16T00:05:54.295-0800] file1 content 3""" + ) + path2.write_text( + """file2 content 1 +file2 content 2 +[2022-11-16T00:05:54.295-0800] file2 content 3""" + ) fth = FileTaskHandler("") - assert fth._read_from_local(path1) == ( + messages, parsed_log_streams, log_size = fth._read_from_local(path1) + assert messages == [ + "Found local files:", + f" * {path1}", + f" * {path2}", + ] + # Optional[datetime], int, str = record + assert [[record for record in parsed_log_stream] for parsed_log_stream in parsed_log_streams] == [ [ - "Found local files:", - f" * {path1}", - f" * {path2}", + (None, 0, "file1 content 1"), + ( + None, + 1, + "file1 content 2", + ), + ( + pendulum.parse("2022-11-16T00:05:54.295-0800"), + 2, + "[2022-11-16T00:05:54.295-0800] file1 content 3", + ), ], - ["file1 content", "file2 content"], - ) + [ + (None, 0, "file2 content 1"), + ( + None, + 1, + "file2 content 2", + ), + ( + pendulum.parse("2022-11-16T00:05:54.295-0800"), + 2, + "[2022-11-16T00:05:54.295-0800] file2 content 3", + ), + ], + ] + assert log_size == 156 @pytest.mark.parametrize( "remote_logs, local_logs, served_logs_checked", @@ -555,11 +591,12 @@ def test_log_retrieval_valid_trigger(self, create_task_instance): """ -def test_parse_timestamps(): - actual = [] - for timestamp, _, _ in _parse_timestamps_in_log_file(log_sample.splitlines()): - actual.append(timestamp) - assert actual == [ +def test_get_parsed_log_stream(tmp_path): + log_path: Path = tmp_path / "test_parsed_log_stream.log" + log_path.write_text(log_sample) + expected_line_num = 0 + expected_lines = log_sample.splitlines() + expected_timestamps = [ pendulum.parse("2022-11-16T00:05:54.278000-08:00"), pendulum.parse("2022-11-16T00:05:54.278000-08:00"), pendulum.parse("2022-11-16T00:05:54.278000-08:00"), @@ -581,6 +618,13 @@ def test_parse_timestamps(): pendulum.parse("2022-11-16T00:05:54.592000-08:00"), pendulum.parse("2022-11-16T00:05:54.604000-08:00"), ] + for record in _get_parsed_log_stream(log_path): + assert record == ( + expected_timestamps[expected_line_num], + expected_line_num, + expected_lines[expected_line_num], + ) + expected_line_num += 1 def test_interleave_interleaves(): From afa5b1bd7e1cf408cd5cf915c30bf11f79a728d9 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sat, 21 Dec 2024 17:54:59 +0800 Subject: [PATCH 4/9] Fix line_num in parsed_log_stream, dedupe logic --- airflow/utils/log/file_task_handler.py | 29 +++++++++++++++----------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 69d5499e50334..e02e1f47ba4fb 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -53,6 +53,7 @@ CHUNK_SIZE = 1024 * 1024 * 5 # 5MB DEFAULT_SORT_DATETIME = pendulum.datetime(2000, 1, 1) +SORT_KEY_OFFSET = 10000000 HEAP_DUMP_SIZE = 500000 HALF_HEAP_DUMP_SIZE = HEAP_DUMP_SIZE // 2 @@ -126,6 +127,7 @@ def _parse_timestamp(line: str): def _get_parsed_log_stream(file_path: Path) -> _ParsedLogStreamType: with open(file_path) as f: + line_num = 0 # line number for each log line for file_chunk in iter(partial(f.read, CHUNK_SIZE), b""): if not file_chunk: break @@ -133,14 +135,16 @@ def _get_parsed_log_stream(file_path: Path) -> _ParsedLogStreamType: lines = file_chunk.splitlines() timestamp = None next_timestamp = None - for idx, line in enumerate(lines): + for line in lines: if line: with suppress(Exception): # next_timestamp unchanged if line can't be parsed next_timestamp = _parse_timestamp(line) if next_timestamp: timestamp = next_timestamp - yield timestamp, idx, line + + yield timestamp, line_num, line + line_num += 1 def _sort_key(timestamp: pendulum.DateTime | None, line_num: int) -> int: @@ -151,7 +155,7 @@ def _sort_key(timestamp: pendulum.DateTime | None, line_num: int) -> int: :param line_num: line number of the log line :return: a integer as sort key to avoid overhead of memory usage """ - return (timestamp or DEFAULT_SORT_DATETIME).int_timestamp * 10000000 + line_num + return int((timestamp or DEFAULT_SORT_DATETIME).timestamp() * 1000) * SORT_KEY_OFFSET + line_num def _add_log_from_parsed_log_streams_to_heap( @@ -167,9 +171,6 @@ def _add_log_from_parsed_log_streams_to_heap( :param parsed_log_streams: list of parsed log streams """ for log_stream in parsed_log_streams: - if log_stream is None: - parsed_log_streams.remove(log_stream) - continue record: _ParsedLogRecordType | None = next(log_stream, None) if record is None: parsed_log_streams.remove(log_stream) @@ -218,15 +219,19 @@ def _interleave_logs(*parsed_log_streams: _ParsedLogStreamType) -> Generator[str if len(heap) >= HEAP_DUMP_SIZE: for _ in range(HALF_HEAP_DUMP_SIZE): _, line = heapq.heappop(heap) - if line != last: # dedupe - yield line + if line == last: # dedupe + last = line + continue + yield line last = line - continue # yield remaining records - for _, line in heap: - if line != last: # dedupe - yield line + for _ in range(len(heap)): + _, line = heapq.heappop(heap) + if line == last: # dedupe + last = line + continue + yield line last = line # free memory del heap From 8617e5b5e02aea96e9550a876b3d5a07ddeb6b8a Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sat, 21 Dec 2024 18:05:28 +0800 Subject: [PATCH 5/9] Fix interleave releated test --- tests/utils/test_log_handlers.py | 56 ++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index b0f2d6f262e55..45a43fdba684c 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -21,6 +21,7 @@ import logging.config import os import re +from contextlib import suppress from http import HTTPStatus from importlib import reload from pathlib import Path @@ -47,6 +48,7 @@ _fetch_logs_from_service, _get_parsed_log_stream, _interleave_logs, + _parse_timestamp, ) from airflow.utils.log.logging_mixin import set_context from airflow.utils.net import get_hostname @@ -68,6 +70,20 @@ FILE_TASK_HANDLER = "task" +def _log_sample_to_parsed_log_stream(log_sample: str): + lines = log_sample.splitlines() + timestamp = None + next_timestamp = None + for idx, line in enumerate(lines): + if line: + with suppress(Exception): + # next_timestamp unchanged if line can't be parsed + next_timestamp = _parse_timestamp(line) + if next_timestamp: + timestamp = next_timestamp + yield timestamp, idx, line + + class TestFileTaskLogHandler: def clean_up(self): with create_session() as session: @@ -633,6 +649,7 @@ def test_interleave_interleaves(): "[2022-11-16T00:05:54.278-0800] {taskinstance.py:1258} INFO - Starting attempt 1 of 1", ] ) + log_sample1_stream = _log_sample_to_parsed_log_stream(log_sample1) log_sample2 = "\n".join( [ "[2022-11-16T00:05:54.295-0800] {taskinstance.py:1278} INFO - Executing on 2022-11-16 08:05:52.324532+00:00", @@ -643,6 +660,7 @@ def test_interleave_interleaves(): "[2022-11-16T00:05:54.309-0800] {standard_task_runner.py:83} INFO - Job 33648: Subtask wait", ] ) + log_sample2_stream = _log_sample_to_parsed_log_stream(log_sample2) log_sample3 = "\n".join( [ "[2022-11-16T00:05:54.457-0800] {task_command.py:376} INFO - Running on host daniels-mbp-2.lan", @@ -655,6 +673,7 @@ def test_interleave_interleaves(): "[2022-11-16T00:05:54.604-0800] {taskinstance.py:1360} INFO - Pausing task as DEFERRED. dag_id=simple_async_timedelta, task_id=wait, execution_date=20221116T080552, start_date=20221116T080554", ] ) + log_sample3_stream = _log_sample_to_parsed_log_stream(log_sample3) expected = "\n".join( [ "[2022-11-16T00:05:54.278-0800] {taskinstance.py:1258} INFO - Starting attempt 1 of 1", @@ -672,7 +691,9 @@ def test_interleave_interleaves(): "[2022-11-16T00:05:54.604-0800] {taskinstance.py:1360} INFO - Pausing task as DEFERRED. dag_id=simple_async_timedelta, task_id=wait, execution_date=20221116T080552, start_date=20221116T080554", ] ) - assert "\n".join(_interleave_logs(log_sample2, log_sample1, log_sample3)) == expected + interleave_log_stream = _interleave_logs(log_sample1_stream, log_sample2_stream, log_sample3_stream) + interleave_log_str = "\n".join(line for line in interleave_log_stream) + assert interleave_log_str == expected long_sample = """ @@ -749,22 +770,31 @@ def test_interleave_logs_correct_ordering(): [2023-01-17T12:47:11.883-0800] {triggerer_job.py:540} INFO - Trigger (ID 1) fired: TriggerEvent """ - assert sample_with_dupe == "\n".join(_interleave_logs(sample_with_dupe, "", sample_with_dupe)) + interleave_stream = _interleave_logs( + _log_sample_to_parsed_log_stream(sample_with_dupe), + _log_sample_to_parsed_log_stream(""), + _log_sample_to_parsed_log_stream(sample_with_dupe), + ) + interleave_str = "\n".join(line for line in interleave_stream) + assert interleave_str == sample_with_dupe def test_interleave_logs_correct_dedupe(): sample_without_dupe = """test, - test, - test, - test, - test, - test, - test, - test, - test, - test""" - - assert sample_without_dupe == "\n".join(_interleave_logs(",\n ".join(["test"] * 10))) +test, +test, +test, +test, +test, +test, +test, +test, +test""" + + interleave_stream = _interleave_logs( + _log_sample_to_parsed_log_stream(sample_without_dupe), + ) + assert "\n".join(line for line in interleave_stream) == "test,\ntest" def test_permissions_for_new_directories(tmp_path): From 70d6de95d8469aa66c9a81cac164c2688902382f Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sun, 22 Dec 2024 14:21:59 +0800 Subject: [PATCH 6/9] Fix format for _read and read_log_stream methods --- airflow/utils/log/file_task_handler.py | 6 ++++-- airflow/utils/log/log_reader.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index e02e1f47ba4fb..51363edc8805f 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -451,6 +451,7 @@ def _read( # is needed to get correct log path. worker_log_rel_path = self._render_filename(ti, try_number) messages_list: list[str] = [] + remote_logs: list[str] = [] # compact for running, will be remove in further commit remote_parsed_logs: list[_ParsedLogStreamType] = [] remote_logs_size = 0 local_parsed_logs: list[_ParsedLogStreamType] = [] @@ -516,7 +517,7 @@ def _read( out_stream = interleave_log_stream else: # first time reading log, add messages before interleaved log stream - out_stream = chain((msg for msg in messages), interleave_log_stream) + out_stream = chain([messages], interleave_log_stream) return out_stream, {"end_of_log": end_of_log, "log_pos": current_total_logs_size} @staticmethod @@ -566,6 +567,7 @@ def read( # So the log for a particular task try will only show up when # try number gets incremented in DB, i.e logs produced the time # after cli run and before try_number + 1 in DB will not be displayed. + try_numbers: list if try_number is None: next_try = task_instance.try_number + 1 try_numbers = list(range(1, next_try)) @@ -576,7 +578,7 @@ def read( try_numbers = [try_number] hosts = [""] * len(try_numbers) - logs: list = [] * len(try_numbers) + logs: list = [None] * len(try_numbers) metadata_array: list[dict] = [{}] * len(try_numbers) # subclasses implement _read and may not have log_type, which was added recently diff --git a/airflow/utils/log/log_reader.py b/airflow/utils/log/log_reader.py index 9acdccaa281a1..58356bb0d4655 100644 --- a/airflow/utils/log/log_reader.py +++ b/airflow/utils/log/log_reader.py @@ -94,7 +94,7 @@ def read_log_stream(self, ti: TaskInstance, try_number: int | None, metadata: di log_stream: Iterable[str] host, log_stream, metadata = self.read_log_chunks(ti, current_try_number, metadata) for log in log_stream: - yield f"\n{host or ''}\n{log}\n" + yield f"{host or ''}\n{log}\n" if "end_of_log" not in metadata or ( not metadata["end_of_log"] and ti.state not in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) From 453228dd2c94a0dab14027e4f1bdc7d4e6350a2d Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sun, 22 Dec 2024 16:09:44 +0800 Subject: [PATCH 7/9] Refactor _read_from_logs_server, read resp in chunk --- airflow/utils/log/file_task_handler.py | 44 +++++++++++++++++++++----- airflow/utils/serve_logs.py | 2 +- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 51363edc8805f..a01708ad676a3 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -111,6 +111,7 @@ def _fetch_logs_from_service(url, log_relative_path): url, timeout=timeout, headers={"Authorization": signer.generate_signed_token({"filename": log_relative_path})}, + stream=True, ) response.encoding = "utf-8" return response @@ -478,13 +479,17 @@ def _read( local_messages, local_parsed_logs, local_logs_size = self._read_from_local(worker_log_full_path) messages_list.extend(local_messages) if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not has_k8s_exec_pod: - served_messages, served_logs = self._read_from_logs_server(ti, worker_log_rel_path) + served_messages, served_parsed_logs, served_logs_size = self._read_from_logs_server( + ti, worker_log_rel_path + ) messages_list.extend(served_messages) elif ti.state not in State.unfinished and not (local_parsed_logs or remote_logs): # ordinarily we don't check served logs, with the assumption that users set up # remote logging or shared drive for logs for persistence, but that's not always true # so even if task is done, if no local logs or remote logs are found, we'll check the worker - served_messages, served_logs = self._read_from_logs_server(ti, worker_log_rel_path) + served_messages, served_parsed_logs, served_logs_size = self._read_from_logs_server( + ti, worker_log_rel_path + ) messages_list.extend(served_messages) # Log message source details are grouped: they are not relevant for most users and can @@ -680,9 +685,29 @@ def _read_from_local(worker_log_path: Path) -> _LogSourceType: return messages, parsed_log_streams, total_log_size - def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[list[str], list[str]]: - messages = [] - logs = [] + def _read_from_logs_server(self, ti, worker_log_rel_path) -> _LogSourceType: + total_log_size: int = 0 + messages: list[str] = [] + parsed_log_streams: list[_ParsedLogStreamType] = [] + + def _get_parsed_log_stream_from_response(response): + line_num = 0 + # read response in chunks instead of reading whole response text + for resp_chunk in response.iter_content(chunk_size=CHUNK_SIZE, decode_unicode=True): + if not resp_chunk: + break + lines = resp_chunk.splitlines() + timestamp = None + next_timestamp = None + for line in lines: + if line: + with suppress(Exception): + next_timestamp = _parse_timestamp(line) + if next_timestamp: + timestamp = next_timestamp + yield timestamp, line_num, line + line_num += 1 + try: log_type = LogType.TRIGGER if ti.triggerer_job else LogType.WORKER url, rel_path = self._get_log_retrieval_url(ti, worker_log_rel_path, log_type=log_type) @@ -698,9 +723,12 @@ def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[list[str], li ) # Check if the resource was properly fetched response.raise_for_status() - if response.text: + # get the total size of the logs + content_length = response.headers.get("Content-Length") + if content_length is not None: + total_log_size = int(content_length) messages.append(f"Found logs served from host {url}") - logs.append(response.text) + parsed_log_streams.append(_get_parsed_log_stream_from_response(response)) except Exception as e: from requests.exceptions import InvalidSchema @@ -709,7 +737,7 @@ def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[list[str], li else: messages.append(f"Could not read served logs: {e}") logger.exception("Could not read served logs") - return messages, logs + return messages, parsed_log_streams, total_log_size def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], list[str]]: """ diff --git a/airflow/utils/serve_logs.py b/airflow/utils/serve_logs.py index 31ef86600da79..6b2ca12a9c574 100644 --- a/airflow/utils/serve_logs.py +++ b/airflow/utils/serve_logs.py @@ -133,7 +133,7 @@ def validate_pre_signed_url(): @flask_app.route("/log/") def serve_logs_view(filename): - return send_from_directory(log_directory, filename, mimetype="application/json", as_attachment=False) + return send_from_directory(log_directory, filename, mimetype="text/plain", as_attachment=False) return flask_app From 9cb659caf2833e8dc38c20230f19c7e91f3f8e37 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sun, 22 Dec 2024 18:03:32 +0800 Subject: [PATCH 8/9] Add compatible utility for old and new log source --- airflow/executors/base_executor.py | 13 ++++-- airflow/utils/log/file_task_handler.py | 58 ++++++++++++++++++++++---- 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index a50dd801f9e68..36b5aed60b121 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -21,7 +21,7 @@ import logging import sys from collections import defaultdict, deque -from collections.abc import Sequence +from collections.abc import Generator, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional @@ -544,13 +544,20 @@ def execute_async( """ raise NotImplementedError() - def get_task_log(self, ti: TaskInstance, try_number: int) -> tuple[list[str], list[str]]: + def get_task_log( + self, ti: TaskInstance, try_number: int + ) -> ( + tuple[list[str], list[Generator[tuple[pendulum.DateTime | None, int, str], None, None]], int] + | tuple[list[str], list[str]] + ): """ Return the task logs. :param ti: A TaskInstance object :param try_number: current try_number to read log from - :return: tuple of logs and messages + :return: + - old interface: Tuple of messages and list of log lines. + - new interface: Tuple of messages, parsed log streams, total size of logs. """ return [], [] diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index a01708ad676a3..0c6cd828bdaf4 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -28,7 +28,7 @@ from functools import cached_property, partial from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from urllib.parse import urljoin import pendulum @@ -63,6 +63,10 @@ """Generator of parsed log streams, each yielding a tuple of timestamp, line number, and line.""" _LogSourceType = tuple[list[str], list[_ParsedLogStreamType], int] """Tuple of messages, parsed log streams, total size of logs.""" +_OldLogSourceType = tuple[list[str], list[str]] +"""Tuple of messages and list of log str, will be removed after all providers adapt to stream-based log reading.""" +_CompatibleLogSourceType = Union[_LogSourceType, _OldLogSourceType] +"""Compatible type hint for stream-based log reading and old log reading.""" class LogType(str, Enum): @@ -265,6 +269,27 @@ def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance: return val +def _get_compatible_parse_log_stream(remote_logs: list[str]) -> _ParsedLogStreamType: + """ + Compatible utility for new log reading(stream-based + k-way merge log) and old log reading(read whole log in memory + sorting). + + Turn old log reading into new stream-based log reading. + Will be removed after all providers adapt to stream-based log reading. + + :param remote_logs: list of log lines + :return: parsed log stream + """ + timestamp = None + next_timestamp = None + for line_num, line in enumerate(remote_logs): + with suppress(Exception): + # next_timestamp unchanged if line can't be parsed + next_timestamp = _parse_timestamp(line) + if next_timestamp: + timestamp = next_timestamp + yield timestamp, line_num, line + + class FileTaskHandler(logging.Handler): """ FileTaskHandler is a python log handler that handles and reads task instance logs. @@ -417,7 +442,7 @@ def _read_grouped_logs(self): return False @cached_property - def _executor_get_task_log(self) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]: + def _executor_get_task_log(self) -> Callable[[TaskInstance, int], _CompatibleLogSourceType]: """This cached property avoids loading executor repeatedly.""" executor = ExecutorLoader.get_default_executor() return executor.get_task_log @@ -452,7 +477,6 @@ def _read( # is needed to get correct log path. worker_log_rel_path = self._render_filename(ti, try_number) messages_list: list[str] = [] - remote_logs: list[str] = [] # compact for running, will be remove in further commit remote_parsed_logs: list[_ParsedLogStreamType] = [] remote_logs_size = 0 local_parsed_logs: list[_ParsedLogStreamType] = [] @@ -463,17 +487,35 @@ def _read( served_parsed_logs: list[_ParsedLogStreamType] = [] served_logs_size = 0 with suppress(NotImplementedError): - remote_messages, remote_logs = self._read_remote_logs(ti, try_number, metadata) + remote_log_result = self._read_remote_logs(ti, try_number, metadata) + if len(remote_log_result) == 2: + # old log reading + remote_messages, remote_logs = remote_log_result + remote_logs_size = sum(len(log) for log in remote_logs) + remote_parsed_logs = [_get_compatible_parse_log_stream(remote_logs)] + elif len(remote_log_result) == 3: + # new stream-based log reading + remote_messages, remote_parsed_logs, remote_logs_size = remote_log_result + else: + raise ValueError("Unexpected return value from _read_remote_logs") + # common logic for both old and new log reading messages_list.extend(remote_messages) has_k8s_exec_pod = False if ti.state == TaskInstanceState.RUNNING: response = self._executor_get_task_log(ti, try_number) - if response: + if response and len(response) == 2: executor_messages, executor_logs = response + executor_logs_size = sum(len(log) for log in executor_logs) + executor_parsed_logs = [_get_compatible_parse_log_stream(executor_logs)] + elif response and len(response) == 3: + executor_messages, executor_parsed_logs, executor_logs_size = response + else: + raise ValueError("Unexpected return value from executor.get_task_log") + # common logic for both old and new log reading if executor_messages: messages_list.extend(executor_messages) has_k8s_exec_pod = True - if not (remote_logs and ti.state not in State.unfinished): + if not (remote_parsed_logs and ti.state not in State.unfinished): # when finished, if we have remote logs, no need to check local worker_log_full_path = Path(self.local_base, worker_log_rel_path) local_messages, local_parsed_logs, local_logs_size = self._read_from_local(worker_log_full_path) @@ -483,7 +525,7 @@ def _read( ti, worker_log_rel_path ) messages_list.extend(served_messages) - elif ti.state not in State.unfinished and not (local_parsed_logs or remote_logs): + elif ti.state not in State.unfinished and not (local_parsed_logs or remote_parsed_logs): # ordinarily we don't check served logs, with the assumption that users set up # remote logging or shared drive for logs for persistence, but that's not always true # so even if task is done, if no local logs or remote logs are found, we'll check the worker @@ -739,7 +781,7 @@ def _get_parsed_log_stream_from_response(response): logger.exception("Could not read served logs") return messages, parsed_log_streams, total_log_size - def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], list[str]]: + def _read_remote_logs(self, ti, try_number, metadata=None) -> _CompatibleLogSourceType: """ Implement in subclasses to read from the remote service. From 268a7813b14aac7df782a14677c81debc1ca2114 Mon Sep 17 00:00:00 2001 From: jason810496 Date: Sun, 22 Dec 2024 18:05:30 +0800 Subject: [PATCH 9/9] Fix test_log_handlers - add check log_stream type utils - fix type checking for - test_file_task_handler_when_ti_value_is_invalid - test_file_task_handler - test_file_task_handler_running - test_file_task_handler_rotate_size_limit - test__read_when_local - test__read_served_logs_checked_when_done_and_no_local_or_remote_logs - also test compatible interface for test__read_served_logs_checked_when_done_and_no_local_or_remote_logs - which might call _read_remote_logs --- tests/utils/test_log_handlers.py | 146 ++++++++++++++++++++----------- 1 file changed, 94 insertions(+), 52 deletions(-) diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index 45a43fdba684c..2a1cf80764e10 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -24,7 +24,9 @@ from contextlib import suppress from http import HTTPStatus from importlib import reload +from itertools import chain from pathlib import Path +from types import GeneratorType from unittest import mock from unittest.mock import patch @@ -85,6 +87,9 @@ def _log_sample_to_parsed_log_stream(log_sample: str): class TestFileTaskLogHandler: + def _assert_is_log_stream_type(self, log_stream): + assert isinstance(log_stream, chain) or isinstance(log_stream, GeneratorType) + def clean_up(self): with create_session() as session: session.query(DagRun).delete() @@ -150,14 +155,17 @@ def task_callable(ti): assert hasattr(file_handler, "read") # Return value of read must be a tuple of list and list. # passing invalid `try_number` to read function - logs, metadatas = file_handler.read(ti, 0) - assert isinstance(logs, list) - assert isinstance(metadatas, list) - assert len(logs) == 1 - assert len(logs) == len(metadatas) - assert isinstance(metadatas[0], dict) - assert logs[0][0][0] == "default_host" - assert logs[0][0][1] == "Error fetching the logs. Try number 0 is invalid." + hosts, log_streams, metadata_array = file_handler.read(ti, 0) + assert isinstance(log_streams, list) + assert isinstance(metadata_array, list) + assert len(log_streams) == 1 + assert len(log_streams) == len(metadata_array) + self._assert_is_log_stream_type(log_streams[0]) + assert isinstance(metadata_array[0], dict) + assert hosts[0] == "default_host" + assert "Error fetching the logs. Try number 0 is invalid." in "\n".join( + line for line in log_streams[0] + ) # Remove the generated tmp log file. os.remove(log_filename) @@ -203,18 +211,21 @@ def task_callable(ti): file_handler.close() assert hasattr(file_handler, "read") - # Return value of read must be a tuple of list and list. - logs, metadatas = file_handler.read(ti) - assert isinstance(logs, list) - assert isinstance(metadatas, list) - assert len(logs) == 1 - assert len(logs) == len(metadatas) - assert isinstance(metadatas[0], dict) + # Return value of read must be a tuple of hosts, log_streams and metadata_array. + _, log_streams, metadata_array = file_handler.read(ti) + assert isinstance(log_streams, list) + assert isinstance(metadata_array, list) + assert len(log_streams) == 1 + assert len(log_streams) == len(metadata_array) + self._assert_is_log_stream_type(log_streams[0]) + assert isinstance(metadata_array[0], dict) target_re = r"\n\[[^\]]+\] {test_log_handlers.py:\d+} INFO - test\n" # We should expect our log line from the callable above to appear in # the logs we read back - assert re.search(target_re, logs[0][0][-1]), "Logs were " + str(logs) + log_str = "\n".join(line for line in log_streams[0]) + log_lines = log_str.splitlines() + assert re.search(target_re, log_lines[-2]), "Logs were " + log_str # Remove the generated tmp log file. os.remove(log_filename) @@ -259,15 +270,16 @@ def task_callable(ti): logger.info("Test") - # Return value of read must be a tuple of list and list. - logs, metadatas = file_handler.read(ti) - assert isinstance(logs, list) + # Return value of read must be a tuple of hosts, log_streams and metadata_array. + _, log_streams, metadata_array = file_handler.read(ti) + assert isinstance(log_streams, list) # Logs for running tasks should show up too. - assert isinstance(logs, list) - assert isinstance(metadatas, list) - assert len(logs) == 2 - assert len(logs) == len(metadatas) - assert isinstance(metadatas[0], dict) + assert isinstance(log_streams, list) + assert isinstance(metadata_array, list) + assert len(log_streams) == 2 + assert len(log_streams) == len(metadata_array) + self._assert_is_log_stream_type(log_streams[0]) + assert isinstance(metadata_array[0], dict) # Remove the generated tmp log file. os.remove(log_filename) @@ -336,25 +348,26 @@ def task_callable(ti): assert current_file_size < max_bytes_size # Return value of read must be a tuple of list and list. - logs, metadatas = file_handler.read(ti) + hosts, log_streams, metadata_array = file_handler.read(ti) # the log content should have the filename of both current log file and rotate log file. find_current_log = False find_rotate_log_1 = False - for log in logs: - if log_filename in str(log): + for log_stream in log_streams: + if log_filename in "\n".join(line for line in log_stream): find_current_log = True - if log_rotate_1_name in str(log): + if log_rotate_1_name in "\n".join(line for line in log_stream): find_rotate_log_1 = True assert find_current_log is True assert find_rotate_log_1 is True - assert isinstance(logs, list) + assert isinstance(hosts, list) # Logs for running tasks should show up too. - assert isinstance(logs, list) - assert isinstance(metadatas, list) - assert len(logs) == len(metadatas) - assert isinstance(metadatas[0], dict) + assert isinstance(log_streams, list) + assert isinstance(metadata_array, list) + assert len(log_streams) == len(metadata_array) + self._assert_is_log_stream_type(log_streams[0]) + assert isinstance(metadata_array[0], dict) # Remove the two generated tmp log files. os.remove(log_filename) @@ -369,7 +382,12 @@ def test__read_when_local(self, mock_read_local, create_task_instance): path = Path( "dag_id=dag_for_testing_local_log_read/run_id=scheduled__2016-01-01T00:00:00+00:00/task_id=task_for_testing_local_log_read/attempt=1.log" ) - mock_read_local.return_value = (["the messages"], ["the log"]) + # messages, parsed_log_streams, log_size + mock_read_local.return_value = ( + ["the messages"], + [_log_sample_to_parsed_log_stream("the log")], + len("the log"), + ) local_log_file_read = create_task_instance( dag_id="dag_for_testing_local_log_read", task_id="task_for_testing_local_log_read", @@ -377,11 +395,12 @@ def test__read_when_local(self, mock_read_local, create_task_instance): logical_date=DEFAULT_DATE, ) fth = FileTaskHandler("") - actual = fth._read(ti=local_log_file_read, try_number=1) + log_stream, metadata_array = fth._read(ti=local_log_file_read, try_number=1) mock_read_local.assert_called_with(path) - assert "*** the messages\n" in actual[0] - assert actual[0].endswith("the log") - assert actual[1] == {"end_of_log": True, "log_pos": 7} + log_stream_str = "".join(line for line in log_stream) + assert "*** the messages\n" in log_stream_str + assert log_stream_str.endswith("the log") + assert metadata_array == {"end_of_log": True, "log_pos": 7} def test__read_from_local(self, tmp_path): """Tests the behavior of method _read_from_local""" @@ -439,10 +458,12 @@ def test__read_from_local(self, tmp_path): @pytest.mark.parametrize( "remote_logs, local_logs, served_logs_checked", [ - (True, True, False), - (True, False, False), - (False, True, False), - (False, False, True), + ((True, True), True, False), + ((True, True), False, False), + ((True, False), True, False), + ((True, False), False, False), + ((False, None), True, False), + ((False, None), False, True), ], ) def test__read_served_logs_checked_when_done_and_no_local_or_remote_logs( @@ -467,24 +488,45 @@ def test__read_served_logs_checked_when_done_and_no_local_or_remote_logs( with conf_vars({("core", "executor"): executor_name}): reload(executor_loader) fth = FileTaskHandler("") - if remote_logs: + has_remote_logs, stream_based_remote_logs = remote_logs + if has_remote_logs: fth._read_remote_logs = mock.Mock() - fth._read_remote_logs.return_value = ["found remote logs"], ["remote\nlog\ncontent"] + if stream_based_remote_logs: + # testing for providers already migrated to stream based logs + # new implementation returns: messages, parsed_log_streams, log_size + fth._read_remote_logs.return_value = ( + ["found remote logs"], + [_log_sample_to_parsed_log_stream("remote\nlog\ncontent")], + 16, + ) + else: + # old implementation returns: messages, log_lines + fth._read_remote_logs.return_value = ["found remote logs"], ["remote\nlog\ncontent"] if local_logs: fth._read_from_local = mock.Mock() - fth._read_from_local.return_value = ["found local logs"], ["local\nlog\ncontent"] + fth._read_from_local.return_value = ( + ["found local logs"], + [_log_sample_to_parsed_log_stream("local\nlog\ncontent")], + 16, + ) fth._read_from_logs_server = mock.Mock() - fth._read_from_logs_server.return_value = ["this message"], ["this\nlog\ncontent"] - actual = fth._read(ti=ti, try_number=1) + fth._read_from_logs_server.return_value = ( + ["this message"], + [_log_sample_to_parsed_log_stream("this\nlog\ncontent")], + 16, + ) + + log_stream, metadata_array = fth._read(ti=ti, try_number=1) + log_stream_str = "\n".join(line for line in log_stream) if served_logs_checked: fth._read_from_logs_server.assert_called_once() - assert "*** this message\n" in actual[0] - assert actual[0].endswith("this\nlog\ncontent") - assert actual[1] == {"end_of_log": True, "log_pos": 16} + assert "*** this message\n" in log_stream_str + assert log_stream_str.endswith("this\nlog\ncontent") + assert metadata_array == {"end_of_log": True, "log_pos": 16} else: fth._read_from_logs_server.assert_not_called() - assert actual[0] - assert actual[1] + assert log_stream_str + assert metadata_array def test_add_triggerer_suffix(self): sample = "any/path/to/thing.txt"