From b6192f4361683ff2c3dfa6a770112ffe057c9b82 Mon Sep 17 00:00:00 2001 From: HansBug Date: Sun, 26 Feb 2023 21:37:04 +0800 Subject: [PATCH 1/2] dev(hansbug): add register support for treevalue --- .github/workflows/test.yml | 4 -- docs/source/api_doc/tree/index.rst | 1 + docs/source/api_doc/tree/integration.rst | 13 +++++++ requirements-test-extra.txt | 3 +- test/tree/integration/__init__.py | 0 test/tree/integration/test_jax.py | 30 +++++++++++++++ treevalue/tree/__init__.py | 1 + treevalue/tree/integration/__init__.py | 1 + treevalue/tree/integration/cjax.pxd | 6 +++ treevalue/tree/integration/cjax.pyx | 47 ++++++++++++++++++++++++ treevalue/tree/integration/jax.py | 19 ++++++++++ treevalue/tree/tree/flatten.pyx | 8 ++-- 12 files changed, 124 insertions(+), 9 deletions(-) create mode 100644 docs/source/api_doc/tree/integration.rst create mode 100644 test/tree/integration/__init__.py create mode 100644 test/tree/integration/test_jax.py create mode 100644 treevalue/tree/integration/__init__.py create mode 100644 treevalue/tree/integration/cjax.pxd create mode 100644 treevalue/tree/integration/cjax.pyx create mode 100644 treevalue/tree/integration/jax.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 397b9151d7..837f2ddcc8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -94,10 +94,6 @@ jobs: pip install -r requirements.txt pip install -r requirements-build.txt pip install -r requirements-test.txt - - name: Install extra PyPI dependencies - continue-on-error: true - shell: bash - run: | pip install -r requirements-test-extra.txt - name: Test the basic environment shell: bash diff --git a/docs/source/api_doc/tree/index.rst b/docs/source/api_doc/tree/index.rst index 62f5b7aa14..b8d0ec112b 100644 --- a/docs/source/api_doc/tree/index.rst +++ b/docs/source/api_doc/tree/index.rst @@ -8,3 +8,4 @@ treevalue.tree tree func general + integration diff --git a/docs/source/api_doc/tree/integration.rst b/docs/source/api_doc/tree/integration.rst new file mode 100644 index 0000000000..093c27ee5a --- /dev/null +++ b/docs/source/api_doc/tree/integration.rst @@ -0,0 +1,13 @@ +treevalue.tree.integration +======================================= + +.. py:currentmodule:: treevalue.tree.integration + + +.. _apidoc_tree_integration_register_for_jax: + +register_for_jax +------------------------ + +.. autofunction:: register_for_jax + diff --git a/requirements-test-extra.txt b/requirements-test-extra.txt index 5e73e53a39..ea8327ef93 100644 --- a/requirements-test-extra.txt +++ b/requirements-test-extra.txt @@ -1 +1,2 @@ -torch>=1.1.0 +jax[cpu]>=0.3.25; platform_system != 'Windows' +torch>=1.1.0; python_version < '3.11' diff --git a/test/tree/integration/__init__.py b/test/tree/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/tree/integration/test_jax.py b/test/tree/integration/test_jax.py new file mode 100644 index 0000000000..bca2d3b5b3 --- /dev/null +++ b/test/tree/integration/test_jax.py @@ -0,0 +1,30 @@ +from unittest import skipUnless + +import numpy as np +import pytest + +from treevalue import FastTreeValue + +try: + import jax +except (ModuleNotFoundError, ImportError): + jax = None + + +@pytest.mark.unittest +@skipUnless(jax, 'Jax required.') +class TestTreeTreeIntegration: + def test_jax_double(self): + @jax.jit + def double(x): + return x * 2 + 1.5 + + t1 = FastTreeValue({ + 'a': np.random.randint(0, 10, (2, 3)), + 'b': { + 'x': np.asarray(233.0), + 'y': np.random.randn(2, 3) + } + }) + assert FastTreeValue.func()(np.isclose)(double(t1), t1 * 2 + 1.5).all() == \ + FastTreeValue({'a': True, 'b': {'x': True, 'y': True}}) diff --git a/treevalue/tree/__init__.py b/treevalue/tree/__init__.py index f0146cc0db..6055f4e6b3 100644 --- a/treevalue/tree/__init__.py +++ b/treevalue/tree/__init__.py @@ -1,4 +1,5 @@ from .common import raw from .func import * from .general import * +from .integration import * from .tree import * diff --git a/treevalue/tree/integration/__init__.py b/treevalue/tree/integration/__init__.py new file mode 100644 index 0000000000..81a5f62d25 --- /dev/null +++ b/treevalue/tree/integration/__init__.py @@ -0,0 +1 @@ +from .jax import register_for_jax diff --git a/treevalue/tree/integration/cjax.pxd b/treevalue/tree/integration/cjax.pxd new file mode 100644 index 0000000000..e4e49bce54 --- /dev/null +++ b/treevalue/tree/integration/cjax.pxd @@ -0,0 +1,6 @@ +# distutils:language=c++ +# cython:language_level=3 + +cdef tuple _c_flatten_for_jax(object tv) +cdef object _c_unflatten_for_jax(tuple aux, tuple values) +cpdef void register_for_jax(object cls) except* diff --git a/treevalue/tree/integration/cjax.pyx b/treevalue/tree/integration/cjax.pyx new file mode 100644 index 0000000000..a170dbd97c --- /dev/null +++ b/treevalue/tree/integration/cjax.pyx @@ -0,0 +1,47 @@ +# distutils:language=c++ +# cython:language_level=3 + +import cython + +from ..tree.flatten cimport _c_flatten, _c_unflatten +from ..tree.tree cimport TreeValue + +cdef inline tuple _c_flatten_for_jax(object tv): + cdef list result = [] + _c_flatten(tv._detach(), (), result) + + cdef list paths = [] + cdef list values = [] + for path, value in result: + paths.append(path) + values.append(value) + + return values, (type(tv), paths) + +cdef inline object _c_unflatten_for_jax(tuple aux, tuple values): + cdef object type_ + cdef list paths + type_, paths = aux + return type_(_c_unflatten(zip(paths, values))) + +@cython.binding(True) +cpdef void register_for_jax(object cls) except*: + """ + Overview: + Register treevalue class for jax. + + :param cls: TreeValue class. + + Examples:: + >>> from treevalue import FastTreeValue, TreeValue, register_for_jax + >>> register_for_jax(TreeValue) + >>> register_for_jax(FastTreeValue) + + .. warning:: + This method will put a warning message and then do nothing when jax is not installed. + """ + if isinstance(cls, type) and issubclass(cls, TreeValue): + import jax + jax.tree_util.register_pytree_node(cls, _c_flatten_for_jax, _c_unflatten_for_jax) + else: + raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.') diff --git a/treevalue/tree/integration/jax.py b/treevalue/tree/integration/jax.py new file mode 100644 index 0000000000..aae995b599 --- /dev/null +++ b/treevalue/tree/integration/jax.py @@ -0,0 +1,19 @@ +import warnings +from functools import wraps + +try: + import jax +except (ModuleNotFoundError, ImportError): + from .cjax import register_for_jax as _original_register_for_jax + + + @wraps(_original_register_for_jax) + def register_for_jax(cls): + warnings.warn(f'Jax is not installed, registration of {cls!r} will be ignored.') +else: + from .cjax import register_for_jax + from ..tree import TreeValue + from ..general import FastTreeValue + + register_for_jax(TreeValue) + register_for_jax(FastTreeValue) diff --git a/treevalue/tree/tree/flatten.pyx b/treevalue/tree/tree/flatten.pyx index 54449a8c2e..c195631e2e 100644 --- a/treevalue/tree/tree/flatten.pyx +++ b/treevalue/tree/tree/flatten.pyx @@ -8,7 +8,7 @@ import cython from .tree cimport TreeValue from ..common.storage cimport TreeStorage, _c_undelay_data -cdef void _c_flatten(TreeStorage st, tuple path, list res) except *: +cdef inline void _c_flatten(TreeStorage st, tuple path, list res) except *: cdef dict data = st.detach() cdef tuple curpath @@ -44,7 +44,7 @@ cpdef list flatten(TreeValue tree): _c_flatten(tree._detach(), (), result) return result -cdef void _c_flatten_values(TreeStorage st, list res) except *: +cdef inline void _c_flatten_values(TreeStorage st, list res) except *: cdef dict data = st.detach() cdef str k @@ -72,7 +72,7 @@ cpdef list flatten_values(TreeValue tree): _c_flatten_values(tree._detach(), result) return result -cdef void _c_flatten_keys(TreeStorage st, tuple path, list res) except *: +cdef inline void _c_flatten_keys(TreeStorage st, tuple path, list res) except *: cdef dict data = st.detach() cdef tuple curpath @@ -102,7 +102,7 @@ cpdef list flatten_keys(TreeValue tree): _c_flatten_keys(tree._detach(), (), result) return result -cdef TreeStorage _c_unflatten(object pairs): +cdef inline TreeStorage _c_unflatten(object pairs): cdef dict raw_data = {} cdef TreeStorage result = TreeStorage(raw_data) cdef list stack = [] From ee3184df354e2c9db046a83a1ffb653659367288 Mon Sep 17 00:00:00 2001 From: HansBug Date: Sun, 26 Feb 2023 21:43:51 +0800 Subject: [PATCH 2/2] dev(hansbug): fix skip --- test/tree/integration/test_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tree/integration/test_jax.py b/test/tree/integration/test_jax.py index bca2d3b5b3..97e6e6745f 100644 --- a/test/tree/integration/test_jax.py +++ b/test/tree/integration/test_jax.py @@ -12,8 +12,8 @@ @pytest.mark.unittest -@skipUnless(jax, 'Jax required.') class TestTreeTreeIntegration: + @skipUnless(jax, 'Jax required.') def test_jax_double(self): @jax.jit def double(x):