Skip to content

Commit

Permalink
Merge pull request #4 from opendilab/release/0.2.1
Browse files Browse the repository at this point in the history
release(hansbug): use version 0.2.1
  • Loading branch information
HansBug authored Mar 22, 2022
2 parents fe5f681 + 716b4b9 commit 9cf4605
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 27 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/badge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ name: Badge Creation
on:
push:
branches: [ main, 'badge/*', 'doc/*' ]
pull_request:
branches: [ main, 'badge/*', 'doc/*' ]

jobs:
update-badges:
Expand Down
39 changes: 25 additions & 14 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: Code Test

on:
- push
- pull_request

jobs:
unittest:
Expand All @@ -18,27 +17,31 @@ jobs:
- '3.7'
- '3.8'
- '3.9'
numpy-version:
- '1.18.0'
- '1.20.0'
- '1.22.0'
torch-version:
- '1.2.0'
- '1.4.0'
- '1.5.0'
- '1.6.0'
- '1.7.0'
- '1.8.0'
- '1.9.0'
- '1.10.0'
exclude:
- os: 'ubuntu-18.04'
python-version: '3.9'
- python-version: '3.6'
numpy-version: '1.20.0'
- python-version: '3.6'
numpy-version: '1.22.0'
- python-version: '3.7'
numpy-version: '1.22.0'
- python-version: '3.8'
torch-version: '1.2.0'
- python-version: '3.9'
torch-version: '1.2.0'
- python-version: '3.9'
torch-version: '1.4.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
torch-version: '1.5.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
- python-version: '3.9'
torch-version: '1.6.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
torch-version: '1.7.0'

steps:
- name: Checkout code
Expand All @@ -60,6 +63,14 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install --upgrade flake8 setuptools wheel twine
- name: Install latest numpy
if: ${{ matrix.numpy-version == 'latest' }}
run: |
pip install 'numpy'
- name: Install numpy v${{ matrix.numpy-version }}
if: ${{ matrix.numpy-version != 'latest' }}
run: |
pip install 'numpy==${{ matrix.numpy-version }}'
- name: Install latest pytorch
if: ${{ matrix.torch-version == 'latest' }}
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api_doc/numpy/funcs.rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _doc_process(doc: str) -> str:
print_title(f"Description From Numpy v{_short_version}", levelc='-', file=p_func)
current_module(np.__name__, file=p_func)

_origin_doc = _doc_process(_origin.__doc__ or "")
_origin_doc = _doc_process(_origin.__doc__ or "").lstrip()
_doc_lines = _origin_doc.splitlines()
_first_line, _other_lines = _doc_lines[0], _doc_lines[1:]
if _first_line.strip():
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
treevalue>=1.2.0
treevalue>=1.3.0
torch>=1.1.0,<=1.10.0
hbutils>=0.0.1
numpy
21 changes: 21 additions & 0 deletions test/numpy/test_array.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import pytest
import torch

import treetensor.numpy as tnp
import treetensor.torch as ttorch
from treetensor.common import Object


Expand Down Expand Up @@ -233,3 +235,22 @@ def test_tolist(self):
'd': [0, 0, 0.0],
}
})

def test_tensor(self):
assert ttorch.isclose(self._DEMO_1.tensor().double(), ttorch.Tensor({
'a': ttorch.Tensor([[1, 2, 3], [4, 5, 6]]),
'b': ttorch.Tensor([1, 3, 5, 7]),
'x': {
'c': ttorch.Tensor([[11], [23]]),
'd': ttorch.Tensor([3, 9, 11.0])
}
}).double()).all()

assert (self._DEMO_1.tensor(dtype=torch.float64) == ttorch.Tensor({
'a': ttorch.Tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float64),
'b': ttorch.Tensor([1, 3, 5, 7], dtype=torch.float64),
'x': {
'c': ttorch.Tensor([[11], [23]], dtype=torch.float64),
'd': ttorch.Tensor([3, 9, 11.0], dtype=torch.float64),
}
})).all()
94 changes: 94 additions & 0 deletions test/numpy/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,97 @@ def test_array_equal(self):
'd': True,
}
})

def test_zeros(self):
zs = tnp.zeros((2, 3))
assert isinstance(zs, np.ndarray)
assert np.allclose(zs, np.zeros((2, 3)))

zs = tnp.zeros({'a': (2, 3), 'c': {'x': (3, 4)}})
assert tnp.allclose(zs, tnp.ndarray({
'a': np.zeros((2, 3)),
'c': {'x': np.zeros((3, 4))}
}))

def test_ones(self):
zs = tnp.ones((2, 3))
assert isinstance(zs, np.ndarray)
assert np.allclose(zs, np.ones((2, 3)))

zs = tnp.ones({'a': (2, 3), 'c': {'x': (3, 4)}})
assert tnp.allclose(zs, tnp.ndarray({
'a': np.ones((2, 3)),
'c': {'x': np.zeros((3, 4))}
}))

def test_stack(self):
a = np.array([1, 2, 3])
b = np.array([2, 3, 4])
nd = tnp.stack((a, b))
assert isinstance(nd, np.ndarray)
assert np.allclose(nd, np.array([[1, 2, 3],
[2, 3, 4]]))

a = tnp.array({
'a': [1, 2, 3],
'c': {'x': [11, 22, 33]},
})
b = tnp.array({
'a': [2, 3, 4],
'c': {'x': [22, 33, 44]},
})
nd = tnp.stack((a, b))
assert tnp.allclose(nd, tnp.array({
'a': [[1, 2, 3], [2, 3, 4]],
'c': {'x': [[11, 22, 33], [22, 33, 44]]},
}))

def test_concatenate(self):
a = np.array([[1, 2], [3, 4]])
b = np.array([[5, 6]])
nd = tnp.concatenate((a, b), axis=0)
assert isinstance(nd, np.ndarray)
assert np.allclose(nd, np.array([[1, 2],
[3, 4],
[5, 6]]))

a = tnp.array({
'a': [[1, 2], [3, 4]],
'c': {'x': [[11, 22], [33, 44]]},
})
b = tnp.array({
'a': [[5, 6]],
'c': {'x': [[55, 66]]},
})
nd = tnp.concatenate((a, b), axis=0)
assert tnp.allclose(nd, tnp.array({
'a': [[1, 2], [3, 4], [5, 6]],
'c': {'x': [[11, 22], [33, 44], [55, 66]]},
}))

def test_split(self):
x = np.arange(9.0)
ns = tnp.split(x, 3)
assert len(ns) == 3
assert isinstance(ns[0], np.ndarray)
assert np.allclose(ns[0], np.array([0.0, 1.0, 2.0]))
assert isinstance(ns[1], np.ndarray)
assert np.allclose(ns[1], np.array([3.0, 4.0, 5.0]))
assert isinstance(ns[2], np.ndarray)
assert np.allclose(ns[2], np.array([6.0, 7.0, 8.0]))

xx = tnp.arange(tnp.ndarray({'a': 9.0, 'c': {'x': 18.0}}))
ns = tnp.split(xx, 3)
assert len(ns) == 3
assert tnp.allclose(ns[0], tnp.array({
'a': [0.0, 1.0, 2.0],
'c': {'x': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]},
}))
assert tnp.allclose(ns[1], tnp.array({
'a': [3.0, 4.0, 5.0],
'c': {'x': [6.0, 7.0, 8.0, 9.0, 10.0, 11.0]},
}))
assert tnp.allclose(ns[2], tnp.array({
'a': [6.0, 7.0, 8.0],
'c': {'x': [12.0, 13.0, 14.0, 15.0, 16.0, 17.0]},
}))
10 changes: 6 additions & 4 deletions test/torch/tensor/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import treetensor.torch as ttorch
from .base import choose_mark

bool_init_dtype = torch.tensor([True, False]).dtype


# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchTensorReduction:
Expand All @@ -14,15 +16,15 @@ def test_all(self):
'b': {'x': [[True, True, ], [True, True, ]]}
}).all()
assert isinstance(t1, torch.Tensor)
assert t1.dtype == torch.bool
assert t1.dtype == bool_init_dtype
assert t1

t2 = ttorch.Tensor({
'a': [True, False],
'b': {'x': [[True, True, ], [True, True, ]]}
}).all()
assert isinstance(t2, torch.Tensor)
assert t2.dtype == torch.bool
assert t2.dtype == bool_init_dtype
assert not t2

t3 = ttorch.tensor({
Expand All @@ -48,15 +50,15 @@ def test_any(self):
'b': {'x': [[False, False, ], [False, False, ]]}
}).any()
assert isinstance(t1, torch.Tensor)
assert t1.dtype == torch.bool
assert t1.dtype == bool_init_dtype
assert t1

t2 = ttorch.Tensor({
'a': [False, False],
'b': {'x': [[False, False, ], [False, False, ]]}
}).any()
assert isinstance(t2, torch.Tensor)
assert t2.dtype == torch.bool
assert t2.dtype == bool_init_dtype
assert not t2

t3 = ttorch.Tensor({
Expand Down
2 changes: 1 addition & 1 deletion treetensor/common/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _load_func(name):
@doc_from_base()
@return_self_dec
@post_process(auto_tree_cls)
@func_treelize(return_type=TreeValue, rise=True)
@func_treelize(return_type=TreeValue, subside=True, rise=True)
@wraps(func, assigned=('__name__',), updated=())
def _new_func(*args, **kwargs):
return func(*args, **kwargs)
Expand Down
5 changes: 2 additions & 3 deletions treetensor/common/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from functools import wraps
from operator import itemgetter

from treevalue import TreeValue, walk
from treevalue import TreeValue, flatten_values

__all__ = [
'ireduce',
Expand All @@ -17,7 +16,7 @@ def _decorator(func):
def _new_func(*args, **kwargs):
result = func(*args, **kwargs)
if isinstance(result, TreeValue):
it = map(itemgetter(1), walk(result, include_nodes=False))
it = flatten_values(result)
return rfunc(piter(it))
else:
return result
Expand Down
2 changes: 1 addition & 1 deletion treetensor/config/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__TITLE__ = "DI-treetensor"

#: Version of this project.
__VERSION__ = "0.2.0"
__VERSION__ = "0.2.1"

#: Short description of the project, will be included in ``setup.py``.
__DESCRIPTION__ = 'A flexible, generalized tree-based tensor structure.'
Expand Down
16 changes: 16 additions & 0 deletions treetensor/numpy/array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from functools import lru_cache

import numpy
import torch
from treevalue import method_treelize

from .base import TreeNumpy
Expand All @@ -12,6 +15,12 @@
_ArrayProxy, _InstanceArrayProxy = get_tree_proxy(numpy.ndarray)


@lru_cache()
def _get_tensor_class(args0):
from ..torch import Tensor
return Tensor(args0)


class _BaseArrayMeta(clsmeta(numpy.asarray, allow_dict=True)):
pass

Expand Down Expand Up @@ -92,6 +101,13 @@ def all(self: numpy.ndarray, *args, **kwargs):
def any(self: numpy.ndarray, *args, **kwargs):
return self.any(*args, **kwargs)

@method_treelize(return_type=_get_tensor_class)
def tensor(self: numpy.ndarray, *args, **kwargs):
tensor_: torch.Tensor = torch.from_numpy(self)
if args or kwargs:
tensor_ = tensor_.to(*args, **kwargs)
return tensor_

@method_treelize()
def __eq__(self, other):
"""
Expand Down
32 changes: 32 additions & 0 deletions treetensor/numpy/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
__all__ = [
'all', 'any', 'array',
'equal', 'array_equal',
'stack', 'concatenate', 'split',
'zeros', 'ones',
]

func_treelize = post_process(post_process(args_mapping(
Expand Down Expand Up @@ -71,3 +73,33 @@ def array(p_object, *args, **kwargs):
})
"""
return np.array(p_object, *args, **kwargs)


@doc_from(np.stack)
@func_treelize(subside=True)
def stack(arrays, *args, **kwargs):
return np.stack(arrays, *args, **kwargs)


@doc_from(np.concatenate)
@func_treelize(subside=True)
def concatenate(arrays, *args, **kwargs):
return np.concatenate(arrays, *args, **kwargs)


@doc_from(np.split)
@func_treelize(rise=True)
def split(ary, *args, **kwargs):
return np.split(ary, *args, **kwargs)


@doc_from(np.zeros)
@func_treelize()
def zeros(shape, *args, **kwargs):
return np.zeros(shape, *args, **kwargs)


@doc_from(np.ones)
@func_treelize()
def ones(shape, *args, **kwargs):
return np.ones(shape, *args, **kwargs)

0 comments on commit 9cf4605

Please sign in to comment.