Skip to content

Commit

Permalink
add generic
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored and svlandeg committed Mar 22, 2024
1 parent 1de5752 commit 3046706
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions catalogue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union
from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic
from typing import List
import inspect

Expand Down Expand Up @@ -29,7 +29,7 @@ def create(*namespace: str, entry_points: bool = False) -> "Registry":
return Registry(namespace, entry_points=entry_points)


class Registry(object):
class Registry(Generic[InFunc]):
def __init__(self, namespace: Sequence[str], entry_points: bool = False) -> None:
"""Initialize a new registry.
Expand All @@ -51,7 +51,7 @@ def __contains__(self, name: str) -> bool:
return has_entry_point or namespace in REGISTRY

def __call__(
self, name: str, func: Optional[Any] = None
self, name: str, func: Optional[InFunc] = None
) -> Callable[[InFunc], InFunc]:
"""Register a function for a given namespace. Same as Registry.register.
Expand All @@ -62,7 +62,7 @@ def __call__(
return self.register(name, func=func)

def register(
self, name: str, *, func: Optional[Any] = None
self, name: str, *, func: Optional[InFunc] = None
) -> Callable[[InFunc], InFunc]:
"""Register a function for a given namespace.
Expand All @@ -79,7 +79,7 @@ def do_registration(func):
return do_registration(func)
return do_registration

def get(self, name: str) -> Any:
def get(self, name: str) -> InFunc:
"""Get the registered function for a given name.
name (str): The name.
Expand All @@ -98,7 +98,7 @@ def get(self, name: str) -> Any:
)
return _get(namespace)

def get_all(self) -> Dict[str, Any]:
def get_all(self) -> Dict[str, InFunc]:
"""Get a all functions for a given namespace.
namespace (Tuple[str]): The namespace to get.
Expand All @@ -115,7 +115,7 @@ def get_all(self) -> Dict[str, Any]:
result[keys[-1]] = value
return result

def get_entry_points(self) -> Dict[str, Any]:
def get_entry_points(self) -> Dict[str, InFunc]:
"""Get registered entry points from other packages for this namespace.
RETURNS (Dict[str, Any]): Entry points, keyed by name.
Expand All @@ -125,7 +125,7 @@ def get_entry_points(self) -> Dict[str, Any]:
result[entry_point.name] = entry_point.load()
return result

def get_entry_point(self, name: str, default: Optional[Any] = None) -> Any:
def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> InFunc:
"""Check if registered entry point is available for a given name in the
namespace and load it. Otherwise, return the default value.
Expand Down

0 comments on commit 3046706

Please sign in to comment.