Skip to content

Commit

Permalink
fix typing for sort keys (finally)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilius committed Dec 5, 2024
1 parent 8075e2a commit 5419402
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 64 deletions.
9 changes: 4 additions & 5 deletions pyglossary/sort_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
from typing import TYPE_CHECKING, Any, NamedTuple

if TYPE_CHECKING:

from .icu_types import T_Collator, T_Locale
from .sort_keys_types import (
LocaleSortKeyMakerType,
SortKeyMakerType,
SQLiteLocaleSortKeyMakerType,
SQLiteSortKeyMakerType,
)

Expand All @@ -45,8 +43,8 @@
class NamedSortKey(NamedTuple):
name: str
desc: str
normal: SortKeyMakerType
sqlite: SQLiteSortKeyMakerType
normal: SortKeyMakerType | None
sqlite: SQLiteSortKeyMakerType | None


@dataclass(slots=True) # not frozen because of mod
Expand All @@ -66,6 +64,7 @@ def module(self): # noqa: ANN201
self.mod = mod
return mod

# mypy seems to have problems with @property
@property
def normal(self) -> SortKeyMakerType:
return self.module.normal
Expand All @@ -79,7 +78,7 @@ def locale(self) -> LocaleSortKeyMakerType | None:
return getattr(self.module, "locale", None)

@property
def sqlite_locale(self) -> SQLiteLocaleSortKeyMakerType | None:
def sqlite_locale(self) -> SQLiteSortKeyMakerType | None:
return getattr(self.module, "sqlite_locale", None)


Expand Down
25 changes: 2 additions & 23 deletions pyglossary/sort_keys_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,15 @@
Any,
]


RawSortKeyType: TypeAlias = Callable[
[bytes],
Any,
]

SQLiteSortKeyType: TypeAlias = list[tuple[str, str, SortKeyType]]


class SortKeyMakerType(Protocol):
def __call__(
self,
sortEncoding: str,
**kwargs
) -> RawSortKeyType: ...
def __call__(self, sortEncoding: str = "utf-8", **kwargs) -> SortKeyType: ...


class SQLiteSortKeyMakerType(Protocol):
def __call__(
self,
sortEncoding: str,
**kwargs
) -> SQLiteSortKeyType: ...
def __call__(self, sortEncoding: str = "utf-8", **kwargs) -> SQLiteSortKeyType: ...


class LocaleSortKeyMakerType(Protocol):
Expand All @@ -43,13 +29,6 @@ def __call__(
) -> SortKeyMakerType: ...


class SQLiteLocaleSortKeyMakerType(Protocol):
def __call__(
self,
collator: T_Collator, # noqa: F821
) -> SQLiteSortKeyType: ...


__all__ = [
"LocaleSortKeyMakerType",
"SQLiteSortKeyType",
Expand Down
8 changes: 4 additions & 4 deletions pyglossary/sort_modules/ebook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from pyglossary.sort_keys_types import SortKeyType, SQLiteSortKeyType
Expand All @@ -17,8 +17,8 @@ def normal(
) -> SortKeyType:
length = options.get("group_by_prefix_length", 2)

# FIXME: return bytes
def sortKey(words: list[str]) -> tuple[str, str]:
# FIXME: return bytes?
def sortKey(words: list[str]) -> Any:
word = words[0]
if not word:
return "", ""
Expand All @@ -42,7 +42,7 @@ def getPrefix(words: list[str]) -> str:
return "SPECIAL"
return prefix

def headword(words: list[str]) -> bytes:
def headword(words: list[str]) -> Any:
return words[0].encode(sortEncoding, errors="replace")

_type = "TEXT" if sortEncoding == "utf-8" else "BLOB"
Expand Down
22 changes: 12 additions & 10 deletions pyglossary/sort_modules/headword.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Callable

from pyglossary.icu_types import T_Collator
from pyglossary.sort_keys_types import (
SortKeyMakerType,
SortKeyType,
SQLiteSortKeyMakerType,
SQLiteSortKeyType,
)

Expand All @@ -17,7 +16,7 @@


def normal(sortEncoding: str = "utf-8", **_options) -> SortKeyType:
def sortKey(words: list[str]) -> bytes:
def sortKey(words: list[str]) -> Any:
return words[0].encode(sortEncoding, errors="replace")

return sortKey
Expand All @@ -28,17 +27,17 @@ def locale(
) -> SortKeyMakerType:
cSortKey = collator.getSortKey

def sortKey(words: list[str]) -> bytes:
def sortKey(words: list[str]) -> Any:
return cSortKey(words[0])

def warpper(_sortEncoding: str = "utf-8", **_options) -> SortKeyType:
def warpper(sortEncoding: str = "utf-8", **_options) -> SortKeyType: # noqa: ARG001
return sortKey

return warpper


def sqlite(sortEncoding: str = "utf-8", **_options) -> SQLiteSortKeyType:
def sortKey(words: list[str]) -> bytes:
def sortKey(words: list[str]) -> Any:
return words[0].encode(sortEncoding, errors="replace")

return [
Expand All @@ -52,10 +51,13 @@ def sortKey(words: list[str]) -> bytes:

def sqlite_locale(
collator: T_Collator, # noqa: F821
) -> Callable[..., SQLiteSortKeyType]:
) -> SQLiteSortKeyMakerType:
cSortKey = collator.getSortKey

def sortKey(words: list[str]) -> bytes:
def sortKey(words: list[str]) -> Any:
return cSortKey(words[0])

return lambda **_options: [("sortkey", "BLOB", sortKey)]
def wrapper(sortEncoding="", **_options): # noqa: ARG001
return [("sortkey", "BLOB", sortKey)]

return wrapper
6 changes: 3 additions & 3 deletions pyglossary/sort_modules/headword_bytes_lower.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from pyglossary.sort_keys_types import SortKeyType, SQLiteSortKeyType
Expand All @@ -13,7 +13,7 @@ def normal(
sortEncoding: str = "utf-8",
**_options,
) -> SortKeyType:
def sortKey(words: list[str]) -> bytes:
def sortKey(words: list[str]) -> Any:
return words[0].encode(sortEncoding, errors="replace").lower()

return sortKey
Expand All @@ -26,7 +26,7 @@ def sortKey(words: list[str]) -> bytes:


def sqlite(sortEncoding: str = "utf-8", **_options) -> SQLiteSortKeyType:
def sortKey(words: list[str]) -> bytes:
def sortKey(words: list[str]) -> Any:
return words[0].encode(sortEncoding, errors="replace").lower()

return [
Expand Down
22 changes: 12 additions & 10 deletions pyglossary/sort_modules/headword_lower.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Callable

from pyglossary.icu_types import T_Collator
from pyglossary.sort_keys_types import (
SortKeyMakerType,
SortKeyType,
SQLiteSortKeyMakerType,
SQLiteSortKeyType,
)

Expand All @@ -17,7 +16,7 @@


def normal(sortEncoding: str = "utf-8", **_options) -> SortKeyType:
def sortKey(words: list[str]) -> bytes:
def sortKey(words: list[str]) -> Any:
# assert isinstance(words, list) # OK
return words[0].lower().encode(sortEncoding, errors="replace")

Expand All @@ -29,11 +28,11 @@ def locale(
) -> SortKeyMakerType:
cSortKey = collator.getSortKey

def sortKey(words: list[str]) -> bytes:
def sortKey(words: list[str]) -> Any:
# assert isinstance(words, list) # OK
return cSortKey(words[0].lower())

def warpper(_sortEncoding: str = "utf-8", **_options) -> SortKeyType:
def warpper(sortEncoding: str = "utf-8", **_options) -> SortKeyType: # noqa: ARG001
return sortKey

return warpper
Expand All @@ -43,7 +42,7 @@ def sqlite(
sortEncoding: str = "utf-8",
**_options,
) -> SQLiteSortKeyType:
def sortKey(words: list[str]) -> bytes:
def sortKey(words: list[str]) -> Any:
return words[0].lower().encode(sortEncoding, errors="replace")

return [
Expand All @@ -57,10 +56,13 @@ def sortKey(words: list[str]) -> bytes:

def sqlite_locale(
collator: T_Collator, # noqa: F821
) -> Callable[..., SQLiteSortKeyType]:
) -> SQLiteSortKeyMakerType:
cSortKey = collator.getSortKey

def sortKey(words: list[str]) -> bytes:
def sortKey(words: list[str]) -> Any:
return cSortKey(words[0].lower())

return lambda **_options: [("sortkey", "BLOB", sortKey)]
def wrapper(sortEncoding="", **_options): # noqa: ARG001
return [("sortkey", "BLOB", sortKey)]

return wrapper
20 changes: 15 additions & 5 deletions pyglossary/sort_modules/random.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Callable

from pyglossary.icu_types import T_Collator
from pyglossary.sort_keys_types import SortKeyType, SQLiteSortKeyType
from pyglossary.sort_keys_types import (
SortKeyMakerType,
SortKeyType,
SQLiteSortKeyType,
)


desc = "Random"
Expand All @@ -19,11 +23,17 @@ def normal(**_options) -> SortKeyType:


def locale(
_collator: T_Collator, # noqa: F821
) -> SortKeyType:
collator: T_Collator, # noqa: ARG001 # noqa: F821
) -> SortKeyMakerType:
from random import random

return lambda **_options: lambda _words: random()
def sortKey(words: list[str]) -> Any: # noqa: ARG001
return random()

def warpper(sortEncoding: str = "utf-8", **_options) -> SortKeyType: # noqa: ARG001
return sortKey

return warpper


def sqlite(**_options) -> SQLiteSortKeyType:
Expand Down
8 changes: 4 additions & 4 deletions pyglossary/sort_modules/stardict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from pyglossary.sort_keys_types import SortKeyType, SQLiteSortKeyType
Expand All @@ -10,18 +10,18 @@


def normal(sortEncoding: str = "utf-8", **_options) -> SortKeyType:
def sortKey(words: list[str]) -> tuple[bytes, bytes]:
def sortKey(words: list[str]) -> Any:
b_word = words[0].encode(sortEncoding, errors="replace")
return (b_word.lower(), b_word)

return sortKey


def sqlite(sortEncoding: str = "utf-8", **_options) -> SQLiteSortKeyType:
def headword_lower(words: list[str]) -> bytes:
def headword_lower(words: list[str]) -> Any:
return words[0].encode(sortEncoding, errors="replace").lower()

def headword(words: list[str]) -> bytes:
def headword(words: list[str]) -> Any:
return words[0].encode(sortEncoding, errors="replace")

_type = "TEXT" if sortEncoding == "utf-8" else "BLOB"
Expand Down

0 comments on commit 5419402

Please sign in to comment.