-
Notifications
You must be signed in to change notification settings - Fork 271
/
onnxconverter.py
67 lines (58 loc) · 2.63 KB
/
onnxconverter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import warnings
from segment_anything.utils.onnx import SamOnnxModel
from segment_anything import sam_model_registry, SamPredictor
import torch
import os
import onnxruntime
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic
VITH_CHECKPOINT = os.environ.get("VITH_CHECKPOINT", "sam_vit_h_4b8939.pth")
def convert(checkpoint_path):
onnx_model_path = os.path.join(os.path.dirname(checkpoint_path), "sam_onnx_example.onnx")
onnx_model_quantized_path = os.path.join(os.path.dirname(checkpoint_path), "sam_onnx_quantized_example.onnx")
if os.path.exists(onnx_model_path) and os.path.exists(onnx_model_quantized_path):
print(f"ONNX model already exists at {onnx_model_path}, {onnx_model_quantized_path}, skipping conversion")
return
sam = sam_model_registry["vit_h"](checkpoint=VITH_CHECKPOINT)
onnx_model = SamOnnxModel(sam, return_single_mask=True)
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(onnx_model_path, "wb") as f:
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=17,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
quantize_dynamic(
model_input=onnx_model_path,
model_output=onnx_model_quantized_path,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
convert(VITH_CHECKPOINT)