From 10c380d5eef399f7c9aff9906cfc2dbed2d9909a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sun, 18 Feb 2024 01:11:01 -0800 Subject: [PATCH] finish --- catalogue/__init__.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/catalogue/__init__.py b/catalogue/__init__.py index 5f40ec4..148b033 100644 --- a/catalogue/__init__.py +++ b/catalogue/__init__.py @@ -1,4 +1,4 @@ -from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic +from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic, Type from types import ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType from typing import List import inspect @@ -16,9 +16,9 @@ InFunc = TypeVar("InFunc") +S = TypeVar('S') - -def create(*namespace: str, entry_points: bool = False) -> "Registry": +def create(*namespace: str, entry_points: bool = False, generic_type: Optional[Type[S]] = None) -> "Registry[S]": """Create a new registry. *namespace (str): The namespace, e.g. "spacy" or "spacy", "architectures". @@ -27,7 +27,11 @@ def create(*namespace: str, entry_points: bool = False) -> "Registry": """ if check_exists(*namespace): raise RegistryError(f"Namespace already exists: {namespace}") - return Registry(namespace, entry_points=entry_points) + + if generic_type is None: + return Registry[Any](namespace, entry_points=entry_points) + else: + return Registry[S](namespace, entry_points=entry_points) class Registry(Generic[InFunc]): @@ -53,7 +57,7 @@ def __contains__(self, name: str) -> bool: def __call__( self, name: str, func: Optional[InFunc] = None - ) -> Callable[[InFunc], InFunc]: + ) -> Union[Callable[[InFunc], InFunc], InFunc]: """Register a function for a given namespace. Same as Registry.register. name (str): The name to register under the namespace. @@ -64,7 +68,7 @@ def __call__( def register( self, name: str, *, func: Optional[InFunc] = None - ) -> Callable[[InFunc], InFunc]: + ) -> Union[Callable[[InFunc], InFunc], InFunc]: """Register a function for a given namespace. name (str): The name to register under the namespace. @@ -139,10 +143,11 @@ def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> Option return entry_point.load() return default - def _get_entry_points(self) -> List[importlib_metadata.EntryPoint]: + def _get_entry_points(self) -> Union[List[importlib_metadata.EntryPoint], importlib_metadata.EntryPoints]: if hasattr(AVAILABLE_ENTRY_POINTS, "select"): return AVAILABLE_ENTRY_POINTS.select(group=self.entry_point_namespace) else: # dict + assert isinstance(AVAILABLE_ENTRY_POINTS, dict) return AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, []) def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]: @@ -174,7 +179,6 @@ def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]: "docstring": inspect.cleandoc(docstring) if docstring else None, } - def check_exists(*namespace: str) -> bool: """Check if a namespace exists.