Skip to content

Commit 073304d

Browse files
judajuda
andauthored
[TVM PyTorch Integration] libstdc++ CXX11 ABI Compatibility & boolean tensor support (#12232)
* first commit * rename * cmake * deprecated * newline * config * config * typo * skip tvm_class * rename * delete ptr * delete ptr * save progress * boolean support * cmake file * polish code * compile config * improving the codes * format * doc&errormsg * zero-cost copy * one step * to ndarray * extra output * delete extra codes * update test * boolean support * strong test * decrease memory copy * polish * reformat * polish * remove redundant import Co-authored-by: juda <[email protected]>
1 parent d2f9f25 commit 073304d

File tree

10 files changed

+844
-279
lines changed

10 files changed

+844
-279
lines changed

apps/pt_tvmdsoop/tests/test_as_torch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# specific language governing permissions and limitations
1818
# under the License.
1919
"""Test script for tvm torch module"""
20+
import tempfile
21+
2022
import numpy as np
2123

2224
import torch
@@ -190,7 +192,10 @@ def test_tvmscript_torch_gpu():
190192
q1 = torch.arange(8, device=cuda0).type(torch.float32)
191193
q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0)
192194

193-
ModuleGPU(q1, q2)
195+
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
196+
torch.save(ModuleGPU, tmp.name)
197+
loaded_mod = torch.load(tmp.name)
198+
loaded_mod(q1, q2)
194199

195200
tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5)
196201

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#!/usr/bin/env python
2+
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
"""Test script for boolean tensor support"""
20+
import tempfile
21+
22+
import torch
23+
24+
import tvm
25+
import tvm.testing
26+
from tvm.contrib.torch import as_torch, optimize_torch
27+
from tvm.script import tir as T
28+
29+
30+
def negate(x):
31+
return x.logical_not()
32+
33+
34+
def sum_up_tensor(x):
35+
return x.size(dim=0) - torch.sum(x.int())
36+
37+
38+
def tensor_boolean_operation(x):
39+
arr1 = (x + 0.3).floor().bool()
40+
arr2 = (~((x + 0.7).int().bool())).bool()
41+
ret = ((arr1 & arr2).byte() + 0.5).half()
42+
return ~(ret.bool())
43+
44+
45+
def test_bool_tensor_negate():
46+
input = torch.ones(1, dtype=torch.bool)
47+
optimized_negate = optimize_torch(
48+
negate,
49+
input,
50+
)
51+
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
52+
torch.save(optimized_negate, tmp.name)
53+
loaded_mod = torch.load(tmp.name)
54+
output = loaded_mod(negate(input))
55+
tvm.testing.assert_allclose(input.numpy(), output.numpy(), atol=1e-5, rtol=1e-5)
56+
57+
58+
def test_sum_up_tensor():
59+
x = torch.randint(0, 2, (16,))
60+
y = x.bool()
61+
optimized_func = optimize_torch(
62+
sum_up_tensor,
63+
(y,),
64+
)
65+
ret1 = (x[x == 0]).size(dim=0)
66+
ret2 = optimized_func(y).numpy()
67+
tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)
68+
69+
70+
def test_tensor_boolean_operation():
71+
input = torch.rand(200)
72+
model = optimize_torch(
73+
tensor_boolean_operation,
74+
input,
75+
)
76+
ret1 = tensor_boolean_operation(input)
77+
ret2 = model(input)
78+
tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)
79+
80+
81+
@as_torch
82+
@T.prim_func
83+
def negate_tvmscript(
84+
X: T.Buffer[(8, 8), "bool"],
85+
Y: T.Buffer[(8, 8), "float32"],
86+
Z: T.Buffer[(8, 8), "bool"],
87+
U: T.Buffer[(8, 8), "float32"],
88+
) -> None:
89+
for i, j in T.grid(8, 8):
90+
with T.block():
91+
if Y[i, j] > 0.0:
92+
Z[i, j] = X[i, j]
93+
U[i, j] = Y[i, j]
94+
else:
95+
Z[i, j] = not X[i, j]
96+
U[i, j] = 0.0 - Y[i, j]
97+
98+
99+
def negate_vanila(x, y):
100+
z = torch.zeros(8, 8).bool()
101+
for i in range(8):
102+
for j in range(8):
103+
if y[i, j] > 0:
104+
z[i, j] = x[i, j]
105+
else:
106+
z[i, j] = ~x[i, j]
107+
return z
108+
109+
110+
def test_tvmscript_torch_decorator():
111+
q1 = (torch.rand(8, 8) + 0.5).int().bool()
112+
q2 = torch.rand(8, 8) - 0.5
113+
q3 = torch.zeros(8, 8).bool()
114+
q4 = torch.zeros(8, 8)
115+
116+
std1 = negate_vanila(q1, q2)
117+
std2 = torch.abs(q2)
118+
119+
negate_tvmscript(q1, q2, q3, q4)
120+
121+
tvm.testing.assert_allclose(std1.numpy(), q3.numpy(), atol=1e-5, rtol=1e-5)
122+
tvm.testing.assert_allclose(std2.numpy(), q4.numpy(), atol=1e-5, rtol=1e-5)
123+
124+
125+
if __name__ == "__main__":
126+
test_tvmscript_torch_decorator()
127+
test_bool_tensor_negate()
128+
test_sum_up_tensor()
129+
test_tensor_boolean_operation()

cmake/modules/contrib/PT_TVMDSOOP.cmake

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# "License"); you may not use this file except in compliance
77
# with the License. You may obtain a copy of the License at
88
#
9-
# http://www.apache.org/licenses/LICENSE-2.0
9+
# http://www.apache.org/licenses/LICENSE-2.0
1010
#
1111
# Unless required by applicable law or agreed to in writing,
1212
# software distributed under the License is distributed on an
@@ -17,42 +17,80 @@
1717

1818
if(NOT USE_PT_TVMDSOOP STREQUAL "OFF")
1919
find_package(PythonInterp REQUIRED)
20-
2120
execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.__path__[0].strip())"
2221
OUTPUT_VARIABLE PT_PATH
2322
RESULT_VARIABLE PT_STATUS)
24-
if (NOT ${PT_STATUS} EQUAL 0)
23+
24+
if(NOT ${PT_STATUS} EQUAL 0)
2525
message(FATAL_ERROR "Fail to get pytorch path")
2626
endif()
2727

2828
string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}")
2929
message(STATUS "PyTorch path: ${PT_PATH}")
3030

31-
set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0")
31+
execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch;print(torch.compiled_with_cxx11_abi())"
32+
OUTPUT_VARIABLE PT_CXX_FLAG
33+
RESULT_VARIABLE PT_STATUS)
34+
35+
string(REGEX REPLACE "\n" "" PT_CXX_FLAG "${PT_CXX_FLAG}")
36+
message(STATUS "Found TORCH_BUILT_WITH_CXX_ABI=${PT_CXX_FLAG} ")
37+
38+
if(${PT_CXX_FLAG} STREQUAL "False")
39+
set(CXX_ABI_ENABLED 0)
40+
else()
41+
set(CXX_ABI_ENABLED 1)
42+
endif()
43+
44+
set_property(
45+
SOURCE
46+
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc
47+
APPEND PROPERTY
48+
COMPILE_OPTIONS
49+
"-D_GLIBCXX_USE_CXX11_ABI=${CXX_ABI_ENABLED}"
50+
"-I${PT_PATH}/include"
51+
)
52+
53+
set_property(
54+
SOURCE
55+
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/tvm_class.cc
56+
APPEND PROPERTY
57+
COMPILE_OPTIONS
58+
"-I${PT_PATH}/include"
59+
)
60+
3261
set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so -l:libtorch_python.so")
3362

3463
if(NOT USE_CUDA STREQUAL "OFF")
3564
add_definitions(-DPT_TVMDSOOP_ENABLE_GPU)
3665
endif()
3766

38-
3967
string(REGEX REPLACE "\n" " " PT_FLAGS "${PT_COMPILE_FLAGS} ${PT_LINK_FLAGS}")
40-
separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND ${PT_COMPILE_FLAGS_STR})
68+
separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND)
4169
separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR})
4270

71+
# This old version is depereated and will be removed after tvm 0.11
72+
set(LIBRARY_OLD_NAME pt_tvmdsoop)
4373

44-
set(LIBRARY_NAME pt_tvmdsoop)
45-
tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/**/*.cc)
46-
add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS})
74+
# This new library is set for pytorch integration, which solves the c++ abi imcompability issue
75+
set(LIBRARY_NEW_NAME pt_tvmdsoop_new)
76+
tvm_file_glob(GLOB_RECURSE PTTVM_TORCH ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc)
77+
78+
tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/*.cc)
79+
80+
add_library(${LIBRARY_OLD_NAME} SHARED ${PTTVM_SRCS})
81+
add_library(${LIBRARY_NEW_NAME} SHARED ${PTTVM_TORCH})
4782
set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR})
4883

49-
if (NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
50-
add_dependencies(${LIBRARY_NAME} tvm)
84+
if(NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
85+
add_dependencies(${LIBRARY_OLD_NAME} tvm)
86+
add_dependencies(${LIBRARY_NEW_NAME} tvm)
5187
endif()
5288

53-
target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
54-
target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
55-
target_compile_definitions(${LIBRARY_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
89+
target_compile_options(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
90+
target_link_libraries(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
91+
target_compile_definitions(${LIBRARY_OLD_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
5692

93+
target_compile_options(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS})
94+
target_link_libraries(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS})
95+
target_compile_definitions(${LIBRARY_NEW_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
5796
endif()
58-

python/tvm/contrib/torch/__init__.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
"""Module container of Pytorch custom class"""
1919
import os
2020
import platform
21+
import warnings
2122
import torch
2223
from tvm._ffi import libinfo
2324

2425

25-
def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
26+
def _load_platform_specific_library(lib_name):
2627
system = platform.system()
2728
if system == "Darwin":
2829
lib_file_name = lib_name + ".dylib"
@@ -33,11 +34,27 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
3334
lib_path = libinfo.find_lib_path()[0]
3435
lib_dir = os.path.dirname(lib_path)
3536
lib_file_path = os.path.join(lib_dir, lib_file_name)
36-
torch.classes.load_library(lib_file_path)
37+
try:
38+
torch.classes.load_library(lib_file_path)
39+
except OSError as err:
40+
errmsg = str(err)
41+
if errmsg.find("undefined symbol") != -1:
42+
reason = " ".join(
43+
(
44+
"Got undefined symbol error,",
45+
"which might be due to the CXXABI incompatibility.",
46+
)
47+
)
48+
else:
49+
reason = errmsg
50+
warnings.warn(
51+
f"The library {lib_name} is not built successfully. {reason}",
52+
RuntimeWarning,
53+
)
3754

3855

39-
_load_platform_specific_library()
40-
56+
_load_platform_specific_library("libpt_tvmdsoop")
57+
_load_platform_specific_library("libpt_tvmdsoop_new")
4158

4259
from . import module
4360

python/tvm/contrib/torch/module.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""Module container of PyTorch custom class"""
19+
import warnings
1920
from typing import List
21+
2022
import torch
2123

2224

@@ -29,6 +31,11 @@ def shape_repr(cls, input_shapes):
2931
return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)
3032

3133
def __init__(self, num_inputs, num_outputs, device=None):
34+
warnings.warn(
35+
"This module will be removed at TVM version 0.11",
36+
DeprecationWarning,
37+
stacklevel=2,
38+
)
3239
super().__init__()
3340
self.dummy_param = torch.nn.Parameter(torch.empty(0))
3441
self.engine = None
@@ -67,6 +74,11 @@ def shape_repr(cls, input_shapes):
6774
return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)
6875

6976
def __init__(self, num_inputs, num_outputs, device=None):
77+
warnings.warn(
78+
"This module will be removed at TVM version 0.11",
79+
DeprecationWarning,
80+
stacklevel=2,
81+
)
7082
super().__init__()
7183
self.dummy_param = torch.nn.Parameter(torch.empty(0))
7284
self.engine = None
@@ -113,6 +125,11 @@ class TraceTvmModule(torch.nn.Module):
113125
"""
114126

115127
def __init__(self, tvm_module):
128+
warnings.warn(
129+
"This module will be removed at TVM version 0.11",
130+
DeprecationWarning,
131+
stacklevel=2,
132+
)
116133
super().__init__()
117134
self.tvm_module = tvm_module
118135

python/tvm/contrib/torch/pytorch_tvm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# pylint: disable=redefined-builtin
2020
"""`compile` api that convert torch module to torch tvm module"""
2121
import os
22+
import warnings
2223
import tvm
2324
import tvm.testing
2425
from tvm import relay, autotvm
@@ -183,6 +184,16 @@ def load_tvm(self, export_dir):
183184

184185
def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None):
185186
"""Build pytorch module containing TVM Graph Module"""
187+
warnings.warn(
188+
" ".join(
189+
(
190+
"This function will be removed at TVM version 0.11,",
191+
"we suggest users to use `optimized_torch` for tuning Torch modules instead.",
192+
)
193+
),
194+
DeprecationWarning,
195+
stacklevel=2,
196+
)
186197
assert self.export_dir, "you must build_tvm or load_tvm before"
187198
input_infos = input_infos or self.input_infos
188199
assert input_infos
@@ -224,6 +235,16 @@ def compile(script_module, option):
224235
pytorch_tvm_module = compile(script_module, option)
225236
pytorch_tvm_module("model_tvm.pt")
226237
"""
238+
warnings.warn(
239+
" ".join(
240+
(
241+
"This function will be removed at TVM version 0.11,",
242+
"we suggest users to use `optimized_torch` for tuning Torch modules instead.",
243+
)
244+
),
245+
DeprecationWarning,
246+
stacklevel=2,
247+
)
227248
input_infos = option["input_infos"]
228249
default_dtype = option.get("default_dtype", "float32")
229250
export_dir = option.get("export_dir", "pytorch_compiled")

0 commit comments

Comments
 (0)