Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored and svlandeg committed Mar 22, 2024
1 parent 3046706 commit 7cc7669
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions catalogue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __contains__(self, name: str) -> bool:
RETURNS (bool): Whether the name is in the registry.
"""
namespace = tuple(list(self.namespace) + [name])
has_entry_point = self.entry_points and self.get_entry_point(name)
has_entry_point = self.entry_points and self.get_entry_point(name) is not None
return has_entry_point or namespace in REGISTRY

def __call__(
Expand All @@ -56,7 +56,7 @@ def __call__(
"""Register a function for a given namespace. Same as Registry.register.
name (str): The name to register under the namespace.
func (Any): Optional function to register (if not used as decorator).
func (InFunc): Optional function to register (if not used as decorator).
RETURNS (Callable): The decorator.
"""
return self.register(name, func=func)
Expand All @@ -67,7 +67,7 @@ def register(
"""Register a function for a given namespace.
name (str): The name to register under the namespace.
func (Any): Optional function to register (if not used as decorator).
func (InFunc): Optional function to register (if not used as decorator).
RETURNS (Callable): The decorator.
"""

Expand All @@ -83,7 +83,7 @@ def get(self, name: str) -> InFunc:
"""Get the registered function for a given name.
name (str): The name.
RETURNS (Any): The registered function.
RETURNS (InFunc): The registered function.
"""
if self.entry_points:
from_entry_point = self.get_entry_point(name)
Expand All @@ -102,7 +102,7 @@ def get_all(self) -> Dict[str, InFunc]:
"""Get a all functions for a given namespace.
namespace (Tuple[str]): The namespace to get.
RETURNS (Dict[str, Any]): The functions, keyed by name.
RETURNS (Dict[str, InFunc]): The functions, keyed by name.
"""
global REGISTRY
result = {}
Expand All @@ -118,19 +118,19 @@ def get_all(self) -> Dict[str, InFunc]:
def get_entry_points(self) -> Dict[str, InFunc]:
"""Get registered entry points from other packages for this namespace.
RETURNS (Dict[str, Any]): Entry points, keyed by name.
RETURNS (Dict[str, InFunc]): Entry points, keyed by name.
"""
result = {}
for entry_point in self._get_entry_points():
result[entry_point.name] = entry_point.load()
return result

def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> InFunc:
def get_entry_point(self, name: str, default: Optional[InFunc] = None) -> Optional[InFunc]:
"""Check if registered entry point is available for a given name in the
namespace and load it. Otherwise, return the default value.
name (str): Name of entry point to load.
default (Any): The default value to return.
default (InFunc): The default value to return.
RETURNS (Any): The loaded entry point or the default value.
"""
for entry_point in self._get_entry_points():
Expand Down

0 comments on commit 7cc7669

Please sign in to comment.