-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #78 from opendilab/dev/jax
dev(hansbug): add register support for treevalue
- Loading branch information
Showing
12 changed files
with
124 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ treevalue.tree | |
tree | ||
func | ||
general | ||
integration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
treevalue.tree.integration | ||
======================================= | ||
|
||
.. py:currentmodule:: treevalue.tree.integration | ||
.. _apidoc_tree_integration_register_for_jax: | ||
|
||
register_for_jax | ||
------------------------ | ||
|
||
.. autofunction:: register_for_jax | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
torch>=1.1.0 | ||
jax[cpu]>=0.3.25; platform_system != 'Windows' | ||
torch>=1.1.0; python_version < '3.11' |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from unittest import skipUnless | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from treevalue import FastTreeValue | ||
|
||
try: | ||
import jax | ||
except (ModuleNotFoundError, ImportError): | ||
jax = None | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestTreeTreeIntegration: | ||
@skipUnless(jax, 'Jax required.') | ||
def test_jax_double(self): | ||
@jax.jit | ||
def double(x): | ||
return x * 2 + 1.5 | ||
|
||
t1 = FastTreeValue({ | ||
'a': np.random.randint(0, 10, (2, 3)), | ||
'b': { | ||
'x': np.asarray(233.0), | ||
'y': np.random.randn(2, 3) | ||
} | ||
}) | ||
assert FastTreeValue.func()(np.isclose)(double(t1), t1 * 2 + 1.5).all() == \ | ||
FastTreeValue({'a': True, 'b': {'x': True, 'y': True}}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .common import raw | ||
from .func import * | ||
from .general import * | ||
from .integration import * | ||
from .tree import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .jax import register_for_jax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# distutils:language=c++ | ||
# cython:language_level=3 | ||
|
||
cdef tuple _c_flatten_for_jax(object tv) | ||
cdef object _c_unflatten_for_jax(tuple aux, tuple values) | ||
cpdef void register_for_jax(object cls) except* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# distutils:language=c++ | ||
# cython:language_level=3 | ||
|
||
import cython | ||
|
||
from ..tree.flatten cimport _c_flatten, _c_unflatten | ||
from ..tree.tree cimport TreeValue | ||
|
||
cdef inline tuple _c_flatten_for_jax(object tv): | ||
cdef list result = [] | ||
_c_flatten(tv._detach(), (), result) | ||
|
||
cdef list paths = [] | ||
cdef list values = [] | ||
for path, value in result: | ||
paths.append(path) | ||
values.append(value) | ||
|
||
return values, (type(tv), paths) | ||
|
||
cdef inline object _c_unflatten_for_jax(tuple aux, tuple values): | ||
cdef object type_ | ||
cdef list paths | ||
type_, paths = aux | ||
return type_(_c_unflatten(zip(paths, values))) | ||
|
||
@cython.binding(True) | ||
cpdef void register_for_jax(object cls) except*: | ||
""" | ||
Overview: | ||
Register treevalue class for jax. | ||
:param cls: TreeValue class. | ||
Examples:: | ||
>>> from treevalue import FastTreeValue, TreeValue, register_for_jax | ||
>>> register_for_jax(TreeValue) | ||
>>> register_for_jax(FastTreeValue) | ||
.. warning:: | ||
This method will put a warning message and then do nothing when jax is not installed. | ||
""" | ||
if isinstance(cls, type) and issubclass(cls, TreeValue): | ||
import jax | ||
jax.tree_util.register_pytree_node(cls, _c_flatten_for_jax, _c_unflatten_for_jax) | ||
else: | ||
raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import warnings | ||
from functools import wraps | ||
|
||
try: | ||
import jax | ||
except (ModuleNotFoundError, ImportError): | ||
from .cjax import register_for_jax as _original_register_for_jax | ||
|
||
|
||
@wraps(_original_register_for_jax) | ||
def register_for_jax(cls): | ||
warnings.warn(f'Jax is not installed, registration of {cls!r} will be ignored.') | ||
else: | ||
from .cjax import register_for_jax | ||
from ..tree import TreeValue | ||
from ..general import FastTreeValue | ||
|
||
register_for_jax(TreeValue) | ||
register_for_jax(FastTreeValue) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters