Skip to content

Commit 902343a

Browse files
[BugFix][UMA] Fix order issue in uma_lower (#12447)
There was a flaw in uma_lower (see issue #12410) that lead in some case to a different argument ordering of the cached_func and the Relay function. This results in an incorrect lowering of the primfunc and eventually a wrong result of a run-time error, in some cases. This commit adds code to correct the described misbehavior and a unit test case to check this end-to-end functionality with a TFLITE model.
1 parent e5e05fe commit 902343a

File tree

5 files changed

+107
-30
lines changed

5 files changed

+107
-30
lines changed

apps/uma/_template/patterns.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,8 @@ def conv2d_pattern():
2323
pattern = is_op("nn.conv2d")(wildcard(), wildcard())
2424
pattern = pattern.has_attr({"strides": [1, 1], "groups": 1})
2525
return pattern
26+
27+
28+
def dense_pattern():
29+
pattern = is_op("nn.dense")(wildcard(), wildcard())
30+
return pattern

python/tvm/relay/backend/contrib/uma/api/lower.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,27 +60,7 @@ def _lower_relay_to_tir(self, relay_prim_func: relay.Function) -> tvm.tir.PrimFu
6060
"""
6161

6262
def _get_tensors(te_cached_func):
63-
outputs = list(te_cached_func.outputs)
64-
stack = []
65-
visited = set()
66-
for output_ in outputs:
67-
if output_ not in visited:
68-
visited.add(output_)
69-
stack.append(output_)
70-
71-
args = []
72-
while len(stack) != 0:
73-
tensor = stack.pop()
74-
if isinstance(tensor.op, tvm.te.tensor.PlaceholderOp):
75-
args.append(tensor)
76-
elif isinstance(tensor.op, tvm.te.tensor.ComputeOp):
77-
inputs = tensor.op.input_tensors
78-
for input_ in inputs:
79-
if input_ not in visited:
80-
visited.add(input_)
81-
stack.append(input_)
82-
83-
return args + outputs
63+
return list(te_cached_func.inputs) + list(te_cached_func.outputs)
8464

8565
lower_to_te = tvm._ffi.get_global_func("relay.backend.LowerToTE")
8666
te_cached_func = lower_to_te(relay_prim_func)

python/tvm/testing/aot.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -571,9 +571,7 @@ def _create_header_file(tensor_name, npy_data, output_path, data_linkage):
571571
header_file.write("};\n\n")
572572

573573

574-
def convert_to_relay(
575-
tflite_model_buf,
576-
):
574+
def convert_to_relay(tflite_model_buf, bind_params_by_name=True):
577575
"""Convert a tflite model buffer in a Relay module"""
578576
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
579577
try:
@@ -588,7 +586,8 @@ def convert_to_relay(
588586
raise ImportError("The tflite package must be installed")
589587

590588
mod, params = relay.frontend.from_tflite(tflite_model)
591-
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
589+
if bind_params_by_name:
590+
mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
592591
return mod, params
593592

594593

@@ -931,20 +930,30 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"):
931930
return dict(zip(output_tensor_names, out))
932931

933932

934-
def create_relay_module_and_inputs_from_tflite_file(tflite_model_file):
933+
def create_relay_module_and_inputs_from_tflite_file(tflite_model_file, bind_params_by_name=True):
935934
"""A helper function to create a Relay IRModule with inputs
936935
and params from a tflite file"""
937936
with open(tflite_model_file, "rb") as f:
938937
tflite_model_buf = f.read()
939-
mod, params = convert_to_relay(tflite_model_buf)
938+
mod, params = convert_to_relay(tflite_model_buf, bind_params_by_name)
940939

941940
inputs = dict()
942941
for param in mod["main"].params:
943942
name = str(param.name_hint)
944943
data_shape = [int(i) for i in param.type_annotation.shape]
945944
dtype = str(param.type_annotation.dtype)
946-
in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max)
947-
data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype)
945+
if np.issubdtype(dtype, np.floating):
946+
# Since np.random.uniform only allows the ranges of float32,
947+
# at first float16 is used and scaled afterwards, if necessary.
948+
in_min, in_max = (np.finfo("float16").min, np.finfo("float16").max)
949+
data = np.random.uniform(low=in_min, high=in_max, size=data_shape).astype(dtype)
950+
scale = np.finfo(dtype).min / np.finfo("float16").min
951+
data *= scale
952+
elif np.issubdtype(dtype, np.integer):
953+
in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max)
954+
data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype)
955+
else:
956+
raise TypeError(f"Type {dtype} not supported")
948957
inputs[name] = data
949958

950959
return mod, inputs, params

tests/python/contrib/test_uma/test_uma_pipeline.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,20 @@
1616
# under the License.
1717

1818
import pytest
19+
20+
pytest.importorskip("tflite")
21+
pytest.importorskip("tensorflow")
22+
23+
import os
24+
import tensorflow as tf
1925
from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER
2026
from tvm.relay import transform, testing
2127
from tvm.testing.aot import (
2228
AOTTestModel,
2329
AOTTestRunner,
2430
generate_ref_data,
2531
compile_and_run,
32+
create_relay_module_and_inputs_from_tflite_file,
2633
)
2734

2835
import tvm
@@ -132,5 +139,80 @@ def test_mobilenet():
132139
)
133140

134141

142+
def test_tflite_model():
143+
"""
144+
End-to-end test of TF-Lite file using UMA
145+
"""
146+
tflite_file = "/tmp/model.tflite"
147+
if os.path.exists(tflite_file):
148+
os.remove(tflite_file)
149+
generate_tflite_file(tflite_file)
150+
151+
pytest.importorskip("tflite")
152+
153+
interpreter = tf.lite.Interpreter(model_path=tflite_file)
154+
tf_model_details = interpreter.get_input_details()
155+
mod, _, params = create_relay_module_and_inputs_from_tflite_file(
156+
tflite_file, bind_params_by_name=False
157+
)
158+
159+
uma_backend = VanillaAcceleratorBackend()
160+
uma_backend.register()
161+
target = tvm.target.Target("vanilla_accelerator", host=tvm.target.Target("c"))
162+
target_c = tvm.target.Target("c")
163+
164+
# Generation of test input and output
165+
data_shape = [int(x) for x in mod["main"].params[0].type_annotation.shape]
166+
data = np.random.uniform(size=data_shape).astype("float32")
167+
input_list = {str(tf_model_details[0]["name"]): data}
168+
output_list = generate_ref_data(mod, input_list, params)
169+
170+
# UMA partitioning (needs to be done after generate_ref_data)
171+
mod = uma_backend.partition(mod)
172+
173+
aot_test_model = AOTTestModel(module=mod, inputs=input_list, outputs=output_list, params=params)
174+
test_runner = AOTTestRunner(
175+
pass_config={"tir.usmp.enable": True, "tir.usmp.algorithm": "greedy_by_size"}
176+
)
177+
178+
compile_and_run(
179+
aot_test_model,
180+
test_runner,
181+
interface_api="c",
182+
use_unpacked_api=True,
183+
workspace_byte_alignment=1,
184+
debug_calculated_workspaces=False,
185+
target=[target_c, target],
186+
)
187+
188+
189+
def generate_tflite_file(tflite_filename):
190+
mnist = tf.keras.datasets.mnist
191+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
192+
x_train, x_test = x_train / 255.0, x_test / 255.0
193+
x_train, x_test = x_train.reshape(-1, 28, 28, 1), x_test.reshape(-1, 28, 28, 1)
194+
tf_model = tf.keras.models.Sequential(
195+
[
196+
tf.keras.Input(shape=(28, 28, 1)),
197+
tf.keras.layers.Conv2D(4, (3, 3), padding="same", activation="relu"),
198+
tf.keras.layers.Flatten(input_shape=(28, 28)),
199+
tf.keras.layers.Dense(32, activation="relu"),
200+
tf.keras.layers.Dropout(0.2),
201+
tf.keras.layers.Dense(10),
202+
]
203+
)
204+
output = tf_model(x_train[:1])
205+
output = output.numpy()
206+
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
207+
loss(y_train[:1], output).numpy()
208+
tf_model.compile(metrics=["accuracy"], optimizer="adam", loss=loss)
209+
tf_model.fit(x_train, y_train, epochs=1)
210+
211+
tflite_converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
212+
tflite_model = tflite_converter.convert()
213+
with open(tflite_filename, "wb") as f:
214+
f.write(tflite_model)
215+
216+
135217
if __name__ == "__main__":
136218
tvm.testing.main()

tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from apps.uma._template.codegen import gen_includes
2626

27-
from apps.uma._template.patterns import conv2d_pattern
27+
from apps.uma._template.patterns import conv2d_pattern, dense_pattern
2828
from tvm.relay.backend.contrib.uma import uma_available
2929

3030
pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available")
@@ -40,6 +40,7 @@ def __init__(self):
4040
# Relay to Relay function registration
4141
#######################################################################
4242
self._register_pattern("conv2d", conv2d_pattern())
43+
self._register_pattern("dense", dense_pattern())
4344

4445
#######################################################################
4546
# Relay to TIR function registration

0 commit comments

Comments
 (0)