diff --git a/tools/symbol-check.py b/tools/symbol-check.py index 37f056843d..9e26561260 100755 --- a/tools/symbol-check.py +++ b/tools/symbol-check.py @@ -1,26 +1,17 @@ #!/usr/bin/env python3 -''' -A script to check that a libsecp256k1 shared library -exports only expected symbols. - -Example usage: - -- when building with Autotools: +"""Check that a libsecp256k1 shared library exports only expected symbols. +Usage examples: + - When building with Autotools: ./tools/symbol-check.py .libs/libsecp256k1.so -or ./tools/symbol-check.py .libs/libsecp256k1-.dll -or ./tools/symbol-check.py .libs/libsecp256k1.dylib -- when building with CMake: - + - When building with CMake: ./tools/symbol-check.py build/lib/libsecp256k1.so -or ./tools/symbol-check.py build/bin/libsecp256k1-.dll -or - ./tools/symbol-check.py build/lib/libsecp256k1.dylib -''' + ./tools/symbol-check.py build/lib/libsecp256k1.dylib""" + import re import sys import subprocess @@ -28,61 +19,54 @@ import lief -def grep_exported_symbols() -> list[str]: - grep_output = subprocess.check_output(["git", "grep", r"^\s*SECP256K1_API", "--", "include"], universal_newlines=True, encoding="utf8") - lines = grep_output.split("\n") - exports: list[str] = [] - pattern = re.compile(r'\bsecp256k1_\w+') - for line in lines: - if line.strip(): - function_name = pattern.findall(line)[-1] - exports.append(function_name) - return exports - +class UnexpectedExport(RuntimeError): + pass -def check_ELF_exported_symbols(library, expected_exports) -> bool: - ok: bool = True - for symbol in library.exported_symbols: - name: str = symbol.name - if name in expected_exports: - continue - print(f'{filename}: export of symbol {name} not expected') - ok = False - return ok +def get_exported_exports(library) -> list[str]: + """Adapter function to get exported symbols based on the library format.""" + if library.format == lief.Binary.FORMATS.ELF: + return [symbol.name for symbol in library.exported_symbols] + elif library.format == lief.Binary.FORMATS.PE: + return [function.name for function in library.exported_functions] + elif library.format == lief.Binary.FORMATS.MACHO: + return [function.name[1:] for function in library.exported_functions] + raise NotImplementedError(f"Unsupported format: {library.format}") -def check_PE_exported_functions(library, expected_exports) -> bool: - ok: bool = True - for function in library.exported_functions: - name: str = function.name - if name in expected_exports: - continue - print(f'{filename}: export of function {name} not expected') - ok = False - return ok - -def check_MACHO_exported_functions(library, expected_exports) -> bool: - ok: bool = True - for function in library.exported_functions: - name: str = function.name[1:] - if name in expected_exports: - continue - print(f'{filename}: export of function {name} not expected') - ok = False - return ok - - -if __name__ == '__main__': - filename: str = sys.argv[1] - library: lief.Binary = lief.parse(filename) - exe_format: lief.Binary.FORMATS = library.format - if exe_format == lief.Binary.FORMATS.ELF: - success = check_ELF_exported_symbols(library, grep_exported_symbols()) - elif exe_format == lief.Binary.FORMATS.PE: - success = check_PE_exported_functions(library, grep_exported_symbols()) - elif exe_format == lief.Binary.FORMATS.MACHO: - success = check_MACHO_exported_functions(library, grep_exported_symbols()) - - if not success: - sys.exit(1) +def grep_expected_symbols() -> list[str]: + """Guess the list of expected exported symbols from the source code.""" + grep_output = subprocess.check_output( + ["git", "grep", "^SECP256K1_API", "--", "include"], # TODO WHITESPACE + universal_newlines=True, + encoding="utf-8" + ) + lines = grep_output.split("\n") + pattern = re.compile(r'\bsecp256k1_\w+') + exported: list[str] = [pattern.findall(line)[-1] for line in lines if line.strip()] + return exported + + +def check_symbols(library, expected_exports) -> None: + """Check that the library exports only the expected symbols.""" + actual_exports = list(get_exported_exports(library)) + unexpected_exports = set(actual_exports) - set(expected_exports) + if unexpected_exports != set(): + raise UnexpectedExport(f"Unexpected exported symbols: {unexpected_exports}") + +def main(): + if len(sys.argv) != 2: + print(__doc__) + return 1 + library = lief.parse(sys.argv[1]) + expected_exports = grep_expected_symbols() + try: + check_symbols(library, expected_exports) + except UnexpectedExport as e: + print(f"{sys.argv[0]}: In {sys.argv[1]}: {e}") + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main())