Skip to content

Commit

Permalink
Add types to everything (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
vtclose authored Apr 8, 2023
1 parent b670570 commit 901442b
Show file tree
Hide file tree
Showing 21 changed files with 600 additions and 294 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.dmypy.json
*.pyc
.DS_Store
build/
Expand Down
11 changes: 10 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,16 @@ testpaths=tests
[mypy]
warn_unused_configs = True
ignore_missing_imports = False
python_version = 3.6
disallow_untyped_defs = True
disallow_incomplete_defs = True
no_implicit_optional = True
strict_equality = True
warn_unreachable = True
warn_unused_ignores = True
show_error_context = True
pretty = True
check_untyped_defs = True
python_version = 3.8
files = tasktiger

[mypy-flask_script.*]
Expand Down
56 changes: 44 additions & 12 deletions tasktiger/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,26 @@
import operator
import os
import threading
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Iterable,
List,
Optional,
Tuple,
Type,
TypedDict,
Union,
)

from .exceptions import TaskImportError

if TYPE_CHECKING:
from .task import Task
from .tasktiger import TaskTiger

# Task states (represented by different queues)
# Note some client code may rely on the string values (e.g. get_queue_stats).
QUEUED = "queued"
Expand All @@ -31,11 +48,19 @@
# Global task context. We store this globally (and not on the TaskTiger
# instance) for consistent results just in case the user has multiple TaskTiger
# instances.
g = {"tiger": None, "current_task_is_batch": None, "current_tasks": None}


class _G(TypedDict):
tiger: Optional["TaskTiger"]
current_task_is_batch: Optional[bool]
current_tasks: Optional[List["Task"]]


g: _G = {"tiger": None, "current_task_is_batch": None, "current_tasks": None}


# from rq
def import_attribute(name):
def import_attribute(name: str) -> Any:
"""Return an attribute from a dotted path name (e.g. "path.to.func")."""
try:
sep = ":" if ":" in name else "." # For backwards compatibility
Expand All @@ -46,14 +71,14 @@ def import_attribute(name):
raise TaskImportError(e)


def gen_id():
def gen_id() -> str:
"""
Generates and returns a random hex-encoded 256-bit unique ID.
"""
return binascii.b2a_hex(os.urandom(32)).decode("utf8")


def gen_unique_id(serialized_name, args, kwargs):
def gen_unique_id(serialized_name: str, args: Any, kwargs: Any) -> str:
"""
Generates and returns a hex-encoded 256-bit ID for the given task name and
args. Used to generate IDs for unique tasks or for task locks.
Expand All @@ -66,7 +91,7 @@ def gen_unique_id(serialized_name, args, kwargs):
).hexdigest()


def serialize_func_name(func):
def serialize_func_name(func: Union[Callable, Type]) -> str:
"""
Returns the dotted serialized path to the passed function.
"""
Expand All @@ -83,7 +108,7 @@ def serialize_func_name(func):
return ":".join([func.__module__, func_name])


def dotted_parts(s):
def dotted_parts(s: str) -> Iterable[str]:
"""
For a string "a.b.c", yields "a", "a.b", "a.b.c".
"""
Expand All @@ -96,7 +121,7 @@ def dotted_parts(s):
yield s[:idx]


def reversed_dotted_parts(s):
def reversed_dotted_parts(s: str) -> Iterable[str]:
"""
For a string "a.b.c", yields "a.b.c", "a.b", "a".
"""
Expand All @@ -110,14 +135,16 @@ def reversed_dotted_parts(s):
yield s[:idx]


def serialize_retry_method(retry_method):
def serialize_retry_method(retry_method: Any) -> Tuple[str, Tuple]:
if callable(retry_method):
return (serialize_func_name(retry_method), ())
else:
return (serialize_func_name(retry_method[0]), retry_method[1])


def get_timestamp(when):
def get_timestamp(
when: Optional[Union[datetime.timedelta, datetime.datetime]]
) -> Optional[float]:
# convert timedelta to datetime
if isinstance(when, datetime.timedelta):
when = datetime.datetime.utcnow() + when
Expand All @@ -126,9 +153,14 @@ def get_timestamp(when):
# Convert to unixtime: utctimetuple drops microseconds so we add
# them manually.
return calendar.timegm(when.utctimetuple()) + when.microsecond / 1.0e6
return None


def queue_matches(queue, only_queues=None, exclude_queues=None):
def queue_matches(
queue: str,
only_queues: Optional[Collection[str]] = None,
exclude_queues: Optional[Collection[str]] = None,
) -> bool:
"""Checks if the given queue matches against only/exclude constraints
Returns whether the given queue should be included by checking each part of
Expand Down Expand Up @@ -171,5 +203,5 @@ class classproperty(property):
Works like @property but on classes.
"""

def __get__(desc, self, cls):
return desc.fget(cls)
def __get__(desc, self, cls): # type:ignore[no-untyped-def]
return desc.fget(cls) # type:ignore[misc]
8 changes: 7 additions & 1 deletion tasktiger/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from typing import Any


class TaskImportError(ImportError):
Expand Down Expand Up @@ -40,7 +41,12 @@ class RetryException(BaseException):
from Redis when it completes.
"""

def __init__(self, method=None, original_traceback=False, log_error=True):
def __init__(
self,
method: Any = None,
original_traceback: Any = False,
log_error: bool = True,
):
self.method = method
self.exc_info = sys.exc_info() if original_traceback else None
self.log_error = log_error
Expand Down
16 changes: 11 additions & 5 deletions tasktiger/flask_script.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import absolute_import

import argparse
from typing import TYPE_CHECKING, Any, List

from flask_script import Command

if TYPE_CHECKING:
from tasktiger import TaskTiger


class TaskTigerCommand(Command):
"""
Expand All @@ -14,28 +18,30 @@ class TaskTigerCommand(Command):
capture_all_args = True
help = "Run a TaskTiger worker"

def __init__(self, tiger):
def __init__(self, tiger: TaskTiger) -> None:
super(TaskTigerCommand, self).__init__()
self.tiger = tiger

def create_parser(self, *args, **kwargs):
def create_parser(
self, *args: Any, **kwargs: Any
) -> argparse.ArgumentParser:
# Override the default parser so we can pass all arguments to the
# TaskTiger parser.
func_stack = kwargs.pop("func_stack", ())
parent = kwargs.pop("parent", None)
parser = argparse.ArgumentParser(*args, add_help=False, **kwargs)
parser = argparse.ArgumentParser(*args, add_help=False, **kwargs) # type: ignore[misc]
parser.set_defaults(func_stack=func_stack + (self,))
self.parser = parser
self.parent = parent
return parser

def setup(self):
def setup(self) -> None:
"""
Override this method to implement custom setup (e.g. logging) before
running the worker.
"""

def run(self, args):
def run(self, args: List[str]) -> None:
# Allow passing a callable that returns the TaskTiger instance.
if callable(self.tiger):
self.tiger = self.tiger()
Expand Down
6 changes: 5 additions & 1 deletion tasktiger/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any, Dict

from ._internal import g


def tasktiger_processor(logger, method_name, event_dict):
def tasktiger_processor(
logger: Any, method_name: Any, event_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""
TaskTiger structlog processor.
Expand Down
7 changes: 6 additions & 1 deletion tasktiger/migrations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import TYPE_CHECKING

from .utils import redis_glob_escape

if TYPE_CHECKING:
from . import TaskTiger


def migrate_executions_count(tiger):
def migrate_executions_count(tiger: "TaskTiger") -> None:
"""
Backfills ``t:task:<uuid>:executions_count`` by counting
elements in ``t:task:<uuid>:executions``.
Expand Down
Loading

0 comments on commit 901442b

Please sign in to comment.