Skip to content

Commit

Permalink
enable strict mode for mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
albertodonato committed Oct 21, 2024
1 parent e4cb28a commit ee638b3
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 75 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,4 @@ omit = [
ignore_missing_imports = true
install_types = true
non_interactive = true
warn_return_any = true
warn_unused_configs = true
strict = true
36 changes: 7 additions & 29 deletions query_exporter/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

from collections import defaultdict
from collections.abc import Mapping
from dataclasses import (
dataclass,
field,
)
from dataclasses import dataclass
from functools import reduce
from importlib import resources
import itertools
Expand All @@ -27,9 +24,8 @@
import yaml

from .db import (
create_db_engine,
DATABASE_LABEL,
DataBaseError,
DataBaseConfig,
InvalidQueryParameters,
InvalidQuerySchedule,
Query,
Expand Down Expand Up @@ -85,24 +81,6 @@ class ConfigError(Exception):
"""Configuration is invalid."""


@dataclass(frozen=True)
class DataBaseConfig:
"""Configuration for a database."""

name: str
dsn: str
connect_sql: list[str] = field(default_factory=list)
labels: dict[str, str] = field(default_factory=dict)
keep_connected: bool = True
autocommit: bool = True

def __post_init__(self):
try:
create_db_engine(self.dsn)
except DataBaseError as e:
raise ConfigError(str(e))


@dataclass(frozen=True)
class Config:
"""Top-level configuration."""
Expand All @@ -120,7 +98,7 @@ class Config:


def load_config(
config_fd: IO, logger: Logger, env: Environ = os.environ
config_fd: IO[str], logger: Logger, env: Environ = os.environ
) -> Config:
"""Load YAML config from file."""
data = defaultdict(dict, yaml.safe_load(config_fd))
Expand Down Expand Up @@ -202,7 +180,7 @@ def _get_metrics(

def _validate_metric_config(
name: str, config: dict[str, Any], extra_labels: frozenset[str]
):
) -> None:
"""Validate a metric configuration stanza."""
if name in GLOBAL_METRICS:
raise ConfigError(f'Label name "{name} is reserved for builtin metric')
Expand Down Expand Up @@ -281,7 +259,7 @@ def _validate_query_config(
config: dict[str, Any],
database_names: frozenset[str],
metric_names: frozenset[str],
):
) -> None:
"""Validate a query configuration stanza."""
unknown_databases = set(config["databases"]) - database_names
if unknown_databases:
Expand Down Expand Up @@ -397,7 +375,7 @@ def _build_dsn(details: dict[str, Any]) -> str:
return url


def _validate_config(config: dict[str, Any]):
def _validate_config(config: dict[str, Any]) -> None:
schema_file = resources.files("query_exporter") / "schemas" / "config.yaml"
schema = yaml.safe_load(schema_file.read_bytes())
try:
Expand All @@ -407,7 +385,7 @@ def _validate_config(config: dict[str, Any]):
raise ConfigError(f"Invalid config at {path}: {e.message}")


def _warn_if_unused(config: Config, logger: Logger):
def _warn_if_unused(config: Config, logger: Logger) -> None:
"""Warn if there are unused databases or metrics defined."""
used_dbs: set[str] = set()
used_metrics: set[str] = set()
Expand Down
70 changes: 47 additions & 23 deletions query_exporter/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import asyncio
from collections.abc import Iterable
from dataclasses import (
dataclass,
field,
)
from itertools import chain
import logging
import sys
Expand All @@ -10,6 +14,7 @@
time,
)
from traceback import format_tb
from types import TracebackType
from typing import (
Any,
cast,
Expand Down Expand Up @@ -47,7 +52,7 @@ class DataBaseError(Exception):
if `fatal` is True, it means the Query will never succeed.
"""

def __init__(self, message: str, fatal: bool = False):
def __init__(self, message: str, fatal: bool = False) -> None:
super().__init__(message)
self.fatal = fatal

Expand All @@ -63,7 +68,7 @@ class DataBaseQueryError(DataBaseError):
class QueryTimeoutExpired(Exception):
"""Query execution timeout expired."""

def __init__(self, query_name: str, timeout: QueryTimeout):
def __init__(self, query_name: str, timeout: QueryTimeout) -> None:
super().__init__(
f'Execution for query "{query_name}" expired after {timeout} seconds'
)
Expand All @@ -72,7 +77,7 @@ def __init__(self, query_name: str, timeout: QueryTimeout):
class InvalidResultCount(Exception):
"""Number of results from a query don't match metrics count."""

def __init__(self, expected: int, got: int):
def __init__(self, expected: int, got: int) -> None:
super().__init__(
f"Wrong result count from query: expected {expected}, got {got}"
)
Expand All @@ -95,7 +100,7 @@ def _names(self, names: list[str]) -> str:
class InvalidQueryParameters(Exception):
"""Query parameter names don't match those in query SQL."""

def __init__(self, query_name: str):
def __init__(self, query_name: str) -> None:
super().__init__(
f'Parameters for query "{query_name}" don\'t match those from SQL'
)
Expand All @@ -104,7 +109,7 @@ def __init__(self, query_name: str):
class InvalidQuerySchedule(Exception):
"""Query schedule is wrong or both schedule and interval specified."""

def __init__(self, query_name: str, message: str):
def __init__(self, query_name: str, message: str) -> None:
super().__init__(
f'Invalid schedule for query "{query_name}": {message}'
)
Expand All @@ -118,7 +123,23 @@ def __init__(self, query_name: str, message: str):
FATAL_ERRORS = (InvalidResultCount, InvalidResultColumnNames)


def create_db_engine(dsn: str, **kwargs) -> AsyncioEngine:
@dataclass(frozen=True)
class DataBaseConfig:
"""Configuration for a database."""

name: str
dsn: str
connect_sql: list[str] = field(default_factory=list)
labels: dict[str, str] = field(default_factory=dict)
keep_connected: bool = True
autocommit: bool = True

def __post_init__(self) -> None:
# raise DatabaseError error if the DSN in invalid
create_db_engine(self.dsn)


def create_db_engine(dsn: str, **kwargs: Any) -> AsyncioEngine:
"""Create the database engine, validating the DSN"""
try:
return create_engine(dsn, **kwargs)
Expand All @@ -139,12 +160,12 @@ class QueryResults(NamedTuple):
"""Results of a database query."""

keys: list[str]
rows: list[tuple]
rows: list[tuple[Any]]
timestamp: float | None = None
latency: float | None = None

@classmethod
async def from_results(cls, results: AsyncResultProxy):
async def from_results(cls, results: AsyncResultProxy) -> "QueryResults":
"""Return a QueryResults from results for a query."""
timestamp = time()
conn_info = results._result_proxy.connection.info
Expand Down Expand Up @@ -187,7 +208,7 @@ def __init__(
interval: int | None = None,
schedule: str | None = None,
config_name: str | None = None,
):
) -> None:
self.name = name
self.databases = databases
self.metrics = metrics
Expand Down Expand Up @@ -239,15 +260,15 @@ def results(self, query_results: QueryResults) -> MetricResults:
latency=query_results.latency,
)

def _check_schedule(self):
def _check_schedule(self) -> None:
if self.interval and self.schedule:
raise InvalidQuerySchedule(
self.name, "both interval and schedule specified"
)
if self.schedule and not croniter.is_valid(self.schedule):
raise InvalidQuerySchedule(self.name, "invalid schedule format")

def _check_query_parameters(self):
def _check_query_parameters(self) -> None:
expr = text(self.sql)
query_params = set(expr.compile().params)
if set(self.parameters) != query_params:
Expand All @@ -263,9 +284,9 @@ class DataBase:

def __init__(
self,
config,
config: DataBaseConfig,
logger: logging.Logger = logging.getLogger(),
):
) -> None:
self.config = config
self.logger = logger
self._connect_lock = asyncio.Lock()
Expand All @@ -277,19 +298,21 @@ def __init__(

self._setup_query_latency_tracking()

async def __aenter__(self):
async def __aenter__(self) -> "DataBase":
await self.connect()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
async def __aexit__(
self, exc_type: type, exc_value: Exception, traceback: TracebackType
) -> None:
await self.close()

@property
def connected(self) -> bool:
"""Whether the database is connected."""
return self._conn is not None

async def connect(self):
async def connect(self) -> None:
"""Connect to the database."""
async with self._connect_lock:
if self.connected:
Expand All @@ -311,7 +334,7 @@ async def connect(self):
exc_class=DataBaseQueryError,
)

async def close(self):
async def close(self) -> None:
"""Close the database connection."""
async with self._connect_lock:
if not self.connected:
Expand Down Expand Up @@ -364,27 +387,28 @@ async def _execute_query(self, query: Query) -> AsyncResultProxy:
query.sql, parameters=query.parameters, timeout=query.timeout
)

async def _close(self):
async def _close(self) -> None:
# ensure the connection with the DB is actually closed
self._conn: AsyncConnection
self._conn.sync_connection.detach()
await self._conn.close()
self._conn = None
self._pending_queries = 0
self.logger.debug(f'disconnected from database "{self.config.name}"')

def _setup_query_latency_tracking(self):
def _setup_query_latency_tracking(self) -> None:
engine = self._engine.sync_engine

@event.listens_for(engine, "before_cursor_execute")
@event.listens_for(engine, "before_cursor_execute") # type: ignore
def before_cursor_execute(
conn, cursor, statement, parameters, context, executemany
):
) -> None:
conn.info["query_start_time"] = perf_counter()

@event.listens_for(engine, "after_cursor_execute")
@event.listens_for(engine, "after_cursor_execute") # type: ignore
def after_cursor_execute(
conn, cursor, statement, parameters, context, executemany
):
) -> None:
conn.info["query_latency"] = perf_counter() - conn.info.pop(
"query_start_time"
)
Expand Down
Loading

0 comments on commit ee638b3

Please sign in to comment.