Skip to content

Commit

Permalink
chore: complete typing for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertodonato committed Jan 5, 2025
1 parent ffc00b2 commit 11d8dfb
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 33 deletions.
82 changes: 50 additions & 32 deletions tests/loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from unittest.mock import ANY

from prometheus_aioexporter import MetricsRegistry
from prometheus_client.metrics import MetricWrapperBase
import pytest
from pytest_mock import MockerFixture
from pytest_structlog import StructuredLogCapture
from sqlalchemy.sql.elements import TextClause
import yaml

from query_exporter import loop
from query_exporter.config import load_config
from query_exporter.db import DataBase, DataBaseConfig
from query_exporter.loop import MetricsLastSeen, QueryLoop

from .conftest import QueryTracker

Expand Down Expand Up @@ -42,7 +44,7 @@ def registry() -> Iterator[MetricsRegistry]:
yield MetricsRegistry()


MakeQueryLoop = Callable[[], loop.QueryLoop]
MakeQueryLoop = Callable[[], QueryLoop]


@pytest.fixture
Expand All @@ -51,12 +53,12 @@ async def make_query_loop(
) -> AsyncIterator[MakeQueryLoop]:
query_loops = []

def make_loop() -> loop.QueryLoop:
def make_loop() -> QueryLoop:
config_file = tmp_path / "config.yaml"
config_file.write_text(yaml.dump(config_data), "utf-8")
config = load_config([config_file])
registry.create_metrics(config.metrics.values())
query_loop = loop.QueryLoop(config, registry)
query_loop = QueryLoop(config, registry)
query_loops.append(query_loop)
return query_loop

Expand All @@ -70,30 +72,33 @@ def make_loop() -> loop.QueryLoop:
@pytest.fixture
async def query_loop(
make_query_loop: MakeQueryLoop,
) -> AsyncIterator[loop.QueryLoop]:
) -> AsyncIterator[QueryLoop]:
yield make_query_loop()


MetricValues = list[int | float] | dict[tuple[str], list[int | float]]
MetricValues = list[int | float] | dict[tuple[str, ...], int | float]


def metric_values(metric, by_labels: tuple[str] = ()) -> MetricValues:
def metric_values(
metric: MetricWrapperBase, by_labels: tuple[str, ...] = ()
) -> MetricValues:
"""Return values for the metric."""
if metric._type == "gauge":
suffix = ""
elif metric._type == "counter":
suffix = "_total"

values = defaultdict(list)
values_by_label: dict[tuple[str, ...], int | float] = {}
values_by_suffix: dict[str, list[int | float]] = defaultdict(list)
for sample_suffix, labels, value, *_ in metric._samples():
if sample_suffix == suffix:
if by_labels:
label_values = tuple(labels[label] for label in by_labels)
values[label_values] = value
values_by_label[label_values] = value
else:
values[sample_suffix].append(value)
values_by_suffix[sample_suffix].append(value)

return values if by_labels else values[suffix]
return values_by_label if by_labels else values_by_suffix[suffix]


async def run_queries(db_file: Path, *queries: str) -> None:
Expand All @@ -105,7 +110,7 @@ async def run_queries(db_file: Path, *queries: str) -> None:

class TestMetricsLastSeen:
def test_update(self) -> None:
last_seen = loop.MetricsLastSeen({"m1": 50, "m2": 100})
last_seen = MetricsLastSeen({"m1": 50, "m2": 100})
last_seen.update("m1", {"l1": "v1", "l2": "v2"}, 100)
last_seen.update("m1", {"l1": "v3", "l2": "v4"}, 200)
last_seen.update("other", {"l3": "v100"}, 300)
Expand All @@ -117,12 +122,12 @@ def test_update(self) -> None:
}

def test_update_label_values_sorted_by_name(self) -> None:
last_seen = loop.MetricsLastSeen({"m1": 50})
last_seen = MetricsLastSeen({"m1": 50})
last_seen.update("m1", {"l2": "v2", "l1": "v1"}, 100)
assert last_seen._last_seen == {"m1": {("v1", "v2"): 100}}

def test_expire_series_not_expired(self) -> None:
last_seen = loop.MetricsLastSeen({"m1": 50})
last_seen = MetricsLastSeen({"m1": 50})
last_seen.update("m1", {"l1": "v1", "l2": "v2"}, 10)
last_seen.update("m1", {"l1": "v3", "l2": "v4"}, 20)
assert last_seen.expire_series(30) == {}
Expand All @@ -134,7 +139,7 @@ def test_expire_series_not_expired(self) -> None:
}

def test_expire_series(self) -> None:
last_seen = loop.MetricsLastSeen({"m1": 50, "m2": 100})
last_seen = MetricsLastSeen({"m1": 50, "m2": 100})
last_seen.update("m1", {"l1": "v1", "l2": "v2"}, 10)
last_seen.update("m1", {"l1": "v3", "l2": "v4"}, 100)
last_seen.update("m2", {"l3": "v100"}, 100)
Expand All @@ -145,7 +150,7 @@ def test_expire_series(self) -> None:
}

def test_expire_no_labels(self) -> None:
last_seen = loop.MetricsLastSeen({"m1": 50})
last_seen = MetricsLastSeen({"m1": 50})
last_seen.update("m1", {}, 10)
expired = last_seen.expire_series(120)
assert expired == {"m1": [()]}
Expand All @@ -154,21 +159,26 @@ def test_expire_no_labels(self) -> None:

class TestQueryLoop:
async def test_start(
self, query_tracker: QueryTracker, query_loop
self,
query_tracker: QueryTracker,
query_loop: QueryLoop,
) -> None:
await query_loop.start()
timed_call = query_loop._timed_calls["q"]
assert timed_call.running
await query_tracker.wait_results()

async def test_stop(self, query_loop) -> None:
async def test_stop(self, query_loop: QueryLoop) -> None:
await query_loop.start()
timed_call = query_loop._timed_calls["q"]
await query_loop.stop()
assert not timed_call.running

async def test_run_query(
self, query_tracker: QueryTracker, query_loop: loop.QueryLoop, registry
self,
query_tracker: QueryTracker,
query_loop: QueryLoop,
registry: MetricsRegistry,
) -> None:
await query_loop.start()
await query_tracker.wait_results()
Expand All @@ -192,15 +202,18 @@ async def test_run_scheduled_query(
) -> None:
event_loop = asyncio.get_running_loop()

def croniter(*args: t.Any) -> float:
def croniter(*args: t.Any) -> Iterator[float]:
while True:
# sync croniter time with the loop one
yield event_loop.time() + 60

mock_croniter = mocker.patch.object(loop, "croniter")
mock_croniter = mocker.patch("query_exporter.loop.croniter")
mock_croniter.side_effect = croniter
# ensure that both clocks advance in sync
mocker.patch.object(loop.time, "time", lambda: event_loop.time()) # type: ignore
mocker.patch(
"query_exporter.loop.time.time",
lambda: event_loop.time(),
)

del config_data["queries"]["q"]["interval"]
config_data["queries"]["q"]["schedule"] = "*/2 * * * *"
Expand Down Expand Up @@ -303,7 +316,9 @@ async def test_run_query_metrics_with_database_labels(
}

async def test_update_metric_decimal_value(
self, registry: MetricsRegistry, make_query_loop
self,
registry: MetricsRegistry,
make_query_loop: MakeQueryLoop,
) -> None:
db = DataBase(DataBaseConfig(name="db", dsn="sqlite://"))
query_loop = make_query_loop()
Expand All @@ -317,7 +332,7 @@ async def test_run_query_log(
self,
log: StructuredLogCapture,
query_tracker: QueryTracker,
query_loop: loop.QueryLoop,
query_loop: QueryLoop,
) -> None:
await query_loop.start()
await query_tracker.wait_queries()
Expand Down Expand Up @@ -371,7 +386,7 @@ async def test_run_query_increase_db_error_count(
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
registry,
registry: MetricsRegistry,
) -> None:
config_data["databases"]["db"]["dsn"] = "sqlite:////invalid"
query_loop = make_query_loop()
Expand All @@ -382,11 +397,11 @@ async def test_run_query_increase_db_error_count(

async def test_run_query_increase_database_error_count(
self,
mocker,
mocker: MockerFixture,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
registry,
registry: MetricsRegistry,
) -> None:
query_loop = make_query_loop()
db = query_loop._databases["db"]
Expand All @@ -402,7 +417,7 @@ async def test_run_query_increase_query_error_count(
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
registry,
registry: MetricsRegistry,
) -> None:
config_data["queries"]["q"]["sql"] = "SELECT 100.0 AS a, 200.0 AS b"
query_loop = make_query_loop()
Expand All @@ -415,21 +430,24 @@ async def test_run_query_increase_query_error_count(

async def test_run_query_increase_timeout_count(
self,
mocker: MockerFixture,
query_tracker: QueryTracker,
config_data: dict[str, t.Any],
make_query_loop: MakeQueryLoop,
registry,
registry: MetricsRegistry,
) -> None:
config_data["queries"]["q"]["timeout"] = 0.1
query_loop = make_query_loop()
await query_loop.start()
db = query_loop._databases["db"]
await db.connect()

async def execute(sql, parameters):
async def execute(
sql: TextClause, parameters: dict[str, t.Any] | None
) -> None:
await asyncio.sleep(1) # longer than timeout

db._conn.execute = execute
mocker.patch.object(db._conn, "execute", execute)

await query_tracker.wait_failures()
queries_metric = registry.get_metric("queries")
Expand All @@ -441,7 +459,7 @@ async def test_run_query_at_interval(
self,
advance_time: AdvanceTime,
query_tracker: QueryTracker,
query_loop: loop.QueryLoop,
query_loop: QueryLoop,
) -> None:
await query_loop.start()
await advance_time(0) # kick the first run
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ deps =
.[testing]
mypy
commands =
mypy query_exporter {posargs}
mypy {[base]lint_files} {posargs}

[testenv:coverage]
deps =
Expand Down

0 comments on commit 11d8dfb

Please sign in to comment.