From 7f2e8df4ef31bf3cbab598c54e5603a67882775d Mon Sep 17 00:00:00 2001 From: "Ashwin V. Mohanan" Date: Thu, 14 Mar 2024 18:05:41 +0100 Subject: [PATCH 1/7] Add tests for jax --- Makefile | 3 ++ noxfile.py | 4 +- pdm.lock | 110 +++++++++++++++++++++++++++++++++++++++++++++++-- pyproject.toml | 9 +++- 4 files changed, 120 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 3482a43..ce4ee35 100644 --- a/Makefile +++ b/Makefile @@ -15,6 +15,9 @@ tests_pythran: tests_cython: TRANSONIC_BACKEND="cython" pytest tests data_tests/ipynb +tests_jax: + TRANSONIC_BACKEND="jax" pytest --lf tests data_tests/ipynb + tests_numba: TRANSONIC_BACKEND="numba" pytest tests data_tests/ipynb diff --git a/noxfile.py b/noxfile.py index e66d5a9..731c694 100644 --- a/noxfile.py +++ b/noxfile.py @@ -26,6 +26,8 @@ def test(session, with_pythran, with_cython): else: session.install("setuptools") + session.install("jax", "jaxlib") + if with_pythran: session.install("pythran") if with_cython: @@ -46,7 +48,7 @@ def test(session, with_pythran, with_cython): code_dependencies = 10 * with_pythran + with_cython - for backend in ("python", "pythran", "numba", "cython"): + for backend in ("python", "pythran", "numba", "jax", "cython"): print(f"TRANSONIC_BACKEND={backend}") session.run( "pytest", diff --git a/pdm.lock b/pdm.lock index 96f571e..d986998 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "base_test", "dev", "doc", "mpi", "test"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:947d7510c400794a3ee19d3303734c783a19ced37af5e0a08260056d77e79104" +content_hash = "sha256:f64f349a54ec82b94d2f2c835130717342af4f18c7f0d2ad1955fd4070f48e66" [[package]] name = "alabaster" @@ -932,7 +932,7 @@ name = "importlib-metadata" version = "7.1.0" requires_python = ">=3.8" summary = "Read metadata from Python packages" -groups = ["base_test", "dev", "doc"] +groups = ["base_test", "dev", "doc", "test"] dependencies = [ "zipp>=0.5", ] @@ -1066,6 +1066,62 @@ files = [ {file = "jaraco.functools-4.0.0.tar.gz", hash = "sha256:c279cb24c93d694ef7270f970d499cab4d3813f4e08273f95398651a634f0925"}, ] +[[package]] +name = "jax" +version = "0.4.25" +requires_python = ">=3.9" +summary = "Differentiate, compile, and transform Numpy code." +groups = ["test"] +dependencies = [ + "importlib-metadata>=4.6; python_version < \"3.10\"", + "ml-dtypes>=0.2.0", + "numpy>=1.22", + "numpy>=1.23.2; python_version >= \"3.11\"", + "numpy>=1.26.0; python_version >= \"3.12\"", + "opt-einsum", + "scipy>=1.11.1; python_version >= \"3.12\"", + "scipy>=1.9", +] +files = [ + {file = "jax-0.4.25-py3-none-any.whl", hash = "sha256:8158c837e5ecc195074b421609e85329a962785b36f9fe5ff53e844e8ad87dbc"}, + {file = "jax-0.4.25.tar.gz", hash = "sha256:a8ee189c782de2b7b2ffb64a8916da380b882a617e2769aa429b71d79747b982"}, +] + +[[package]] +name = "jaxlib" +version = "0.4.25" +requires_python = ">=3.9" +summary = "XLA library for JAX" +groups = ["test"] +dependencies = [ + "ml-dtypes>=0.2.0", + "numpy>=1.22", + "scipy>=1.11.1; python_version >= \"3.12\"", + "scipy>=1.9", +] +files = [ + {file = "jaxlib-0.4.25-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:be1b26e96e80d42f54f77226a016717cb969d7d208d0dcb61997f19dc7b2d8e2"}, + {file = "jaxlib-0.4.25-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3b5cbd3a4f731636469cdaf06c4413208811ca458ee312647e8f3faca32f6445"}, + {file = "jaxlib-0.4.25-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:89a011330aaeaf19027bba5e3236be155cc8d73d94aa9db84d817d414f4a7647"}, + {file = "jaxlib-0.4.25-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:dcda74c7c8eb328cde8afeebcf21ec9240138fac54f9631a60b679a211f7e100"}, + {file = "jaxlib-0.4.25-cp310-cp310-win_amd64.whl", hash = "sha256:fd751b10e60c085dec42bec6c27c9905f5c57d12323190eea0df10ee14c574e0"}, + {file = "jaxlib-0.4.25-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:37da780cb545ca210bfa0402b5081452ad830bb06fe9e970fd16ad14d2fdc6a6"}, + {file = "jaxlib-0.4.25-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0df7e2193b216e195dfc7a8aa14527eb52614ec3ba4c59a199af2f17195ae1c1"}, + {file = "jaxlib-0.4.25-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:0ce2a25263e7504d575e8ba5ba4f53aef6fe274679785bcf87ab06b0aaec0b90"}, + {file = "jaxlib-0.4.25-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:a0dd09cbb62583941872b6a198894e87a1b64d8e4dd6b53946dbb41d642b8f5f"}, + {file = "jaxlib-0.4.25-cp311-cp311-win_amd64.whl", hash = "sha256:dfb1ef8c2e6a01ecb60f8833552ff077cd593154fd75739050fba9148879a2a4"}, + {file = "jaxlib-0.4.25-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:425d6f3fa57ea1d1674ae84b5a3d3588ba0937f3c47fd4f166eb84c4240887b8"}, + {file = "jaxlib-0.4.25-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0e97542bbd89f4316d2feb599119d8a43440ca151b7a165eff0fc127cf4512e7"}, + {file = "jaxlib-0.4.25-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:c4e3bc32aea275e025e762612216954626478c9cf5c44131e248cdd17e361efd"}, + {file = "jaxlib-0.4.25-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:dcfb71a7f559c13734584769ca30373bc4b73d0fe105790462370e49f35dcbe4"}, + {file = "jaxlib-0.4.25-cp312-cp312-win_amd64.whl", hash = "sha256:f7aa9682b6806e4197ad51294e87e77f04f5eee7ced4e841aa7ccc7320c6d96b"}, + {file = "jaxlib-0.4.25-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:6660b68741286bd4b849c149d86a8c36e448f7e39e1d483e79dab79ea300bf1b"}, + {file = "jaxlib-0.4.25-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:32881f93d5de195a0fd19e091a2aa89418fa27f630d30c79b4613a51cff4d1c6"}, + {file = "jaxlib-0.4.25-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:ad1ab653265c33b8d54bdcc40867a8ffd61fea879176e4d4cd0b585fe52521fc"}, + {file = "jaxlib-0.4.25-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:0fd113ab414de856f90f07264e6ccd0cb95d392f3579c0deab4ff0943ef75f73"}, + {file = "jaxlib-0.4.25-cp39-cp39-win_amd64.whl", hash = "sha256:b11aef2bd6cf873b39399fda122170b625776d977bbc56b4635f46c396279b8b"}, +] + [[package]] name = "jedi" version = "0.19.1" @@ -1568,6 +1624,38 @@ files = [ {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, ] +[[package]] +name = "ml-dtypes" +version = "0.3.2" +requires_python = ">=3.9" +summary = "" +groups = ["test"] +dependencies = [ + "numpy>1.20", + "numpy>=1.21.2; python_version >= \"3.10\"", + "numpy>=1.23.3; python_version >= \"3.11\"", + "numpy>=1.26.0; python_version >= \"3.12\"", +] +files = [ + {file = "ml_dtypes-0.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53"}, + {file = "ml_dtypes-0.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd"}, + {file = "ml_dtypes-0.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7"}, + {file = "ml_dtypes-0.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb"}, + {file = "ml_dtypes-0.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe"}, + {file = "ml_dtypes-0.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462"}, + {file = "ml_dtypes-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226"}, + {file = "ml_dtypes-0.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655"}, + {file = "ml_dtypes-0.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18"}, + {file = "ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33"}, + {file = "ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855"}, + {file = "ml_dtypes-0.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4"}, + {file = "ml_dtypes-0.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c"}, + {file = "ml_dtypes-0.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e"}, + {file = "ml_dtypes-0.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226"}, + {file = "ml_dtypes-0.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94"}, + {file = "ml_dtypes-0.3.2.tar.gz", hash = "sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967"}, +] + [[package]] name = "more-itertools" version = "10.2.0" @@ -1875,6 +1963,20 @@ files = [ {file = "numpydoc-1.7.0.tar.gz", hash = "sha256:866e5ae5b6509dcf873fc6381120f5c31acf13b135636c1a81d68c166a95f921"}, ] +[[package]] +name = "opt-einsum" +version = "3.3.0" +requires_python = ">=3.5" +summary = "Optimizing numpys einsum function" +groups = ["test"] +dependencies = [ + "numpy>=1.7", +] +files = [ + {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"}, + {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"}, +] + [[package]] name = "overrides" version = "7.7.0" @@ -2615,7 +2717,7 @@ name = "scipy" version = "1.12.0" requires_python = ">=3.9" summary = "Fundamental algorithms for scientific computing in Python" -groups = ["base_test"] +groups = ["base_test", "test"] dependencies = [ "numpy<1.29.0,>=1.22.4", ] @@ -3178,7 +3280,7 @@ name = "zipp" version = "3.18.1" requires_python = ">=3.8" summary = "Backport of pathlib-compatible object wrapper for zip files" -groups = ["base_test", "dev", "doc"] +groups = ["base_test", "dev", "doc", "test"] files = [ {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"}, {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"}, diff --git a/pyproject.toml b/pyproject.toml index ee330a4..5576a67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,14 @@ base_test = [ "scipy", "-e transonic_testing @ file:///${PROJECT_ROOT}/_transonic_testing", ] -test = ["cython", "mpi4py", "pythran", "numba"] +test = [ + "cython", + "mpi4py", + "pythran", + "numba", + "jax>=0.4.25", + "jaxlib>=0.4.25", +] doc = [ "jupyterlab", # "nbsphinx", From c1fa6311660d467de17d74db5bbf151195f26f55 Mon Sep 17 00:00:00 2001 From: "Ashwin V. Mohanan" Date: Thu, 14 Mar 2024 22:06:46 +0100 Subject: [PATCH 2/7] Implement Jax JIT backend class --- src/transonic/backends/__init__.py | 8 ++- src/transonic/backends/jax.py | 110 +++++++++++++++++++++++++++++ src/transonic/config.py | 2 +- src/transonic/util.py | 5 ++ 4 files changed, 121 insertions(+), 4 deletions(-) create mode 100644 src/transonic/backends/jax.py diff --git a/src/transonic/backends/__init__.py b/src/transonic/backends/__init__.py index 6c56254..e7cbb2d 100644 --- a/src/transonic/backends/__init__.py +++ b/src/transonic/backends/__init__.py @@ -31,16 +31,18 @@ from transonic.config import backend_default from transonic.util import get_module_name, get_frame -from .py import PythonBackend -from .pythran import PythranBackend from .cython import CythonBackend +from .jax import JaxBackend from .numba import NumbaBackend +from .py import PythonBackend +from .pythran import PythranBackend backends = dict( - pythran=PythranBackend(), cython=CythonBackend(), numba=NumbaBackend(), + jax=JaxBackend(), python=PythonBackend(), + pythran=PythranBackend(), ) backend_default_modules = {} diff --git a/src/transonic/backends/jax.py b/src/transonic/backends/jax.py new file mode 100644 index 0000000..98041c8 --- /dev/null +++ b/src/transonic/backends/jax.py @@ -0,0 +1,110 @@ +"""Jax backend +================ + +Internal API +------------ + +.. autoclass:: SubBackendJITJax + :members: + :private-members: + +.. autoclass:: JaxBackend + :members: + :private-members: + +""" + +from typing import Optional + +from transonic.analyses.extast import parse, unparse, CommentLine, gast +from transonic.util import format_str + +from .py import PythonBackend, SubBackendJITPython + + +def add_jax_comments(code): + """Add Jax code in Python comments""" + mod = parse(code) + new_body = [CommentLine("# __protected__ from jax import jit")] + + for node in mod.body: + # Replace `import numpy` -> `import jax.numpy as numpy` + # Replace `import numpy as np` -> `import jax.numpy as np` + if isinstance(node, gast.Import): + if (alias := node.names[0]).name == "numpy": + node = gast.Import([gast.alias(name="jax.numpy", asname=alias.asname or alias.name)]) + + # Replace `from numpy import eye` -> `from jax.numpy import eye` + elif isinstance(node, gast.ImportFrom): + if node.module == "numpy": + node.module = "jax.numpy" + + # Add JIT decorator + if isinstance(node, gast.FunctionDef): + new_body.append( + CommentLine("# __protected__ @jit") + ) + new_body.append(node) + + mod.body = new_body + return format_str(unparse(mod)) + + +class SubBackendJITJax(SubBackendJITPython): + def make_backend_source(self, info_analysis, func, path_backend): + src, has_to_write = super().make_backend_source( + info_analysis, func, path_backend + ) + + if not src: + return src, has_to_write + + return add_jax_comments(src), has_to_write + + +class JaxBackend(PythonBackend): + """Main class for the Jax backend""" + + backend_name = "jax" + _SubBackendJIT = SubBackendJITJax + + def compile_extension( + self, + path_backend, + name_ext_file=None, + native=False, + xsimd=False, + openmp=False, + str_accelerator_flags: Optional[str] = None, + parallel=True, + force=True, + ): + if name_ext_file is None: + name_ext_file = self.name_ext_from_path_backend(path_backend) + + with open(path_backend) as file: + source = file.read() + + source = source.replace("# __protected__ ", "") + + with open(path_backend.with_name(name_ext_file), "w") as file: + file.write(format_str(source)) + + compiling = False + process = None + return compiling, process + + def _make_backend_code(self, path_py, analysis, **kwargs): + """Create a backend code from a Python file""" + code, codes_ext, header = super()._make_backend_code(path_py, analysis) + + if not code: + return code, codes_ext, header + + code = add_jax_comments(code) + + for_meson = kwargs.get("for_meson", False) + if for_meson: + code = format_str(code.replace("# __protected__ ", "")) + + return code, codes_ext, header diff --git a/src/transonic/config.py b/src/transonic/config.py index f5278f7..bd3d7a0 100644 --- a/src/transonic/config.py +++ b/src/transonic/config.py @@ -75,7 +75,7 @@ def set_backend(backend: str): """Set the "global variable" backend_default""" backend = backend.lower() - supported_backends = ["pythran", "cython", "numba", "python"] + supported_backends = {"pythran", "cython", "jax", "numba", "python"} if backend not in supported_backends: raise ValueError(f"backend {backend} not supported") diff --git a/src/transonic/util.py b/src/transonic/util.py index cffd831..da7b4ee 100644 --- a/src/transonic/util.py +++ b/src/transonic/util.py @@ -131,6 +131,11 @@ def can_import_accelerator(backend: str = backend_default): import numba except ImportError: return False + elif backend =="jax": + try: + import jax + except ImportError: + return False elif backend == "python": return True else: From ea3020afaa8ca1c804491e1186aa151a36177fed Mon Sep 17 00:00:00 2001 From: "Ashwin V. Mohanan" Date: Thu, 14 Mar 2024 22:21:40 +0100 Subject: [PATCH 3/7] Create frozenset SUPPORTED_BACKENDS to avoid repetition --- src/transonic/analyses/__init__.py | 3 ++- src/transonic/config.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transonic/analyses/__init__.py b/src/transonic/analyses/__init__.py index c075f14..a431841 100644 --- a/src/transonic/analyses/__init__.py +++ b/src/transonic/analyses/__init__.py @@ -37,6 +37,7 @@ from .parser import parse_transonic_def_commands from .objects_from_str import replace_strings_by_objects from . import extast +from ..config import SUPPORTED_BACKENDS __all__ = ["print_dumped", "print_unparsed"] @@ -88,7 +89,7 @@ def get_decorated_dicts( kinds = ("functions", "functions_ext", "methods", "classes") - backend_names = ("__all__", "pythran", "cython", "numba", "python") + backend_names = ("__all__", *SUPPORTED_BACKENDS) decorated_dicts = { kind: {name: {} for name in backend_names} for kind in kinds } diff --git a/src/transonic/config.py b/src/transonic/config.py index bd3d7a0..032f8d4 100644 --- a/src/transonic/config.py +++ b/src/transonic/config.py @@ -47,7 +47,7 @@ from warnings import warn path_root = Path(os.environ.get("TRANSONIC_DIR", Path.home() / ".transonic")) - +SUPPORTED_BACKENDS = frozenset(("pythran", "cython", "jax", "numba", "python")) def strtobool(value): """Convert a string representation of truth to true (1) or false (0). @@ -75,8 +75,7 @@ def set_backend(backend: str): """Set the "global variable" backend_default""" backend = backend.lower() - supported_backends = {"pythran", "cython", "jax", "numba", "python"} - if backend not in supported_backends: + if backend not in SUPPORTED_BACKENDS: raise ValueError(f"backend {backend} not supported") global backend_default, backend_set_by_user From 0fdc8022fdd60f92ad3ed592dacf140b1efab2e0 Mon Sep 17 00:00:00 2001 From: paugier Date: Sat, 1 Jun 2024 05:49:06 +0200 Subject: [PATCH 4/7] Fix data_tests/package_for_test_meson for JAX --- .hgignore | 1 + data_tests/package_for_test_meson/README.md | 2 +- .../package_for_test_meson/for_test__jax__meson.build | 10 ++++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 data_tests/package_for_test_meson/src/package_for_test_meson/for_test__jax__meson.build diff --git a/.hgignore b/.hgignore index fda5ebe..0a4b5e0 100644 --- a/.hgignore +++ b/.hgignore @@ -35,6 +35,7 @@ doc/ipynb/*.rst **/__cython__/* **/__numba__/* **/__python__/* +**/__jax__/* doc/for_dev/**/*.c doc/for_dev/**/*.html diff --git a/data_tests/package_for_test_meson/README.md b/data_tests/package_for_test_meson/README.md index 36b9c89..d08a827 100644 --- a/data_tests/package_for_test_meson/README.md +++ b/data_tests/package_for_test_meson/README.md @@ -21,7 +21,7 @@ needed for the editable mode (see We can also install with another backend with: ```sh -pip install --no-build-isolation --config-settings=setup-args=-Dtransonic-backend=python . +pip install --no-build-isolation -C setup-args=-Dtransonic-backend=python . # or (but does not work here for another reason) python -m build --no-isolation -Csetup-args=-Dtransonic-backend=python . ``` diff --git a/data_tests/package_for_test_meson/src/package_for_test_meson/for_test__jax__meson.build b/data_tests/package_for_test_meson/src/package_for_test_meson/for_test__jax__meson.build new file mode 100644 index 0000000..ae40de5 --- /dev/null +++ b/data_tests/package_for_test_meson/src/package_for_test_meson/for_test__jax__meson.build @@ -0,0 +1,10 @@ +python_sources = [ + 'bar.py', + 'foo.py', +] + +py.install_sources( + python_sources, + pure: false, + subdir: 'package_for_test_meson/__jax__', +) From d01277d5464a9754f3241fe322fbf065003a3d6a Mon Sep 17 00:00:00 2001 From: paugier Date: Sat, 1 Jun 2024 05:58:02 +0200 Subject: [PATCH 5/7] Add data_tests/saved__backend__/jax --- .../__ext__MyClass2__exterior_import_boost.py | 7 +++ ..._ext__MyClass2__exterior_import_boost_2.py | 5 +++ .../jax/__ext__func__exterior_import_boost.py | 7 +++ .../__ext__func__exterior_import_boost_2.py | 5 +++ data_tests/saved__backend__/jax/add_inline.py | 19 ++++++++ .../saved__backend__/jax/assign_func_boost.py | 9 ++++ .../saved__backend__/jax/block_fluidsim.py | 18 ++++++++ .../saved__backend__/jax/blocks_type_hints.py | 19 ++++++++ .../jax/boosted_class_use_import.py | 13 ++++++ .../jax/boosted_func_use_import.py | 12 +++++ .../saved__backend__/jax/class_blocks.py | 44 +++++++++++++++++++ .../saved__backend__/jax/class_rec_calls.py | 20 +++++++++ data_tests/saved__backend__/jax/classic.py | 11 +++++ .../saved__backend__/jax/default_params.py | 10 +++++ data_tests/saved__backend__/jax/methods.py | 13 ++++++ .../jax/mixed_classic_type_hint.py | 18 ++++++++ data_tests/saved__backend__/jax/no_arg.py | 16 +++++++ .../saved__backend__/jax/row_sum_boost.py | 27 ++++++++++++ .../saved__backend__/jax/subpackages.py | 33 ++++++++++++++ .../jax/type_hint_notemplate.py | 13 ++++++ 20 files changed, 319 insertions(+) create mode 100644 data_tests/saved__backend__/jax/__ext__MyClass2__exterior_import_boost.py create mode 100644 data_tests/saved__backend__/jax/__ext__MyClass2__exterior_import_boost_2.py create mode 100644 data_tests/saved__backend__/jax/__ext__func__exterior_import_boost.py create mode 100644 data_tests/saved__backend__/jax/__ext__func__exterior_import_boost_2.py create mode 100644 data_tests/saved__backend__/jax/add_inline.py create mode 100644 data_tests/saved__backend__/jax/assign_func_boost.py create mode 100644 data_tests/saved__backend__/jax/block_fluidsim.py create mode 100644 data_tests/saved__backend__/jax/blocks_type_hints.py create mode 100644 data_tests/saved__backend__/jax/boosted_class_use_import.py create mode 100644 data_tests/saved__backend__/jax/boosted_func_use_import.py create mode 100644 data_tests/saved__backend__/jax/class_blocks.py create mode 100644 data_tests/saved__backend__/jax/class_rec_calls.py create mode 100644 data_tests/saved__backend__/jax/classic.py create mode 100644 data_tests/saved__backend__/jax/default_params.py create mode 100644 data_tests/saved__backend__/jax/methods.py create mode 100644 data_tests/saved__backend__/jax/mixed_classic_type_hint.py create mode 100644 data_tests/saved__backend__/jax/no_arg.py create mode 100644 data_tests/saved__backend__/jax/row_sum_boost.py create mode 100644 data_tests/saved__backend__/jax/subpackages.py create mode 100644 data_tests/saved__backend__/jax/type_hint_notemplate.py diff --git a/data_tests/saved__backend__/jax/__ext__MyClass2__exterior_import_boost.py b/data_tests/saved__backend__/jax/__ext__MyClass2__exterior_import_boost.py new file mode 100644 index 0000000..365c3a0 --- /dev/null +++ b/data_tests/saved__backend__/jax/__ext__MyClass2__exterior_import_boost.py @@ -0,0 +1,7 @@ +const = 1 +from __ext__MyClass2__exterior_import_boost_2 import func_import_2 +import numpy as np + + +def func_import(): + return const + func_import_2() + np.pi - np.pi diff --git a/data_tests/saved__backend__/jax/__ext__MyClass2__exterior_import_boost_2.py b/data_tests/saved__backend__/jax/__ext__MyClass2__exterior_import_boost_2.py new file mode 100644 index 0000000..d12395e --- /dev/null +++ b/data_tests/saved__backend__/jax/__ext__MyClass2__exterior_import_boost_2.py @@ -0,0 +1,5 @@ +const = 1 + + +def func_import_2(): + return const diff --git a/data_tests/saved__backend__/jax/__ext__func__exterior_import_boost.py b/data_tests/saved__backend__/jax/__ext__func__exterior_import_boost.py new file mode 100644 index 0000000..7be565f --- /dev/null +++ b/data_tests/saved__backend__/jax/__ext__func__exterior_import_boost.py @@ -0,0 +1,7 @@ +const = 1 +from __ext__func__exterior_import_boost_2 import func_import_2 +import numpy as np + + +def func_import(): + return const + func_import_2() + np.pi - np.pi diff --git a/data_tests/saved__backend__/jax/__ext__func__exterior_import_boost_2.py b/data_tests/saved__backend__/jax/__ext__func__exterior_import_boost_2.py new file mode 100644 index 0000000..d12395e --- /dev/null +++ b/data_tests/saved__backend__/jax/__ext__func__exterior_import_boost_2.py @@ -0,0 +1,5 @@ +const = 1 + + +def func_import_2(): + return const diff --git a/data_tests/saved__backend__/jax/add_inline.py b/data_tests/saved__backend__/jax/add_inline.py new file mode 100644 index 0000000..09065a2 --- /dev/null +++ b/data_tests/saved__backend__/jax/add_inline.py @@ -0,0 +1,19 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def add(a, b): + return a + b + + +# __protected__ @jit + + +def use_add(n=10000): + tmp = 0 + for _ in range(n): + tmp = add(tmp, 1) + return tmp + + +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/assign_func_boost.py b/data_tests/saved__backend__/jax/assign_func_boost.py new file mode 100644 index 0000000..a0d21a0 --- /dev/null +++ b/data_tests/saved__backend__/jax/assign_func_boost.py @@ -0,0 +1,9 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def func(x): + return x**2 + + +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/block_fluidsim.py b/data_tests/saved__backend__/jax/block_fluidsim.py new file mode 100644 index 0000000..3ca3746 --- /dev/null +++ b/data_tests/saved__backend__/jax/block_fluidsim.py @@ -0,0 +1,18 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def rk2_step0(state_spect_n12, state_spect, tendencies_n, diss2, dt): + # transonic block ( + # complex128[][][] state_spect_n12, state_spect, + # tendencies_n; + # float64[][] diss2; + # float dt + # ) + state_spect_n12[:] = (state_spect + dt / 2 * tendencies_n) * diss2 + + +arguments_blocks = { + "rk2_step0": ["state_spect_n12", "state_spect", "tendencies_n", "diss2", "dt"] +} +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/blocks_type_hints.py b/data_tests/saved__backend__/jax/blocks_type_hints.py new file mode 100644 index 0000000..880e5fd --- /dev/null +++ b/data_tests/saved__backend__/jax/blocks_type_hints.py @@ -0,0 +1,19 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def block0(a, b, n): + # transonic block ( + # A a; A1 b; + # int n + # ) + # transonic block ( + # int[:] a, b; + # float n + # ) + result = a**2 + b.mean() ** 3 + n + return result + + +arguments_blocks = {"block0": ["a", "b", "n"]} +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/boosted_class_use_import.py b/data_tests/saved__backend__/jax/boosted_class_use_import.py new file mode 100644 index 0000000..787ba8d --- /dev/null +++ b/data_tests/saved__backend__/jax/boosted_class_use_import.py @@ -0,0 +1,13 @@ +# __protected__ from jax import jit +import jax.numpy as np +from __ext__MyClass2__exterior_import_boost import func_import + +# __protected__ @jit + + +def __for_method__MyClass2__myfunc(self_attr0, self_attr1, arg): + return self_attr1 + self_attr0 + np.abs(arg) + func_import() + + +__code_new_method__MyClass2__myfunc = "\n\ndef new_method(self, arg):\n return backend_func(self.attr0, self.attr1, arg)\n\n" +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/boosted_func_use_import.py b/data_tests/saved__backend__/jax/boosted_func_use_import.py new file mode 100644 index 0000000..9ee6a4b --- /dev/null +++ b/data_tests/saved__backend__/jax/boosted_func_use_import.py @@ -0,0 +1,12 @@ +# __protected__ from jax import jit +import jax.numpy as np +from __ext__func__exterior_import_boost import func_import + +# __protected__ @jit + + +def func(a, b): + return (a * np.log(b)).max() + func_import() + + +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/class_blocks.py b/data_tests/saved__backend__/jax/class_blocks.py new file mode 100644 index 0000000..97c9342 --- /dev/null +++ b/data_tests/saved__backend__/jax/class_blocks.py @@ -0,0 +1,44 @@ +# __protected__ from jax import jit +import jax.numpy as np + +# __protected__ @jit + + +def block0(a, b, n): + # foo + # transonic block ( + # float[][] a, b; + # int n + # ) bar + # foo + # transonic block ( + # float[][][] a, b; + # int n + # ) + # foobar + result = np.zeros_like(a) + for _ in range(n): + result += a**2 + b**3 + return result + + +# __protected__ @jit + + +def block1(a, b, n): + # transonic block ( + # float[][] a, b; + # int n + # ) + # transonic block ( + # float[][][] a, b; + # int n + # ) + result = np.zeros_like(a) + for _ in range(n): + result += a**2 + b**3 + return result + + +arguments_blocks = {"block0": ["a", "b", "n"], "block1": ["a", "b", "n"]} +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/class_rec_calls.py b/data_tests/saved__backend__/jax/class_rec_calls.py new file mode 100644 index 0000000..944af25 --- /dev/null +++ b/data_tests/saved__backend__/jax/class_rec_calls.py @@ -0,0 +1,20 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def __for_method__Myclass__func(self_attr, self_attr2, arg): + if __for_method__Myclass__func(self_attr, self_attr2, arg - 1) < 1: + return 1 + else: + a = __for_method__Myclass__func( + self_attr, self_attr2, arg - 1 + ) * __for_method__Myclass__func(self_attr, self_attr2, arg - 1) + return ( + a + + self_attr * self_attr2 * arg + + __for_method__Myclass__func(self_attr, self_attr2, arg - 1) + ) + + +__code_new_method__Myclass__func = "\n\ndef new_method(self, arg):\n return backend_func(self.attr, self.attr2, arg)\n\n" +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/classic.py b/data_tests/saved__backend__/jax/classic.py new file mode 100644 index 0000000..97ac465 --- /dev/null +++ b/data_tests/saved__backend__/jax/classic.py @@ -0,0 +1,11 @@ +# __protected__ from jax import jit +import jax.numpy as np + +# __protected__ @jit + + +def func(a, b): + return (a * np.log(b)).max() + + +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/default_params.py b/data_tests/saved__backend__/jax/default_params.py new file mode 100644 index 0000000..2dacf60 --- /dev/null +++ b/data_tests/saved__backend__/jax/default_params.py @@ -0,0 +1,10 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def func(a=1, b=None, c=1.0): + print(b) + return a + c + + +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/methods.py b/data_tests/saved__backend__/jax/methods.py new file mode 100644 index 0000000..153f306 --- /dev/null +++ b/data_tests/saved__backend__/jax/methods.py @@ -0,0 +1,13 @@ +# __protected__ from jax import jit +import jax.numpy as np + +# __protected__ @jit + + +def __for_method__Transmitter____call__(self_arr, self_freq, inp): + """My docstring""" + return (inp * np.exp(np.arange(len(inp)) * self_freq * 1j), self_arr) + + +__code_new_method__Transmitter____call__ = "\n\ndef new_method(self, inp):\n return backend_func(self.arr, self.freq, inp)\n\n" +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/mixed_classic_type_hint.py b/data_tests/saved__backend__/jax/mixed_classic_type_hint.py new file mode 100644 index 0000000..f81d991 --- /dev/null +++ b/data_tests/saved__backend__/jax/mixed_classic_type_hint.py @@ -0,0 +1,18 @@ +# __protected__ from jax import jit +import jax.numpy as np + +# __protected__ @jit + + +def func(a, b): + return (a * np.log(b)).max() + + +# __protected__ @jit + + +def func1(a, b): + return a * np.cos(b) + + +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/no_arg.py b/data_tests/saved__backend__/jax/no_arg.py new file mode 100644 index 0000000..23a841b --- /dev/null +++ b/data_tests/saved__backend__/jax/no_arg.py @@ -0,0 +1,16 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def func(): + return 1 + + +# __protected__ @jit + + +def func2(): + return 1 + + +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/row_sum_boost.py b/data_tests/saved__backend__/jax/row_sum_boost.py new file mode 100644 index 0000000..e5b2683 --- /dev/null +++ b/data_tests/saved__backend__/jax/row_sum_boost.py @@ -0,0 +1,27 @@ +# __protected__ from jax import jit +import jax.numpy as np + +# __protected__ @jit + + +def row_sum(arr, columns): + return arr.T[columns].sum(0) + + +# __protected__ @jit + + +def row_sum_loops(arr, columns): + # locals type annotations are used only for Cython + # arr.dtype not supported for memoryview + dtype = type(arr[0, 0]) + res = np.empty(arr.shape[0], dtype=dtype) + for i in range(arr.shape[0]): + sum_ = dtype(0) + for j in range(columns.shape[0]): + sum_ += arr[i, columns[j]] + res[i] = sum_ + return res + + +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/subpackages.py b/data_tests/saved__backend__/jax/subpackages.py new file mode 100644 index 0000000..c8696ca --- /dev/null +++ b/data_tests/saved__backend__/jax/subpackages.py @@ -0,0 +1,33 @@ +# __protected__ from jax import jit +from numpy.fft import rfft +from numpy.random import randn +from numpy.linalg import matrix_power +from scipy.special import jv + +# __protected__ @jit + + +def test_np_fft(u): + u_fft = rfft(u) + return u_fft + + +# __protected__ @jit + + +def test_np_linalg_random(u): + (nx, ny) = u.shape + u[:] = randn(nx, ny) + u2 = u.T * u + u4 = matrix_power(u2, 2) + return u4 + + +# __protected__ @jit + + +def test_sp_special(v, x): + return jv(v, x) + + +__transonic__ = "0.6.4" diff --git a/data_tests/saved__backend__/jax/type_hint_notemplate.py b/data_tests/saved__backend__/jax/type_hint_notemplate.py new file mode 100644 index 0000000..225cdbb --- /dev/null +++ b/data_tests/saved__backend__/jax/type_hint_notemplate.py @@ -0,0 +1,13 @@ +# __protected__ from jax import jit +# __protected__ @jit + + +def compute(a, b, c, d, e): + print(e) + tmp = a + b + if 1 and 2: + tmp *= 2 + return tmp + + +__transonic__ = "0.6.4" From 3f39c68d507d202ae2e7813952fe403634cf7592 Mon Sep 17 00:00:00 2001 From: paugier Date: Sat, 1 Jun 2024 06:31:14 +0200 Subject: [PATCH 6/7] pyproject.toml with tool.pytest.ini_options for ipdb --- pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5576a67..1f76b0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,8 +82,9 @@ dev = ["pip", "build", "pylint", "twine"] [tool.pytest] addopts = "--doctest-modules" -# [tool.pytest.ini_options] -# addopts = "--cov --cov-config=pyproject.toml --no-cov-on-fail" +[tool.pytest.ini_options] +addopts = "--pdbcls=IPython.terminal.debugger:TerminalPdb" + [tool.black] line-length = 82 From 85d5d20270aaceaec4cb2026559c5d00bb579b30 Mon Sep 17 00:00:00 2001 From: paugier Date: Sat, 1 Jun 2024 07:00:54 +0200 Subject: [PATCH 7/7] Skip test_init_transonified.py::TestsInit::test_pythranize for JAX --- src/transonic/backends/jax.py | 4 +--- src/transonic/util.py | 2 +- tests/test_init_transonified.py | 3 +++ 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transonic/backends/jax.py b/src/transonic/backends/jax.py index 98041c8..4669b60 100644 --- a/src/transonic/backends/jax.py +++ b/src/transonic/backends/jax.py @@ -41,9 +41,7 @@ def add_jax_comments(code): # Add JIT decorator if isinstance(node, gast.FunctionDef): - new_body.append( - CommentLine("# __protected__ @jit") - ) + new_body.append(CommentLine("# __protected__ @jit")) new_body.append(node) mod.body = new_body diff --git a/src/transonic/util.py b/src/transonic/util.py index da7b4ee..05cdc03 100644 --- a/src/transonic/util.py +++ b/src/transonic/util.py @@ -131,7 +131,7 @@ def can_import_accelerator(backend: str = backend_default): import numba except ImportError: return False - elif backend =="jax": + elif backend == "jax": try: import jax except ImportError: diff --git a/tests/test_init_transonified.py b/tests/test_init_transonified.py index 8850419..cc3c870 100644 --- a/tests/test_init_transonified.py +++ b/tests/test_init_transonified.py @@ -99,6 +99,9 @@ def test_transonified(self): for_test_init.func1(1.1, 2.2) for_test_init.check_class() + @unittest.skipIf( + backend.name == "jax", "Not yet supported by our JAX backend" + ) @unittest.skipIf( sys.platform.startswith("win") or not can_import_accelerator(), f"{backend.name} is required for TRANSONIC_COMPILE_AT_IMPORT",