Skip to content

Releases: opendilab/treevalue

v1.5.0

16 Oct 12:12
Compare
Choose a tag to compare

What's Changed

  • dev(hansbug): add register custom dicts by @HansBug in #88
  • remove support for py3.7 && add support for py3.12
  • use cython>=3 to compile the cpy code

Full Changelog: v1.4.12...v1.5.0

v1.4.12

14 Aug 05:18
Compare
Choose a tag to compare

What's Changed

In the new version (v1.4.12), support for torch >= 2 versions has been added, including support for torch.compile for faster inference and backpropagation. Here's an example:

from typing import Tuple, Mapping

import torch
from torch import nn

from treevalue import FastTreeValue


# A simple MLP
class MLP(nn.Module):
    def __init__(self, in_features: int, out_features: int, layers: Tuple[int, ...] = (1024,)):
        nn.Module.__init__(self)
        self.in_features = in_features
        self.out_features = out_features
        self.layers = layers
        ios = [self.in_features, *self.layers, self.out_features]
        self.mlp = nn.Sequential(
            *(
                nn.Linear(in_, out_, bias=True)
                for in_, out_ in zip(ios[:-1], ios[1:])
            )
        )

    def forward(self, x):
        return self.mlp(x)


# Multiple headed MLP
class MultiHeadMLP(nn.Module):
    def __init__(self, in_features: int, out_features: Mapping[str, int], layers: Tuple[int, ...] = (1024,)):
        nn.Module.__init__(self)
        self.in_features = in_features
        self.out_features = out_features
        self.layers = layers
        _networks = {
            o_name: MLP(in_features, o_feat, layers)
            for o_name, o_feat in self.out_features.items()
        }
        self.mlps = nn.ModuleDict(_networks)  # use nn.ModuleDict to register child MLPs
        self._t_mlps = FastTreeValue(_networks)  # use TreeValue for batch inferring

    def forward(self, x):
        return self._t_mlps(x)


if __name__ == '__main__':
    net = MultiHeadMLP(
        20,
        {'a': 10, 'b': 20, 'c': 14, 'd': 3},
    )
    net = torch.compile(net)
    print(net)

    input_ = torch.randn(1, 10, 20)
    output = net(input_)
    print(output.shape)

The compiled version of the MultiHeadMLP above will have the following network structure:

OptimizedModule(
  (_orig_mod): MultiHeadMLP(
    (mlps): ModuleDict(
      (a): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=10, bias=True)
        )
      )
      (b): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=20, bias=True)
        )
      )
      (c): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=14, bias=True)
        )
      )
      (d): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=3, bias=True)
        )
      )
    )
  )
)

And the inference output after passing float32[1, 10, 20] as input will have the following dimensions:

<FastTreeValue 0x7fe9b197e6a0>
├── 'a' --> torch.Size([1, 10, 10])
├── 'b' --> torch.Size([1, 10, 20])
├── 'c' --> torch.Size([1, 10, 14])
└── 'd' --> torch.Size([1, 10, 3])

Full Changelog: v1.4.11...v1.4.12

v1.4.11

26 May 08:07
Compare
Choose a tag to compare

What's Changed

  • dev(hansbug): test for torch high version by @HansBug in #86

Full Changelog: v1.4.10...v1.4.11

v1.4.10

16 Mar 08:10
131d60c
Compare
Choose a tag to compare

What's Changed

  • dev(hansbug): fix bug of jax integration by @HansBug in #84

Full Changelog: v1.4.9...v1.4.10

v1.4.9

06 Mar 18:28
Compare
Choose a tag to compare

What's Changed

  • fix(hansbug): fix bug of #82, add more unittests by @HansBug in #83
  • dev(hansbug): optimize graphviz visualization by @HansBug in #75

Full Changelog: v1.4.7...v1.4.9

v1.4.7

27 Feb 12:17
fbb7ad1
Compare
Choose a tag to compare

What's Changed

  • dev(hansbug): add unpack by @HansBug in #80
  • dev(hansug): add support for torch integration by @HansBug in #79
  • dev(hansbug): add generic_flatten, generic_unflatten, generic_mapping and register_integrate_container for integration module by @HansBug in #81

Full Changelog: v1.4.6...v1.4.7

v1.4.6

26 Feb 15:06
b61551a
Compare
Choose a tag to compare

What's Changed

  • dev(hansbug): fix support of rise and subside for namedtuple by @HansBug in #76
  • dev(hansbug): add register support for treevalue by @HansBug in #78

Full Changelog: v1.4.5...v1.4.6

v1.4.5

12 Feb 17:14
Compare
Choose a tag to compare

What's Changed

  • fix(hansbug): fix bug of constraint in with_constrains by @HansBug in #74

Full Changelog: v1.4.4...v1.4.5

v1.4.4

29 Jan 03:24
Compare
Choose a tag to compare

What's Changed

  • dev(hansbug): add constraint access by @HansBug in #73

Full Changelog: v1.4.3...v1.4.4

v1.4.3

31 Dec 11:22
Compare
Choose a tag to compare

What's Changed

  • fix(feat): fix bug about TypeConstraint by @HansBug in #72

Full Changelog: v1.4.2...v1.4.3