Skip to content

masahi/torchscript-to-tvm

Repository files navigation

This repo has some examples and WIP test cases on converting PyTorch models to TVM.

Below is an example of Python module -> Torchscript -> TVM Relay Translation.

See dynamic_test.py and rnn_test.py for more examples.

PyTorch module

class LoopWithIf(torch.nn.Module):
    def forward(self, inp):
        a = inp
        for i in range(inp.size(0)):
            b = a * 2
            b = a + b
            if b.sum() > 0.0:
                a += b
            else:
                a -= b
        return a

PyTorch JIT IR

graph(%self : __torch__.LoopWithIf,
      %inp.1 : Tensor):
  %2 : None = prim::Constant()
  %3 : int = prim::Constant[value=1]()
  %4 : bool = prim::Constant[value=1]() # dynamic_test.py:64:8
  %5 : int = prim::Constant[value=0]() # dynamic_test.py:64:32
  %6 : int = prim::Constant[value=2]() # dynamic_test.py:65:20
  %7 : float = prim::Constant[value=0]() # dynamic_test.py:67:25
  %8 : int = aten::size(%inp.1, %5) # dynamic_test.py:64:23
  %a : Tensor = prim::Loop(%8, %4, %inp.1) # dynamic_test.py:64:8
    block0(%i : int, %a.15 : Tensor):
      %b.1 : Tensor = aten::mul(%a.15, %6) # dynamic_test.py:65:16
      %b.3 : Tensor = aten::add(%a.15, %b.1, %3) # dynamic_test.py:66:16
      %14 : Tensor = aten::sum(%b.3, %2) # dynamic_test.py:67:15
      %15 : Tensor = aten::gt(%14, %7) # dynamic_test.py:67:15
      %16 : bool = aten::Bool(%15) # dynamic_test.py:67:15
      %a.14 : Tensor = prim::If(%16) # dynamic_test.py:67:12
        block0():
          %a.4 : Tensor = aten::add_(%a.15, %b.3, %3) # dynamic_test.py:68:16
          -> (%a.4)
        block1():
          %a.7 : Tensor = aten::sub_(%a.15, %b.3, %3) # dynamic_test.py:70:16
          -> (%a.7)
      -> (%4, %a.14)
  return (%a)

TVM Relay IR

v0.0.4
def @main(%X: Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] {
  %9 = (
    let %while_loop: fn (int32, Tensor[(10, 20), float32]) -> (int32, Tensor[(10, 20), float32]) = fn (%i: int32, %a.15: Tensor[(10, 20), float32]) -> (int32, Tensor[(10, 20), float32]) {
      %0 = greater_equal(%i, 1 /* ty=int32 */) /* ty=bool */;
      %1 = less_equal(%i, 10 /* ty=int32 */) /* ty=bool */;
      %2 = logical_and(%0, %1) /* ty=bool */;
      if (%2) {
        %3 = add(%i, 1 /* ty=int32 */) /* ty=int32 */;
        %4 = multiply(%a.15, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
        %5 = add(%a.15, %4) /* ty=Tensor[(10, 20), float32] */;
        %6 = sum(%5) /* ty=float32 */;
        %7 = greater(%6, 0f /* ty=float32 */) /* ty=bool */;
        %8 = if (%7) {
          add(%a.15, %5) /* ty=Tensor[(10, 20), float32] */
        } else {
          subtract(%a.15, %5) /* ty=Tensor[(10, 20), float32] */
        };
        %while_loop(%3, %8) /* ty=(int32, Tensor[(10, 20), float32]) */
      } else {
        (%i, %a.15)
      }
    };
    %while_loop
  );
  %10 = %9(1 /* ty=int32 */, %X) /* ty=(int32, Tensor[(10, 20), float32]) */;
  %10.1
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published