diff --git a/catalogue/__init__.py b/catalogue/__init__.py index e468685..4c946e2 100644 --- a/catalogue/__init__.py +++ b/catalogue/__init__.py @@ -47,7 +47,7 @@ def __contains__(self, name: str) -> bool: RETURNS (bool): Whether the name is in the registry. """ namespace = tuple(list(self.namespace) + [name]) - has_entry_point = self.entry_points and self.get_entry_point(name) + has_entry_point = self.entry_points and self.get_entry_point(name) is not None return has_entry_point or namespace in REGISTRY def __call__( @@ -56,7 +56,7 @@ def __call__( """Register a function for a given namespace. Same as Registry.register. name (str): The name to register under the namespace. - func (Any): Optional function to register (if not used as decorator). + func (InFunc): Optional function to register (if not used as decorator). RETURNS (Callable): The decorator. """ return self.register(name, func=func) @@ -67,7 +67,7 @@ def register( """Register a function for a given namespace. name (str): The name to register under the namespace. - func (Any): Optional function to register (if not used as decorator). + func (InFunc): Optional function to register (if not used as decorator). RETURNS (Callable): The decorator. """ @@ -83,7 +83,7 @@ def get(self, name: str) -> InFunc: """Get the registered function for a given name. name (str): The name. - RETURNS (Any): The registered function. + RETURNS (InFunc): The registered function. """ if self.entry_points: from_entry_point = self.get_entry_point(name) @@ -102,7 +102,7 @@ def get_all(self) -> Dict[str, InFunc]: """Get a all functions for a given namespace. namespace (Tuple[str]): The namespace to get. - RETURNS (Dict[str, Any]): The functions, keyed by name. + RETURNS (Dict[str, InFunc]): The functions, keyed by name. """ global REGISTRY result = {} @@ -118,19 +118,19 @@ def get_all(self) -> Dict[str, InFunc]: def get_entry_points(self) -> Dict[str, InFunc]: """Get registered entry points from other packages for this namespace. - RETURNS (Dict[str, Any]): Entry points, keyed by name. + RETURNS (Dict[str, InFunc]): Entry points, keyed by name. """ result = {} for entry_point in self._get_entry_points(): result[entry_point.name] = entry_point.load() return result - def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> InFunc: + def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> Optional[InFunc]: """Check if registered entry point is available for a given name in the namespace and load it. Otherwise, return the default value. name (str): Name of entry point to load. - default (Any): The default value to return. + default (InFunc): The default value to return. RETURNS (Any): The loaded entry point or the default value. """ for entry_point in self._get_entry_points():