Skip to content

Commit

Permalink
Refactor _interleave_logs and _read_from_local
Browse files Browse the repository at this point in the history
  • Loading branch information
jason810496 authored and potiuk committed Dec 21, 2024
1 parent 70e77ef commit 2510a63
Showing 1 changed file with 171 additions and 48 deletions.
219 changes: 171 additions & 48 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@

from __future__ import annotations

import heapq
import logging
import os
from collections.abc import Iterable
from collections.abc import Generator, Iterable
from contextlib import suppress
from enum import Enum
from functools import cached_property
from functools import cached_property, partial
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional
from urllib.parse import urljoin

import pendulum
Expand All @@ -49,6 +51,18 @@

logger = logging.getLogger(__name__)

CHUNK_SIZE = 1024 * 1024 * 5 # 5MB
DEFAULT_SORT_DATETIME = pendulum.datetime(2000, 1, 1)
HEAP_DUMP_SIZE = 500000
HALF_HEAP_DUMP_SIZE = HEAP_DUMP_SIZE // 2

_ParsedLogRecordType = tuple[Optional[pendulum.DateTime], int, str]
"""Tuple of timestamp, line number, and line."""
_ParsedLogStreamType = Generator[_ParsedLogRecordType, None, None]
"""Generator of parsed log streams, each yielding a tuple of timestamp, line number, and line."""
_LogSourceType = tuple[list[str], list[_ParsedLogStreamType], int]
"""Tuple of messages, parsed log streams, total size of logs."""


class LogType(str, Enum):
"""
Expand Down Expand Up @@ -110,30 +124,113 @@ def _parse_timestamp(line: str):
return pendulum.parse(timestamp_str.strip("[]"))


def _parse_timestamps_in_log_file(lines: Iterable[str]):
timestamp = None
next_timestamp = None
for idx, line in enumerate(lines):
if line:
with suppress(Exception):
# next_timestamp unchanged if line can't be parsed
next_timestamp = _parse_timestamp(line)
if next_timestamp:
timestamp = next_timestamp
yield timestamp, idx, line
def _get_parsed_log_stream(file_path: Path) -> _ParsedLogStreamType:
with open(file_path) as f:
for file_chunk in iter(partial(f.read, CHUNK_SIZE), b""):
if not file_chunk:
break
# parse log lines
lines = file_chunk.splitlines()
timestamp = None
next_timestamp = None
for idx, line in enumerate(lines):
if line:
with suppress(Exception):
# next_timestamp unchanged if line can't be parsed
next_timestamp = _parse_timestamp(line)
if next_timestamp:
timestamp = next_timestamp
yield timestamp, idx, line


def _sort_key(timestamp: pendulum.DateTime | None, line_num: int) -> int:
"""
Generate a sort key for log record, to be used in K-way merge.
:param timestamp: timestamp of the log line
:param line_num: line number of the log line
:return: a integer as sort key to avoid overhead of memory usage
"""
return (timestamp or DEFAULT_SORT_DATETIME).int_timestamp * 10000000 + line_num


def _add_log_from_parsed_log_streams_to_heap(
heap: list[tuple[int, str]],
parsed_log_streams: list[_ParsedLogStreamType],
) -> None:
"""
Add one log record from each parsed log stream to the heap.
Remove any empty log stream from the list while iterating.
:param heap: heap to store log records
:param parsed_log_streams: list of parsed log streams
"""
for log_stream in parsed_log_streams:
if log_stream is None:
parsed_log_streams.remove(log_stream)
continue
record: _ParsedLogRecordType | None = next(log_stream, None)
if record is None:
parsed_log_streams.remove(log_stream)
continue
timestamp, line_num, line = record
# take int as sort key to avoid overhead of memory usage
heapq.heappush(heap, (_sort_key(timestamp, line_num), line))


def _interleave_logs(*parsed_log_streams: _ParsedLogStreamType) -> Generator[str, None, None]:
"""
Merge parsed log streams using K-way merge.
By yielding HALF_CHUNK_SIZE records when heap size exceeds CHUNK_SIZE, we can reduce the chance of messing up the global order.
Since there are multiple log streams, we can't guarantee that the records are in global order.
e.g.
log_stream1: ----------
log_stream2: ----
log_stream3: --------
The first record of log_stream3 is later than the fourth record of log_stream1 !
:param parsed_log_streams: parsed log streams
:return: interleaved log stream
"""
# don't need to push whole tuple into heap, which increases too much overhead
# push only sort_key and line into heap
heap: list[tuple[int, str]] = []
# to allow removing empty streams while iterating
log_streams: list[_ParsedLogStreamType] = [log_stream for log_stream in parsed_log_streams]

# add first record from each log stream to heap
_add_log_from_parsed_log_streams_to_heap(heap, log_streams)

def _interleave_logs(*logs):
records = []
for log in logs:
records.extend(_parse_timestamps_in_log_file(log.splitlines()))
# keep adding records from logs until all logs are empty
last = None
for timestamp, _, line in sorted(
records, key=lambda x: (x[0], x[1]) if x[0] else (pendulum.datetime(2000, 1, 1), x[1])
):
if line != last or not timestamp: # dedupe
while heap:
if not log_streams:
break

_add_log_from_parsed_log_streams_to_heap(heap, log_streams)

# yield HALF_HEAP_DUMP_SIZE records when heap size exceeds HEAP_DUMP_SIZE
if len(heap) >= HEAP_DUMP_SIZE:
for _ in range(HALF_HEAP_DUMP_SIZE):
_, line = heapq.heappop(heap)
if line != last: # dedupe
yield line
last = line
continue

# yield remaining records
for _, line in heap:
if line != last: # dedupe
yield line
last = line
# free memory
del heap
del log_streams


def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance:
Expand Down Expand Up @@ -349,11 +446,15 @@ def _read(
# is needed to get correct log path.
worker_log_rel_path = self._render_filename(ti, try_number)
messages_list: list[str] = []
remote_logs: list[str] = []
local_logs: list[str] = []
remote_parsed_logs: list[_ParsedLogStreamType] = []
remote_logs_size = 0
local_parsed_logs: list[_ParsedLogStreamType] = []
local_logs_size = 0
executor_messages: list[str] = []
executor_logs: list[str] = []
served_logs: list[str] = []
executor_parsed_logs: list[_ParsedLogStreamType] = []
executor_logs_size = 0
served_parsed_logs: list[_ParsedLogStreamType] = []
served_logs_size = 0
with suppress(NotImplementedError):
remote_messages, remote_logs = self._read_remote_logs(ti, try_number, metadata)
messages_list.extend(remote_messages)
Expand All @@ -368,27 +469,18 @@ def _read(
if not (remote_logs and ti.state not in State.unfinished):
# when finished, if we have remote logs, no need to check local
worker_log_full_path = Path(self.local_base, worker_log_rel_path)
local_messages, local_logs = self._read_from_local(worker_log_full_path)
local_messages, local_parsed_logs, local_logs_size = self._read_from_local(worker_log_full_path)
messages_list.extend(local_messages)
if ti.state in (TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED) and not has_k8s_exec_pod:
served_messages, served_logs = self._read_from_logs_server(ti, worker_log_rel_path)
messages_list.extend(served_messages)
elif ti.state not in State.unfinished and not (local_logs or remote_logs):
elif ti.state not in State.unfinished and not (local_parsed_logs or remote_logs):
# ordinarily we don't check served logs, with the assumption that users set up
# remote logging or shared drive for logs for persistence, but that's not always true
# so even if task is done, if no local logs or remote logs are found, we'll check the worker
served_messages, served_logs = self._read_from_logs_server(ti, worker_log_rel_path)
messages_list.extend(served_messages)

logs = "\n".join(
_interleave_logs(
*local_logs,
*remote_logs,
*(executor_logs or []),
*served_logs,
)
)
log_pos = len(logs)
# Log message source details are grouped: they are not relevant for most users and can
# distract them from finding the root cause of their errors
messages = " INFO - ::group::Log message source details\n"
Expand All @@ -398,11 +490,29 @@ def _read(
TaskInstanceState.RUNNING,
TaskInstanceState.DEFERRED,
)

current_total_logs_size = local_logs_size + remote_logs_size + executor_logs_size + served_logs_size
interleave_log_stream = _interleave_logs(
*local_parsed_logs,
*remote_parsed_logs,
*(executor_parsed_logs or []),
*served_parsed_logs,
)

# skip log stream until the last position
if metadata and "log_pos" in metadata:
previous_chars = metadata["log_pos"]
logs = logs[previous_chars:] # Cut off previously passed log test as new tail
out_message = logs if "log_pos" in (metadata or {}) else messages + logs
return out_message, {"end_of_log": end_of_log, "log_pos": log_pos}
offset = metadata["log_pos"]
for _ in range(offset):
next(interleave_log_stream, None)

out_stream: Iterable[str]
if "log_pos" in (metadata or {}):
# don't need to add messages, since we're in the middle of the log
out_stream = interleave_log_stream
else:
# first time reading log, add messages before interleaved log stream
out_stream = chain((msg for msg in messages), interleave_log_stream)
return out_stream, {"end_of_log": end_of_log, "log_pos": current_total_logs_size}

@staticmethod
def _get_pod_namespace(ti: TaskInstance):
Expand Down Expand Up @@ -537,14 +647,27 @@ def _init_file(self, ti, *, identifier: str | None = None):
return full_path

@staticmethod
def _read_from_local(worker_log_path: Path) -> tuple[list[str], list[str]]:
messages = []
def _read_from_local(worker_log_path: Path) -> _LogSourceType:
"""
Read logs from local file.
:param worker_log_path: Path to the worker log file
:return: Tuple of messages, log streams, total size of logs
"""
total_log_size: int = 0
messages: list[str] = []
parsed_log_streams: list[_ParsedLogStreamType] = []
paths = sorted(worker_log_path.parent.glob(worker_log_path.name + "*"))
if paths:
messages.append("Found local files:")
messages.extend(f" * {x}" for x in paths)
logs = [file.read_text() for file in paths]
return messages, logs
if not paths:
return messages, parsed_log_streams, total_log_size

messages.append("Found local files:")
for path in paths:
total_log_size += path.stat().st_size
messages.append(f" * {path}")
parsed_log_streams.append(_get_parsed_log_stream(path))

return messages, parsed_log_streams, total_log_size

def _read_from_logs_server(self, ti, worker_log_rel_path) -> tuple[list[str], list[str]]:
messages = []
Expand Down

0 comments on commit 2510a63

Please sign in to comment.