2
2
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
3
import argparse
4
4
import os
5
+ import onnx
5
6
import torch
6
7
7
8
from detectron2 .checkpoint import DetectionCheckpointer
8
9
from detectron2 .config import get_cfg
9
10
from detectron2 .data import build_detection_test_loader
10
11
from detectron2 .evaluation import COCOEvaluator , inference_on_dataset , print_csv_format
11
- from detectron2 .export import add_export_config , export_caffe2_model
12
+ from detectron2 .export import Caffe2Tracer , add_export_config
12
13
from detectron2 .modeling import build_model
13
14
from detectron2 .utils .logger import setup_logger
14
15
@@ -28,10 +29,16 @@ def setup_cfg(args):
28
29
29
30
30
31
if __name__ == "__main__" :
31
- parser = argparse .ArgumentParser (description = "Convert a model to Caffe2" )
32
+ parser = argparse .ArgumentParser (description = "Convert a model using caffe2 tracing." )
33
+ parser .add_argument (
34
+ "--format" ,
35
+ choices = ["caffe2" , "onnx" , "torchscript" ],
36
+ help = "output format" ,
37
+ default = "caffe2" ,
38
+ )
32
39
parser .add_argument ("--config-file" , default = "" , metavar = "FILE" , help = "path to config file" )
33
40
parser .add_argument ("--run-eval" , action = "store_true" )
34
- parser .add_argument ("--output" , help = "output directory for the converted caffe2 model" )
41
+ parser .add_argument ("--output" , help = "output directory for the converted model" )
35
42
parser .add_argument (
36
43
"opts" ,
37
44
help = "Modify config options using the command-line" ,
@@ -41,6 +48,7 @@ def setup_cfg(args):
41
48
args = parser .parse_args ()
42
49
logger = setup_logger ()
43
50
logger .info ("Command line arguments: " + str (args ))
51
+ os .makedirs (args .output , exist_ok = True )
44
52
45
53
cfg = setup_cfg (args )
46
54
@@ -53,13 +61,35 @@ def setup_cfg(args):
53
61
first_batch = next (iter (data_loader ))
54
62
55
63
# convert and save caffe2 model
56
- caffe2_model = export_caffe2_model (cfg , torch_model , first_batch )
57
- caffe2_model .save_protobuf (args .output )
58
- # draw the caffe2 graph
59
- caffe2_model .save_graph (os .path .join (args .output , "model.svg" ), inputs = first_batch )
64
+ tracer = Caffe2Tracer (cfg , torch_model , first_batch )
65
+ if args .format == "caffe2" :
66
+ caffe2_model = tracer .export_caffe2 ()
67
+ caffe2_model .save_protobuf (args .output )
68
+ # draw the caffe2 graph
69
+ caffe2_model .save_graph (os .path .join (args .output , "model.svg" ), inputs = first_batch )
70
+ elif args .format == "onnx" :
71
+ onnx_model = tracer .export_onnx ()
72
+ onnx .save (onnx_model , os .path .join (args .output , "model.onnx" ))
73
+ elif args .format == "torchscript" :
74
+ script_model = tracer .export_torchscript ()
75
+ script_model .save (os .path .join (args .output , "model.ts" ))
76
+
77
+ # Recursively print IR of all modules
78
+ with open (os .path .join (args .output , "model_ts_IR.txt" ), "w" ) as f :
79
+ try :
80
+ f .write (script_model ._actual_script_module ._c .dump_to_str (True , False , False ))
81
+ except AttributeError :
82
+ pass
83
+ # Print IR of the entire graph (all submodules inlined)
84
+ with open (os .path .join (args .output , "model_ts_IR_inlined.txt" ), "w" ) as f :
85
+ f .write (str (script_model .inlined_graph ))
86
+ # Print the model structure in pytorch style
87
+ with open (os .path .join (args .output , "model.txt" ), "w" ) as f :
88
+ f .write (str (script_model ))
60
89
61
90
# run evaluation with the converted model
62
91
if args .run_eval :
92
+ assert args .format == "caffe2" , "Python inference in other format is not yet supported."
63
93
dataset = cfg .DATASETS .TEST [0 ]
64
94
data_loader = build_detection_test_loader (cfg , dataset )
65
95
# NOTE: hard-coded evaluator. change to the evaluator for your dataset
0 commit comments