Skip to content

Commit 60205f8

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
allow converter to dump torchscript; support torchscript GPU inference
Reviewed By: rbgirshick Differential Revision: D21364163 fbshipit-source-id: 6d83968b483f91df976939d8682031a5c60dd271
1 parent 1e21fa4 commit 60205f8

File tree

5 files changed

+66
-28
lines changed

5 files changed

+66
-28
lines changed

detectron2/export/api.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
1616
from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph
1717

18-
__all__ = ["add_export_config", "export_caffe2_model", "Caffe2Model", "export_onnx_model"]
18+
__all__ = [
19+
"add_export_config",
20+
"export_caffe2_model",
21+
"Caffe2Model",
22+
"export_onnx_model",
23+
"Caffe2Tracer",
24+
]
1925

2026

2127
def add_export_config(cfg):
@@ -47,7 +53,8 @@ class Caffe2Tracer:
4753
3. complicated pre/post processing
4854
4955
This class provides a traceable version of a detectron2 model by:
50-
1. Rewrite parts of the model using ops in caffe2
56+
1. Rewrite parts of the model using ops in caffe2. Note that some ops do
57+
not have GPU implementation.
5158
2. Define the inputs "after pre-processing" as inputs to the model
5259
3. Remove post-processing and produce raw layer outputs
5360
@@ -59,8 +66,6 @@ class Caffe2Tracer:
5966
model to different deployment formats.
6067
6168
The class currently only supports models using builtin meta architectures.
62-
63-
Experimental. Don't use.
6469
"""
6570

6671
def __init__(self, cfg, model, inputs):
@@ -127,7 +132,7 @@ def export_torchscript(self):
127132
logger = logging.getLogger(__name__)
128133
logger.info("Tracing the model with torch.jit.trace ...")
129134
with torch.no_grad():
130-
return torch.jit.trace(model, (inputs,))
135+
return torch.jit.trace(model, (inputs,), optimize=True)
131136

132137

133138
def export_caffe2_model(cfg, model, inputs):

detectron2/export/c10.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,6 @@ def forward(self, images, features, gt_instances=None):
164164
features = [features[f] for f in self.in_features]
165165
objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features)
166166

167-
# TODO is the needed?
168-
# objectness_logits_pred = [t.sigmoid() for t in objectness_logits_pred]
169-
170167
assert isinstance(images, ImageList)
171168
if self.tensor_mode:
172169
im_info = images.image_sizes

tools/deploy/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,10 @@ target_link_libraries(
1212
caffe2_mask_rcnn
1313
"${TORCH_LIBRARIES}" gflags glog ${OpenCV_LIBS})
1414
set_property(TARGET caffe2_mask_rcnn PROPERTY CXX_STANDARD 14)
15+
16+
17+
add_executable(torchscript_traced_mask_rcnn torchscript_traced_mask_rcnn.cpp)
18+
target_link_libraries(
19+
torchscript_traced_mask_rcnn
20+
"${TORCH_LIBRARIES}" ${OpenCV_LIBS})
21+
set_property(TARGET torchscript_traced_mask_rcnn PROPERTY CXX_STANDARD 14)

tools/deploy/caffe2_converter.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
33
import argparse
44
import os
5+
import onnx
56
import torch
67

78
from detectron2.checkpoint import DetectionCheckpointer
89
from detectron2.config import get_cfg
910
from detectron2.data import build_detection_test_loader
1011
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format
11-
from detectron2.export import add_export_config, export_caffe2_model
12+
from detectron2.export import Caffe2Tracer, add_export_config
1213
from detectron2.modeling import build_model
1314
from detectron2.utils.logger import setup_logger
1415

@@ -28,10 +29,16 @@ def setup_cfg(args):
2829

2930

3031
if __name__ == "__main__":
31-
parser = argparse.ArgumentParser(description="Convert a model to Caffe2")
32+
parser = argparse.ArgumentParser(description="Convert a model using caffe2 tracing.")
33+
parser.add_argument(
34+
"--format",
35+
choices=["caffe2", "onnx", "torchscript"],
36+
help="output format",
37+
default="caffe2",
38+
)
3239
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
3340
parser.add_argument("--run-eval", action="store_true")
34-
parser.add_argument("--output", help="output directory for the converted caffe2 model")
41+
parser.add_argument("--output", help="output directory for the converted model")
3542
parser.add_argument(
3643
"opts",
3744
help="Modify config options using the command-line",
@@ -41,6 +48,7 @@ def setup_cfg(args):
4148
args = parser.parse_args()
4249
logger = setup_logger()
4350
logger.info("Command line arguments: " + str(args))
51+
os.makedirs(args.output, exist_ok=True)
4452

4553
cfg = setup_cfg(args)
4654

@@ -53,13 +61,35 @@ def setup_cfg(args):
5361
first_batch = next(iter(data_loader))
5462

5563
# convert and save caffe2 model
56-
caffe2_model = export_caffe2_model(cfg, torch_model, first_batch)
57-
caffe2_model.save_protobuf(args.output)
58-
# draw the caffe2 graph
59-
caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=first_batch)
64+
tracer = Caffe2Tracer(cfg, torch_model, first_batch)
65+
if args.format == "caffe2":
66+
caffe2_model = tracer.export_caffe2()
67+
caffe2_model.save_protobuf(args.output)
68+
# draw the caffe2 graph
69+
caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=first_batch)
70+
elif args.format == "onnx":
71+
onnx_model = tracer.export_onnx()
72+
onnx.save(onnx_model, os.path.join(args.output, "model.onnx"))
73+
elif args.format == "torchscript":
74+
script_model = tracer.export_torchscript()
75+
script_model.save(os.path.join(args.output, "model.ts"))
76+
77+
# Recursively print IR of all modules
78+
with open(os.path.join(args.output, "model_ts_IR.txt"), "w") as f:
79+
try:
80+
f.write(script_model._actual_script_module._c.dump_to_str(True, False, False))
81+
except AttributeError:
82+
pass
83+
# Print IR of the entire graph (all submodules inlined)
84+
with open(os.path.join(args.output, "model_ts_IR_inlined.txt"), "w") as f:
85+
f.write(str(script_model.inlined_graph))
86+
# Print the model structure in pytorch style
87+
with open(os.path.join(args.output, "model.txt"), "w") as f:
88+
f.write(str(script_model))
6089

6190
# run evaluation with the converted model
6291
if args.run_eval:
92+
assert args.format == "caffe2", "Python inference in other format is not yet supported."
6393
dataset = cfg.DATASETS.TEST[0]
6494
data_loader = build_detection_test_loader(cfg, dataset)
6595
# NOTE: hard-coded evaluator. change to the evaluator for your dataset

tools/deploy/torchscript_traced_mask_rcnn.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
using namespace std;
1111

12-
// Experimental. Don't use.
12+
// experimental. don't use
1313
int main(int argc, const char* argv[]) {
1414
if (argc != 3) {
1515
return 1;
@@ -19,26 +19,25 @@ int main(int argc, const char* argv[]) {
1919
torch::autograd::AutoGradMode guard(false);
2020
auto module = torch::jit::load(argv[1]);
2121

22+
assert(module.buffers().size() > 0);
23+
// Assume that the entire model is on the same device.
24+
// We just put input to this device.
25+
auto device = (*begin(module.buffers())).device();
26+
2227
cv::Mat input_img = cv::imread(image_file, cv::IMREAD_COLOR);
2328
const int height = input_img.rows;
2429
const int width = input_img.cols;
2530
// FPN models require divisibility of 32
2631
assert(height % 32 == 0 && width % 32 == 0);
27-
const int batch = 1;
2832
const int channels = 3;
2933

30-
auto input = torch::empty({1, channels, height, width});
31-
float* ptr = input.data_ptr<float>();
32-
// HWC to CHW
33-
for (int c = 0; c < 3; ++c) {
34-
for (int i = 0; i < height * width; ++i) {
35-
ptr[c * height * width + i] =
36-
static_cast<float>(input_img.data[3 * i + c]);
37-
}
38-
}
34+
auto input = torch::from_blob(
35+
input_img.data, {1, height, width, channels}, torch::kUInt8);
36+
// NHWC to NCHW
37+
input = input.to(device, torch::kFloat).permute({0, 3, 1, 2}).contiguous();
3938

40-
float im_info_data[] = {height * 1.0f, width * 1.0f, 1.0f};
41-
auto im_info = torch::from_blob(im_info_data, {1, 3});
39+
std::array<float, 3> im_info_data{height * 1.0f, width * 1.0f, 1.0f};
40+
auto im_info = torch::from_blob(im_info_data.data(), {1, 3}).to(device);
4241

4342
// run the network
4443
auto output = module.forward({std::make_tuple(input, im_info)});

0 commit comments

Comments
 (0)