Skip to content

Commit

Permalink
Merge pull request #78 from opendilab/dev/jax
Browse files Browse the repository at this point in the history
dev(hansbug): add register support for treevalue
  • Loading branch information
HansBug authored Feb 26, 2023
2 parents 6187ad5 + ee3184d commit b61551a
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 9 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ jobs:
pip install -r requirements.txt
pip install -r requirements-build.txt
pip install -r requirements-test.txt
- name: Install extra PyPI dependencies
continue-on-error: true
shell: bash
run: |
pip install -r requirements-test-extra.txt
- name: Test the basic environment
shell: bash
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_doc/tree/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ treevalue.tree
tree
func
general
integration
13 changes: 13 additions & 0 deletions docs/source/api_doc/tree/integration.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
treevalue.tree.integration
=======================================

.. py:currentmodule:: treevalue.tree.integration
.. _apidoc_tree_integration_register_for_jax:

register_for_jax
------------------------

.. autofunction:: register_for_jax

3 changes: 2 additions & 1 deletion requirements-test-extra.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
torch>=1.1.0
jax[cpu]>=0.3.25; platform_system != 'Windows'
torch>=1.1.0; python_version < '3.11'
Empty file.
30 changes: 30 additions & 0 deletions test/tree/integration/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from unittest import skipUnless

import numpy as np
import pytest

from treevalue import FastTreeValue

try:
import jax
except (ModuleNotFoundError, ImportError):
jax = None


@pytest.mark.unittest
class TestTreeTreeIntegration:
@skipUnless(jax, 'Jax required.')
def test_jax_double(self):
@jax.jit
def double(x):
return x * 2 + 1.5

t1 = FastTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'b': {
'x': np.asarray(233.0),
'y': np.random.randn(2, 3)
}
})
assert FastTreeValue.func()(np.isclose)(double(t1), t1 * 2 + 1.5).all() == \
FastTreeValue({'a': True, 'b': {'x': True, 'y': True}})
1 change: 1 addition & 0 deletions treevalue/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .common import raw
from .func import *
from .general import *
from .integration import *
from .tree import *
1 change: 1 addition & 0 deletions treevalue/tree/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .jax import register_for_jax
6 changes: 6 additions & 0 deletions treevalue/tree/integration/cjax.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# distutils:language=c++
# cython:language_level=3

cdef tuple _c_flatten_for_jax(object tv)
cdef object _c_unflatten_for_jax(tuple aux, tuple values)
cpdef void register_for_jax(object cls) except*
47 changes: 47 additions & 0 deletions treevalue/tree/integration/cjax.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# distutils:language=c++
# cython:language_level=3

import cython

from ..tree.flatten cimport _c_flatten, _c_unflatten
from ..tree.tree cimport TreeValue

cdef inline tuple _c_flatten_for_jax(object tv):
cdef list result = []
_c_flatten(tv._detach(), (), result)

cdef list paths = []
cdef list values = []
for path, value in result:
paths.append(path)
values.append(value)

return values, (type(tv), paths)

cdef inline object _c_unflatten_for_jax(tuple aux, tuple values):
cdef object type_
cdef list paths
type_, paths = aux
return type_(_c_unflatten(zip(paths, values)))

@cython.binding(True)
cpdef void register_for_jax(object cls) except*:
"""
Overview:
Register treevalue class for jax.
:param cls: TreeValue class.
Examples::
>>> from treevalue import FastTreeValue, TreeValue, register_for_jax
>>> register_for_jax(TreeValue)
>>> register_for_jax(FastTreeValue)
.. warning::
This method will put a warning message and then do nothing when jax is not installed.
"""
if isinstance(cls, type) and issubclass(cls, TreeValue):
import jax
jax.tree_util.register_pytree_node(cls, _c_flatten_for_jax, _c_unflatten_for_jax)
else:
raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.')
19 changes: 19 additions & 0 deletions treevalue/tree/integration/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import warnings
from functools import wraps

try:
import jax
except (ModuleNotFoundError, ImportError):
from .cjax import register_for_jax as _original_register_for_jax


@wraps(_original_register_for_jax)
def register_for_jax(cls):
warnings.warn(f'Jax is not installed, registration of {cls!r} will be ignored.')
else:
from .cjax import register_for_jax
from ..tree import TreeValue
from ..general import FastTreeValue

register_for_jax(TreeValue)
register_for_jax(FastTreeValue)
8 changes: 4 additions & 4 deletions treevalue/tree/tree/flatten.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import cython
from .tree cimport TreeValue
from ..common.storage cimport TreeStorage, _c_undelay_data

cdef void _c_flatten(TreeStorage st, tuple path, list res) except *:
cdef inline void _c_flatten(TreeStorage st, tuple path, list res) except *:
cdef dict data = st.detach()
cdef tuple curpath

Expand Down Expand Up @@ -44,7 +44,7 @@ cpdef list flatten(TreeValue tree):
_c_flatten(tree._detach(), (), result)
return result

cdef void _c_flatten_values(TreeStorage st, list res) except *:
cdef inline void _c_flatten_values(TreeStorage st, list res) except *:
cdef dict data = st.detach()

cdef str k
Expand Down Expand Up @@ -72,7 +72,7 @@ cpdef list flatten_values(TreeValue tree):
_c_flatten_values(tree._detach(), result)
return result

cdef void _c_flatten_keys(TreeStorage st, tuple path, list res) except *:
cdef inline void _c_flatten_keys(TreeStorage st, tuple path, list res) except *:
cdef dict data = st.detach()
cdef tuple curpath

Expand Down Expand Up @@ -102,7 +102,7 @@ cpdef list flatten_keys(TreeValue tree):
_c_flatten_keys(tree._detach(), (), result)
return result

cdef TreeStorage _c_unflatten(object pairs):
cdef inline TreeStorage _c_unflatten(object pairs):
cdef dict raw_data = {}
cdef TreeStorage result = TreeStorage(raw_data)
cdef list stack = []
Expand Down

0 comments on commit b61551a

Please sign in to comment.