-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathplugin_k210.py
309 lines (241 loc) · 10.8 KB
/
plugin_k210.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
# coding=utf-8
'''
@ Summary: how to use k210
@ Update:
@ file: k210.py
@ version: 1.0.0
@ Author: [email protected]
@ Date: 2021/2/5 11:44
@ Update: Load rt_ai_<model_name>_model.c from Documents to projects/applications
将会被移除,在稳定版
@ Date: 2021/02/23
'''
import os
import sys
import re
import logging
import platform
import subprocess
from pathlib import Path
path = os.path.dirname(__file__)
sys.path.append(os.path.join(path, '../../'))
from platforms.plugin_k210 import generate_rt_ai_model_h
class Plugin(object):
def __init__(self, opt):
# aitools base parsers
self.project = opt.project
self.model_path = opt.model
self.platform = opt.platform
# plugin_k210 parser part1: nncase parser
self.inference_type = opt.inference_type
self.dataset = opt.dataset
self.dataset_format = opt.dataset_format
self.dump_weights_range = opt.dump_weights_range
self.weights_quantize_threshold = opt.weights_quantize_threshold
self.output_quantize_threshold = opt.output_quantize_threshold
self.no_quantized_binary = opt.no_quantized_binary
# To support the case of inference_type is uint8, but input type need to be float32.
self.input_type = opt.input_type if opt.input_type else self.inference_type
self.input_std = opt.input_std # input mean, default is 0.00000
self.input_mean = opt.input_mean # input std, default is 1.00000
# plugin_k210 parser part2
self.embed_gcc = opt.embed_gcc
self.ext_tools = opt.ext_tools
self.rt_ai_example = opt.rt_ai_example
self.convert_report = opt.convert_report
self.model_types = opt.model_types
self.network = opt.network
self.clear = opt.clear
kmodel_name = self.is_support_model_type(self.model_types, self.model_path)
self.kmodel_name = opt.model_name if opt.model_name else kmodel_name
self.kmodel_path = Path(__file__).parent / f"{self.kmodel_name}.kmodel"
def is_support_model_type(self, model_types, model):
supported_model = model_types.split()
model = Path(model)
assert model.suffix[1:] in supported_model, f"The {model.name} is not supported now!!!"
logging.info(f"The model is {model.name}")
return model.stem
def excute_cmd(self, cmd, is_realtime=False):
""" Returnning string after the command is executed """
result = list()
screenData = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
if is_realtime:
while True:
line = screenData.stdout.readline()
line_str = line.decode('utf-8').strip()
result.append(line_str)
print(f"\t{line_str}")
if line == b'' or subprocess.Popen.poll(screenData) == 0:
screenData.stdout.close()
break
else:
result.append(screenData.stdout.read())
screenData.stdout.close()
return result
def set_env(self, ncc_path):
""" set ncc.exe path """
assert Path(ncc_path).exists(), "No {} here".format(ncc_path)
# set nncase env
sysstr = platform.system()
if(sysstr =="Windows"):
os.environ["PATH"] += (";" + ncc_path)
elif(sysstr == "Linux"):
self.excute_cmd("chmod +x platforms/plugin_k210/k_tools/ncc")
os.environ["PATH"] = ncc_path
else:
raise Exception("wrong system...")
# validate
ncc_info = self.excute_cmd("ncc --version")
if not ncc_info:
raise Exception("Set nncase env wrong!!!")
ncc_version_info = ncc_info[0].decode().strip()
logging.info(f"ncc {ncc_version_info}...")
return ncc_info
def convert_kmodel(self, model, kmodel_path, inference_type, dataset, dataset_format,
convert_report):
""" convert your model to kmodel"""
model = Path(model)
assert model.exists(), FileNotFoundError("No model found, pls check the model path!!!")
# nncase: comvert model to kmodel
# base cmd
base_cmd = f"ncc compile {model} {kmodel_path} -i {model.suffix[1:]} -t k210 " \
f"--inference-type {inference_type} --input-type {self.input_type}"
# add --dump-weights-range
base_cmd = f"{base_cmd} --dump-weights-range " if self.dump_weights_range else base_cmd
if not self.no_quantized_binary and inference_type == "uint8": # quantization
convert_cmd = f"{base_cmd} --dataset {dataset} --dataset-format {dataset_format} " \
f"--weights-quantize-threshold {self.weights_quantize_threshold} " \
f"--output-quantize-threshold {self.output_quantize_threshold} " \
f"--input-std {self.input_std} --input-mean {self.input_mean}"
else: # not quantization
convert_cmd = base_cmd
cmd_out = self.excute_cmd(convert_cmd, is_realtime=True)
report = "\n".join(cmd_out)
with open(convert_report, "w+") as f:
f.write(report)
if not kmodel_path.exists():
raise Exception("Model convert to kmodel failed!!!")
logging.info("Convert model to kmodel successfully...")
def hex_read_model(self, kmodel_path, project, model):
""" save model with hex """
output_path = Path(project) / "applications"
f = open(kmodel_path, "rb")
# lenth of bytes
count_bytes = 0
s = f.read(1)
# save file contects, hexadecimal
result = list()
while s:
byte = ord(s)
count_bytes += 1
result.append("0x%02x, " % (byte))
if count_bytes % 12 == 0:
result.append("\n")
s = f.read(1)
f.close()
head = f"unsigned char {kmodel_path.stem}_kmodel[] = {'{'}\n"
tail = f"unsigned int {kmodel_path.stem}_{kmodel_path.suffix[1:]}_len = {count_bytes};\n"
result = [head] + result + ["};\n"] + [tail]
# hex kmodel
model_c = output_path / (f"{model}_kmodel.c")
with model_c.open("w+") as fw:
fw.write("".join(result))
logging.info("Save hex kmodel successfully...")
return result
def update_network_name(self, info_file, new_example_file, default_name, model_name):
""" replace old_name by new_name """
# load file
with info_file.open() as fr:
lines = fr.read()
if default_name != model_name:
old_name_list = [default_name, default_name.upper()]
new_name_list = [model_name, model_name.upper()]
# replace file
for i in range(len(old_name_list)):
lines = re.sub(old_name_list[i], new_name_list[i], lines)
# save new example file
with new_example_file.open("w") as fw:
fw.write(lines)
return new_example_file
def load_rt_ai_example(self, rt_ai_example, project, old_name, new_name, platform):
""" load rt_ai_<model_name>_model.c from Documents"""
rt_ai_example = Path(rt_ai_example)
# k210.c
k210_c_file = rt_ai_example / f"{platform}.c"
# network info
example_file = Path(project) / f"applications/rt_ai_{new_name}_model.c"
if example_file.exists():
example_file.unlink()
self.update_network_name(k210_c_file, example_file, old_name, new_name)
logging.info("Generate rt_ai_facelandmark_model.c successfully...")
def set_gcc_path(self, project, embed_gcc):
""" set GNU Compiler Toolchain """
def clear_gcc_path(lines, index=0):
while (index < len(lines)):
if "os.environ['RTT_EXEC_PATH']" in lines[index]:
lines.remove(lines[index])
break
index += 1
return lines
rtconfig_py = os.path.join(project, "rtconfig.py")
with open(rtconfig_py, "r+") as fr:
lines = fr.readlines()
lines = clear_gcc_path(lines)
if embed_gcc:
assert os.path.exists(embed_gcc), "No GNU Compiler Toolchain found???"
set_embed_gcc_env = f"os.environ['RTT_EXEC_PATH'] = r'{embed_gcc}'"
for index, line in enumerate(lines):
if "RTT_EXEC_PATH" in line:
lines = lines[:index - 1] + ["\n", set_embed_gcc_env, "\n"] + lines[index:]
break
with open(rtconfig_py, "w+") as fw:
fw.write("".join(lines))
logging.info("Set GNU Compiler Toolchain successfully...")
def run_plugin(self):
# 1. set nncase env
self.set_env(self.ext_tools)
# 2.1 convert model to kmodel
self.convert_kmodel(self.model_path, self.kmodel_path, self.inference_type, self.dataset,
self.dataset_format, self.convert_report)
# 2.2 save kmodel with hex
self.hex_read_model(self.kmodel_path, self.project, self.kmodel_name)
# 3.1 generate rt_ai_<model_name>_model.h
_ = generate_rt_ai_model_h.rt_ai_model_gen(self.convert_report, self.project, self.kmodel_name)
# 3.2 laod rt_ai_<model_name>_model.c
self.load_rt_ai_example(self.rt_ai_example, self.project, self.network, self.kmodel_name, self.platform)
# 4. set GNU Compiler Toolchain
self.set_gcc_path(self.project, self.embed_gcc)
# 5. remove convert_report.txt or not
if os.path.exists(self.convert_report) and self.clear:
os.remove(self.convert_report)
os.remove(self.kmodel_path)
return True
if __name__ == "__main__":
import shutil
logging.getLogger().setLevel(logging.INFO)
tmp_project = Path("tmp_cwd")
app_path = tmp_project / "applications"
config_path = "D:\Project\K210_Demo\PersonDetection\k210-person-template/rtconfig.py"
if not app_path.exists():
app_path.mkdir(parents=True)
shutil.copy(config_path, tmp_project)
class Opt():
def __init__(self):
self.project = r"tmp_cwd"
self.model = "../../Model/facelandmark.tflite"
self.platform = "k210"
self.rt_ai_example = "../../Documents"
self.model_name = "facelandmark"
# k210
self.embed_gcc = r"D:\Project\k210_third_tools\xpack-riscv-none-embed-gcc-8.3.0-1.2\bin"
self.ext_tools = r"./k_tools"
self.inference_type = "uint8"
self.model_types = "tflite caffe onnx"
self.dataset_format = "image"
self.convert_report = "./convert_report.txt"
self.dataset = "./datasets/images"
self.network = "facelandmark"
self.clear = False
opt = Opt()
k210 = Plugin(opt)
_ = k210.run_plugin()