Skip to content

Commit

Permalink
Merge pull request #9 from KrystalDelusion/krys/destination_level
Browse files Browse the repository at this point in the history
Add destination aware log level filtering
  • Loading branch information
jix authored Jul 27, 2024
2 parents 148e3b0 + 4368c24 commit 7653205
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 8 deletions.
37 changes: 29 additions & 8 deletions src/yosys_mau/task_loop/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
current_task_or_none,
root_task,
)
from .context import task_context
from .context import (
TaskContextDict,
task_context,
)

Level = Literal["debug", "info", "warning", "error"]

Expand Down Expand Up @@ -120,14 +123,12 @@ class LogContext:

level: Level = "info"
"""The minimum log level to display/log.
Can be overridden for named destinations with `destination_levels`.
This does not stop `LogEvent` of smaller levels to be emitted. It is only used to filter which
messages to actually print/log. Hence, it does not affect any user installed `LogEvent`
handlers.
When logging to multiple destinations, currently there is no way to specify this per
destination.
"""
handlers."""

log_format: Callable[[LogEvent], str] = default_formatter
"""The formatter used to format log messages.
Expand All @@ -145,6 +146,14 @@ class LogContext:
Like `log_format` this is looked up by the log writing task, not the emitting task.
"""

destination_levels: TaskContextDict[str, Level] = TaskContextDict()
"""The minimum log level to display/log for named destinations.
Like `log_format` this is looked up by the log writing task, not the emitting task. If the
current destination has no key:value pair in this dictionary, the `level` will be looked up by
the task which emit the log.
"""


def log(*args: Any, level: Level = "info", cls: type[LogEvent] = LogEvent) -> LogEvent:
"""Produce log output.
Expand Down Expand Up @@ -297,7 +306,10 @@ def log_exception(exception: BaseException, raise_error: bool = True) -> LoggedE


def start_logging(
file: IO[Any] | None = None, err: bool = False, color: bool | None = None
file: IO[Any] | None = None,
err: bool = False,
color: bool | None = None,
destination_label: str | None = None,
) -> None:
"""Start logging all log events reaching the current task.
Expand All @@ -310,6 +322,8 @@ def start_logging(
:param color: Whether to use colors. Defaults to ``True`` for terminals and ``False`` otherwise.
When the ``NO_COLOR`` environment variable is set, this will be ignored and no colors will
be used.
:param destination_label: Used to look up destination specific log level filtering.
Used with `LogContext.destination_levels`.
"""
if _no_color:
color = False
Expand All @@ -318,7 +332,14 @@ def log_handler(event: LogEvent):
if file and file.closed:
remove_log_handler()
return
source_level = _level_order[event.source[LogContext].level]
emitter_default = event.source[LogContext].level
if destination_label:
destination_level = LogContext.destination_levels.get(
destination_label, emitter_default
)
else:
destination_level = emitter_default
source_level = _level_order[destination_level]
event_level = _level_order[event.level]
if event_level < source_level:
return
Expand Down
152 changes: 152 additions & 0 deletions tests/task_loop/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,158 @@ def main():
]


@pytest.mark.parametrize(
"label,expected",
[
("default", [2, 3, 4, 6]),
("debug", [1, 2, 3, 4, 5, 6]),
("info", [2, 4, 5, 6]),
("warning", [5, 6]),
("error", [6]),
("varied", [1, 2, 4, 6]),
],
)
def test_log_destinations(label: str, expected: list[str]):
log_output = io.StringIO()

def main():
tl.LogContext.time_format = fixed_time
tl.logging.start_logging(file=log_output, destination_label=label)

# tl.LogContext.level = "info" # implied
tl.LogContext.destination_levels["info"] = "info"
tl.LogContext.destination_levels["debug"] = "debug"
tl.LogContext.destination_levels["warning"] = "warning"
tl.LogContext.destination_levels["error"] = "error"

tl.LogContext.destination_levels["varied"] = "debug"
tl.log_debug("line 1")
tl.log("line 2")

tl.LogContext.level = "debug"
tl.LogContext.destination_levels["varied"] = "warning"
tl.log_debug("line 3")

del tl.LogContext.destination_levels["varied"]
tl.LogContext.destination_levels[""] = "warning"
tl.log("line 4")

tl.LogContext.level = "error"
tl.log_warning("line 5")
tl.log_error("line 6", raise_error=False)

tl.run_task_loop(main)

trimmed_output = [int(x[-1]) for x in log_output.getvalue().splitlines()]
assert trimmed_output == expected


@pytest.mark.parametrize("task", ["root", "task1", "task2"])
@pytest.mark.parametrize("label", ["debug", "info", "warning", "mixed1", "mixed2"])
def test_nested_destinations(task: str, label: str):
log_output = io.StringIO()

async def main():
tl.LogContext.time_format = fixed_time
tl.LogContext.scope = "?root?"
if task == "root":
tl.logging.start_logging(file=log_output, destination_label=label)
tl.LogContext.destination_levels["mixed1"] = "warning"

tl.LogContext.destination_levels["debug"] = "debug"
tl.LogContext.destination_levels["info"] = "info"
tl.LogContext.destination_levels["warning"] = "warning"
tl.LogContext.destination_levels["error"] = "error"
tl.LogContext.destination_levels["source"] = "warning"

tl.log("line 0")
sync_event = asyncio.Event()

async def run_task1():
tl.LogContext.scope = "?root?task1?"
if task == "task1":
tl.logging.start_logging(file=log_output, destination_label=label)
tl.LogContext.destination_levels["mixed1"] = "info"

tl.LogContext.destination_levels["mixed2"] = "debug" if task == "root" else "info"

task2 = tl.Task(on_run=run_task2)
tl.log("line 2")

await task2.started

tl.log_debug("line 4")

sync_event.set()

await task2.finished

tl.log("line 6")

async def run_task2():
tl.LogContext.scope = "?root?task1?task2?"
if task == "task2":
tl.logging.start_logging(file=log_output, destination_label=label)
tl.LogContext.destination_levels["mixed1"] = "debug"

tl.LogContext.destination_levels["mixed2"] = "debug" if task == "task1" else "error"

tl.log_debug("line 3")

await sync_event.wait()
tl.log_warning("line 5")

task1 = tl.Task(on_run=run_task1)

tl.log("line 1")

await task1.finished

tl.log("line 7")

tl.run_task_loop(main)

reference_list = [
"12:34:56 ?root?: line 0",
"12:34:56 ?root?: line 1",
"12:34:56 ?root?task1?: line 2",
"12:34:56 ?root?task1?task2?: DEBUG: line 3",
"12:34:56 ?root?task1?: DEBUG: line 4",
"12:34:56 ?root?task1?task2?: WARNING: line 5",
"12:34:56 ?root?task1?: line 6",
"12:34:56 ?root?: line 7",
]

label_map: dict[str, list[int]] = {
"debug": [0, 1, 2, 3, 4, 5, 6, 7],
"info": [0, 1, 2, 5, 6, 7],
"warning": [5],
}

if label in label_map:
filtered_list = [x for i, x in enumerate(reference_list) if i in label_map[label]]
expected = [x for x in filtered_list if task in x.split("?")]
else:
if label == "mixed1":
task_map: dict[str, list[int]] = {
"root": [5],
"task1": [2, 5, 6],
"task2": [3, 5],
}
elif label == "mixed2":
task_map: dict[str, list[int]] = {
"root": [0, 1, 2, 5, 6, 7],
"task1": [2, 5, 6],
"task2": [],
}
else:
assert False, f"unknown label {label}"
expected = [x for i, x in enumerate(reference_list) if i in task_map[task]]

print(log_output.getvalue())
assert log_output.getvalue().splitlines() == expected


def test_exception_logging():
log_output = io.StringIO()

Expand Down

0 comments on commit 7653205

Please sign in to comment.