Skip to content

Commit f13fed5

Browse files
authored
[Format] Convert all Python code w/o CI (#6448)
* Add black setup * Tweak pyproject.toml * Fix syntax issues * Fix * Tweak * Black all Python code
1 parent 01460e0 commit f13fed5

File tree

1,013 files changed

+60654
-43142
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,013 files changed

+60654
-43142
lines changed

apps/android_camera/models/prepare_model.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,52 +28,62 @@
2828
from tvm.contrib import util, ndk, graph_runtime as runtime
2929
from tvm.contrib.download import download_testdata, download
3030

31-
target = 'llvm -mtriple=arm64-linux-android'
31+
target = "llvm -mtriple=arm64-linux-android"
3232
target_host = None
3333

34+
3435
def del_dir(target: Union[Path, str], only_if_empty: bool = False):
3536
target = Path(target).expanduser()
3637
assert target.is_dir()
37-
for p in sorted(target.glob('**/*'), reverse=True):
38+
for p in sorted(target.glob("**/*"), reverse=True):
3839
if not p.exists():
3940
continue
4041
p.chmod(0o666)
4142
if p.is_dir():
4243
p.rmdir()
4344
else:
4445
if only_if_empty:
45-
raise RuntimeError(f'{p.parent} is not empty!')
46+
raise RuntimeError(f"{p.parent} is not empty!")
4647
p.unlink()
4748
target.rmdir()
4849

50+
4951
def get_model(model_name, batch_size=1):
50-
if model_name == 'resnet18_v1':
52+
if model_name == "resnet18_v1":
5153
import mxnet as mx
5254
from mxnet import gluon
5355
from mxnet.gluon.model_zoo import vision
56+
5457
gluon_model = vision.get_model(model_name, pretrained=True)
5558
img_size = 224
5659
data_shape = (batch_size, 3, img_size, img_size)
5760
net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
5861
return (net, params)
59-
elif model_name == 'mobilenet_v2':
62+
elif model_name == "mobilenet_v2":
6063
import keras
6164
from keras.applications.mobilenet_v2 import MobileNetV2
65+
6266
keras.backend.clear_session() # Destroys the current TF graph and creates a new one.
63-
weights_url = ''.join(['https://github.com/JonathanCMitchell/',
64-
'mobilenet_v2_keras/releases/download/v1.1/',
65-
'mobilenet_v2_weights_tf_dim_ordering_tf_kernels_0.5_224.h5'])
66-
weights_file = 'mobilenet_v2_weights.h5'
67-
weights_path = download_testdata(weights_url, weights_file, module='keras')
68-
keras_mobilenet_v2 = MobileNetV2(alpha=0.5, include_top=True, weights=None,
69-
input_shape=(224, 224, 3), classes=1000)
67+
weights_url = "".join(
68+
[
69+
"https://github.com/JonathanCMitchell/",
70+
"mobilenet_v2_keras/releases/download/v1.1/",
71+
"mobilenet_v2_weights_tf_dim_ordering_tf_kernels_0.5_224.h5",
72+
]
73+
)
74+
weights_file = "mobilenet_v2_weights.h5"
75+
weights_path = download_testdata(weights_url, weights_file, module="keras")
76+
keras_mobilenet_v2 = MobileNetV2(
77+
alpha=0.5, include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000
78+
)
7079
keras_mobilenet_v2.load_weights(weights_path)
71-
80+
7281
img_size = 224
7382
data_shape = (batch_size, 3, img_size, img_size)
74-
mod, params = relay.frontend.from_keras(keras_mobilenet_v2, {'input_1': data_shape})
83+
mod, params = relay.frontend.from_keras(keras_mobilenet_v2, {"input_1": data_shape})
7584
return (mod, params)
7685

86+
7787
def main(model_str, output_path):
7888
if output_path.exists():
7989
del_dir(output_path)
@@ -90,34 +100,40 @@ def main(model_str, output_path):
90100
with tvm.transform.PassContext(opt_level=3):
91101
graph, lib, params = relay.build(net, target, target_host=target_host, params=params)
92102
print("dumping lib...")
93-
lib.export_library(output_path_str + '/' + 'deploy_lib_cpu.so', ndk.create_shared)
103+
lib.export_library(output_path_str + "/" + "deploy_lib_cpu.so", ndk.create_shared)
94104
print("dumping graph...")
95-
with open(output_path_str + '/' + 'deploy_graph.json', 'w') as f:
105+
with open(output_path_str + "/" + "deploy_graph.json", "w") as f:
96106
f.write(graph)
97107
print("dumping params...")
98-
with open(output_path_str + '/' + 'deploy_param.params', 'wb') as f:
108+
with open(output_path_str + "/" + "deploy_param.params", "wb") as f:
99109
f.write(relay.save_param_dict(params))
100110
print("dumping labels...")
101-
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
102-
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
103-
'596b27d23537e5a1b5751d2b0481ef172f58b539/',
104-
'imagenet1000_clsid_to_human.txt'])
105-
synset_path = output_path_str + '/image_net_labels'
106-
download(synset_url, output_path_str + '/image_net_labels')
111+
synset_url = "".join(
112+
[
113+
"https://gist.githubusercontent.com/zhreshold/",
114+
"4d0b62f3d01426887599d4f7ede23ee5/raw/",
115+
"596b27d23537e5a1b5751d2b0481ef172f58b539/",
116+
"imagenet1000_clsid_to_human.txt",
117+
]
118+
)
119+
synset_path = output_path_str + "/image_net_labels"
120+
download(synset_url, output_path_str + "/image_net_labels")
107121
with open(synset_path) as fi:
108122
synset = eval(fi.read())
109-
with open(output_path_str + '/image_net_labels.json', "w") as fo:
123+
with open(output_path_str + "/image_net_labels.json", "w") as fo:
110124
json.dump(synset, fo, indent=4)
111125
os.remove(synset_path)
112126

113-
if __name__ == '__main__':
114-
if environ.get('TVM_NDK_CC') is None:
127+
128+
if __name__ == "__main__":
129+
if environ.get("TVM_NDK_CC") is None:
115130
raise RuntimeError("Require environment variable TVM_NDK_CC")
116-
models_path = Path().absolute().parent.joinpath('app/src/main/assets/models/')
131+
models_path = Path().absolute().parent.joinpath("app/src/main/assets/models/")
117132
if not models_path.exists():
118133
models_path.mkdir()
119-
models = {'mobilenet_v2': models_path.joinpath('mobilenet_v2'),
120-
'resnet18_v1': models_path.joinpath('resnet18_v1')
121-
}
134+
models = {
135+
"mobilenet_v2": models_path.joinpath("mobilenet_v2"),
136+
"resnet18_v1": models_path.joinpath("resnet18_v1"),
137+
}
122138
for model, output_path in models.items():
123139
main(model, output_path)

apps/android_rpc/tests/android_rpc_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@
4343
# whether enable to execute test on Vulkan target
4444
test_vulkan = False
4545

46+
4647
def test_rpc_module():
4748
# graph
4849
n = tvm.runtime.convert(1024)
49-
A = te.placeholder((n,), name='A')
50-
B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
50+
A = te.placeholder((n,), name="A")
51+
B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
5152
a_np = np.random.uniform(size=1024).astype(A.dtype)
5253
temp = util.tempdir()
5354

5455
# Establish remote connection with target hardware
5556
tracker = rpc.connect_tracker(tracker_host, tracker_port)
56-
remote = tracker.request(key, priority=0,
57-
session_timeout=60)
57+
remote = tracker.request(key, priority=0, session_timeout=60)
5858

5959
# Compile the Graph for CPU target
6060
s = te.create_schedule(B.op)
@@ -67,15 +67,15 @@ def test_rpc_module():
6767
f.export_library(path_dso_cpu, ndk.create_shared)
6868

6969
# Execute the portable graph on cpu target
70-
print('Run CPU test ...')
70+
print("Run CPU test ...")
7171
ctx = remote.cpu(0)
7272
remote.upload(path_dso_cpu)
7373
f2 = remote.load_module("cpu_lib.so")
7474
a = tvm.nd.array(a_np, ctx)
7575
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
7676
time_f = f2.time_evaluator(f2.entry_name, ctx, number=10)
7777
cost = time_f(a, b).mean
78-
print('%g secs/op\n' % cost)
78+
print("%g secs/op\n" % cost)
7979
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
8080

8181
# Compile the Graph for OpenCL target
@@ -90,15 +90,15 @@ def test_rpc_module():
9090
path_dso_cl = temp.relpath("dev_lib_cl.so")
9191
f.export_library(path_dso_cl, ndk.create_shared)
9292

93-
print('Run GPU(OpenCL Flavor) test ...')
93+
print("Run GPU(OpenCL Flavor) test ...")
9494
ctx = remote.cl(0)
9595
remote.upload(path_dso_cl)
9696
f1 = remote.load_module("dev_lib_cl.so")
9797
a = tvm.nd.array(a_np, ctx)
9898
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
9999
time_f = f1.time_evaluator(f1.entry_name, ctx, number=10)
100100
cost = time_f(a, b).mean
101-
print('%g secs/op\n' % cost)
101+
print("%g secs/op\n" % cost)
102102
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
103103

104104
# Compile the Graph for Vulkan target
@@ -113,15 +113,15 @@ def test_rpc_module():
113113
path_dso_vulkan = temp.relpath("dev_lib_vulkan.so")
114114
f.export_library(path_dso_vulkan, ndk.create_shared)
115115

116-
print('Run GPU(Vulkan Flavor) test ...')
116+
print("Run GPU(Vulkan Flavor) test ...")
117117
ctx = remote.vulkan(0)
118118
remote.upload(path_dso_vulkan)
119119
f1 = remote.load_module("dev_lib_vulkan.so")
120120
a = tvm.nd.array(a_np, ctx)
121121
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
122122
time_f = f1.time_evaluator(f1.entry_name, ctx, number=10)
123123
cost = time_f(a, b).mean
124-
print('%g secs/op\n' % cost)
124+
print("%g secs/op\n" % cost)
125125
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
126126

127127

apps/benchmark/arm_cpu_imagenet_bench.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ def evaluate_network(network, target, target_host, repeat):
4040

4141
print_progress("%-20s building..." % network)
4242
with tvm.transform.PassContext(opt_level=3):
43-
graph, lib, params = relay.build(
44-
net, target=target, target_host=target_host, params=params)
43+
graph, lib, params = relay.build(net, target=target, target_host=target_host, params=params)
4544

4645
tmp = tempdir()
47-
if 'android' in str(target):
46+
if "android" in str(target):
4847
from tvm.contrib import ndk
48+
4949
filename = "%s.so" % network
5050
lib.export_library(tmp.relpath(filename), ndk.create_shared)
5151
else:
@@ -60,38 +60,55 @@ def evaluate_network(network, target, target_host, repeat):
6060
rlib = remote.load_module(filename)
6161
module = runtime.create(graph, rlib, ctx)
6262
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
63-
module.set_input('data', data_tvm)
63+
module.set_input("data", data_tvm)
6464
module.set_input(**params)
6565

6666
# evaluate
6767
print_progress("%-20s evaluating..." % network)
6868
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=repeat)
6969
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond
70-
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))
70+
print(
71+
"%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))
72+
)
7173

7274

7375
if __name__ == "__main__":
7476
parser = argparse.ArgumentParser()
75-
parser.add_argument("--network", type=str, choices=
76-
['resnet-18', 'resnet-34', 'resnet-50',
77-
'vgg-16', 'vgg-19', 'densenet-121', 'inception_v3',
78-
'mobilenet', 'squeezenet_v1.0', 'squeezenet_v1.1'],
79-
help='The name of neural network')
80-
parser.add_argument("--model", type=str, choices=
81-
['rk3399', 'mate10', 'mate10pro', 'p20', 'p20pro',
82-
'pixel2', 'rasp3b', 'pynq'], default='rk3399',
83-
help="The model of the test device. If your device is not listed in "
84-
"the choices list, pick the most similar one as argument.")
85-
parser.add_argument("--host", type=str, default='localhost')
77+
parser.add_argument(
78+
"--network",
79+
type=str,
80+
choices=[
81+
"resnet-18",
82+
"resnet-34",
83+
"resnet-50",
84+
"vgg-16",
85+
"vgg-19",
86+
"densenet-121",
87+
"inception_v3",
88+
"mobilenet",
89+
"squeezenet_v1.0",
90+
"squeezenet_v1.1",
91+
],
92+
help="The name of neural network",
93+
)
94+
parser.add_argument(
95+
"--model",
96+
type=str,
97+
choices=["rk3399", "mate10", "mate10pro", "p20", "p20pro", "pixel2", "rasp3b", "pynq"],
98+
default="rk3399",
99+
help="The model of the test device. If your device is not listed in "
100+
"the choices list, pick the most similar one as argument.",
101+
)
102+
parser.add_argument("--host", type=str, default="localhost")
86103
parser.add_argument("--port", type=int, default=9190)
87104
parser.add_argument("--rpc-key", type=str, required=True)
88105
parser.add_argument("--repeat", type=int, default=10)
89106
args = parser.parse_args()
90107

91-
dtype = 'float32'
108+
dtype = "float32"
92109

93110
if args.network is None:
94-
networks = ['squeezenet_v1.1', 'mobilenet', 'resnet-18', 'vgg-16']
111+
networks = ["squeezenet_v1.1", "mobilenet", "resnet-18", "vgg-16"]
95112
else:
96113
networks = [args.network]
97114

@@ -103,4 +120,3 @@ def evaluate_network(network, target, target_host, repeat):
103120
print("--------------------------------------------------")
104121
for network in networks:
105122
evaluate_network(network, target, target_host, args.repeat)
106-

0 commit comments

Comments
 (0)