From 643c1169e8a8612dbe20b5b56b7d382ae1f4a4c6 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 16 Feb 2024 00:11:26 -0800 Subject: [PATCH] get mypy to pass --- catalogue/__init__.py | 4 ++++ catalogue/tests/test_catalogue.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/catalogue/__init__.py b/catalogue/__init__.py index 4c946e2..5f40ec4 100644 --- a/catalogue/__init__.py +++ b/catalogue/__init__.py @@ -1,4 +1,5 @@ from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union, Generic +from types import ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType from typing import List import inspect @@ -158,6 +159,9 @@ def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]: line_no: Optional[int] = None file_name: Optional[str] = None try: + if not isinstance(func, (ModuleType, MethodType, FunctionType, TracebackType, FrameType, CodeType, type)): + raise TypeError(f"func type {type(func)} is not a valid type for inspect.getsourcelines()") + _, line_no = inspect.getsourcelines(func) file_name = inspect.getfile(func) except (TypeError, ValueError): diff --git a/catalogue/tests/test_catalogue.py b/catalogue/tests/test_catalogue.py index e53ebf2..7cff281 100644 --- a/catalogue/tests/test_catalogue.py +++ b/catalogue/tests/test_catalogue.py @@ -159,3 +159,31 @@ def a(): assert info["file"] == str(Path(__file__)) assert info["docstring"] == "This is a registered function." assert info["line_no"] + +def test_registry_find_module(): + import json + + test_registry = catalogue.create("test_registry_find_module") + + test_registry.register("json", func=json) + + info = test_registry.find("json") + assert info["module"] == "json" + assert info["file"] == json.__file__ + assert info["docstring"] == json.__doc__.strip('\n') + assert info["line_no"] == 0 + +def test_registry_find_class(): + test_registry = catalogue.create("test_registry_find_class") + + class TestClass: + """This is a registered class.""" + pass + + test_registry.register("test_class", func=TestClass) + + info = test_registry.find("test_class") + assert info["module"] == "catalogue.tests.test_catalogue" + assert info["file"] == str(Path(__file__)) + assert info["docstring"] == TestClass.__doc__ + assert info["line_no"] \ No newline at end of file