Skip to content

Commit

Permalink
finish
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored and svlandeg committed Mar 22, 2024
1 parent 643c116 commit 10c380d
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions catalogue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic
from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic, Type
from types import ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType
from typing import List
import inspect
Expand All @@ -16,9 +16,9 @@


InFunc = TypeVar("InFunc")
S = TypeVar('S')


def create(*namespace: str, entry_points: bool = False) -> "Registry":
def create(*namespace: str, entry_points: bool = False, generic_type: Optional[Type[S]] = None) -> "Registry[S]":
"""Create a new registry.
*namespace (str): The namespace, e.g. "spacy" or "spacy", "architectures".
Expand All @@ -27,7 +27,11 @@ def create(*namespace: str, entry_points: bool = False) -> "Registry":
"""
if check_exists(*namespace):
raise RegistryError(f"Namespace already exists: {namespace}")
return Registry(namespace, entry_points=entry_points)

if generic_type is None:
return Registry[Any](namespace, entry_points=entry_points)
else:
return Registry[S](namespace, entry_points=entry_points)


class Registry(Generic[InFunc]):
Expand All @@ -53,7 +57,7 @@ def __contains__(self, name: str) -> bool:

def __call__(
self, name: str, func: Optional[InFunc] = None
) -> Callable[[InFunc], InFunc]:
) -> Union[Callable[[InFunc], InFunc], InFunc]:
"""Register a function for a given namespace. Same as Registry.register.
name (str): The name to register under the namespace.
Expand All @@ -64,7 +68,7 @@ def __call__(

def register(
self, name: str, *, func: Optional[InFunc] = None
) -> Callable[[InFunc], InFunc]:
) -> Union[Callable[[InFunc], InFunc], InFunc]:
"""Register a function for a given namespace.
name (str): The name to register under the namespace.
Expand Down Expand Up @@ -139,10 +143,11 @@ def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> Option
return entry_point.load()
return default

def _get_entry_points(self) -> List[importlib_metadata.EntryPoint]:
def _get_entry_points(self) -> Union[List[importlib_metadata.EntryPoint], importlib_metadata.EntryPoints]:
if hasattr(AVAILABLE_ENTRY_POINTS, "select"):
return AVAILABLE_ENTRY_POINTS.select(group=self.entry_point_namespace)
else: # dict
assert isinstance(AVAILABLE_ENTRY_POINTS, dict)
return AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, [])

def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]:
Expand Down Expand Up @@ -174,7 +179,6 @@ def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]:
"docstring": inspect.cleandoc(docstring) if docstring else None,
}


def check_exists(*namespace: str) -> bool:
"""Check if a namespace exists.
Expand Down

0 comments on commit 10c380d

Please sign in to comment.