Skip to content

Commit be792b9

Browse files
chenbohua3facebook-github-bot
authored andcommitted
Make 'ROIAlign' & 'ROIAlignV2' version of ROIPooler scriptable.
Summary: Pull Request resolved: #1835 Reviewed By: rbgirshick Differential Revision: D22819550 Pulled By: ppwwyyxx fbshipit-source-id: 85cd2198676289e0ab02678f221b97887e543395
1 parent af866c4 commit be792b9

File tree

2 files changed

+71
-15
lines changed

2 files changed

+71
-15
lines changed

detectron2/modeling/poolers.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,25 @@
77
from torchvision.ops import RoIPool
88

99
from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple
10+
from detectron2.structures import Boxes
11+
12+
"""
13+
To export ROIPooler to torchscript, in this file, variables that should be annotated with
14+
`Union[List[Boxes], List[RotatedBoxes]]` are only annotated with `List[Boxes]`.
15+
16+
TODO: Correct these annotations when torchscript support `Union`.
17+
https://github.com/pytorch/pytorch/issues/41412
18+
"""
1019

1120
__all__ = ["ROIPooler"]
1221

1322

1423
def assign_boxes_to_levels(
15-
box_lists, min_level: int, max_level: int, canonical_box_size: int, canonical_level: int
24+
box_lists: List[Boxes],
25+
min_level: int,
26+
max_level: int,
27+
canonical_box_size: int,
28+
canonical_level: int,
1629
):
1730
"""
1831
Map each box in `box_lists` to a feature map level index and return the assignment
@@ -35,19 +48,25 @@ def assign_boxes_to_levels(
3548
`self.min_level`, for the corresponding box (so value i means the box is at
3649
`self.min_level + i`).
3750
"""
38-
eps = sys.float_info.epsilon
3951
box_sizes = torch.sqrt(cat([boxes.area() for boxes in box_lists]))
4052
# Eqn.(1) in FPN paper
4153
level_assignments = torch.floor(
42-
canonical_level + torch.log2(box_sizes / canonical_box_size + eps)
54+
canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8)
4355
)
4456
# clamp level to (min, max), in case the box size is too large or too small
4557
# for the available feature maps
4658
level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
4759
return level_assignments.to(torch.int64) - min_level
4860

4961

50-
def convert_boxes_to_pooler_format(box_lists):
62+
def _fmt_box_list(box_tensor, batch_index: int):
63+
repeated_index = torch.full(
64+
(len(box_tensor), 1), batch_index, dtype=box_tensor.dtype, device=box_tensor.device
65+
)
66+
return cat((repeated_index, box_tensor), dim=1)
67+
68+
69+
def convert_boxes_to_pooler_format(box_lists: List[Boxes]):
5170
"""
5271
Convert all boxes in `box_lists` to the low-level format used by ROI pooling ops
5372
(see description under Returns).
@@ -70,15 +89,8 @@ def convert_boxes_to_pooler_format(box_lists):
7089
where batch index is the index in [0, N) identifying which batch image the
7190
rotated box (x_ctr, y_ctr, width, height, angle_degrees) comes from.
7291
"""
73-
74-
def fmt_box_list(box_tensor, batch_index):
75-
repeated_index = torch.full(
76-
(len(box_tensor), 1), batch_index, dtype=box_tensor.dtype, device=box_tensor.device
77-
)
78-
return cat((repeated_index, box_tensor), dim=1)
79-
8092
pooler_fmt_boxes = cat(
81-
[fmt_box_list(box_list.tensor, i) for i, box_list in enumerate(box_lists)], dim=0
93+
[_fmt_box_list(box_list.tensor, i) for i, box_list in enumerate(box_lists)], dim=0
8294
)
8395

8496
return pooler_fmt_boxes
@@ -176,7 +188,7 @@ def __init__(
176188
assert canonical_box_size > 0
177189
self.canonical_box_size = canonical_box_size
178190

179-
def forward(self, x: List[torch.Tensor], box_lists):
191+
def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
180192
"""
181193
Args:
182194
x (list[Tensor]): A list of feature maps of NCHW shape, with scales matching those
@@ -226,9 +238,9 @@ def forward(self, x: List[torch.Tensor], box_lists):
226238
(num_boxes, num_channels, output_size, output_size), dtype=dtype, device=device
227239
)
228240

229-
for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)):
241+
for level, pooler in enumerate(self.level_poolers):
230242
inds = nonzero_tuple(level_assignments == level)[0]
231243
pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
232-
output[inds] = pooler(x_level, pooler_fmt_boxes_level)
244+
output[inds] = pooler(x[level], pooler_fmt_boxes_level)
233245

234246
return output

tests/modeling/test_roi_pooler.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from detectron2.modeling.poolers import ROIPooler
77
from detectron2.structures import Boxes, RotatedBoxes
8+
from detectron2.utils.env import TORCH_VERSION
89

910
logger = logging.getLogger(__name__)
1011

@@ -80,6 +81,49 @@ def test_roialignv2_roialignrotated_match_cpu(self):
8081
def test_roialignv2_roialignrotated_match_cuda(self):
8182
self._test_roialignv2_roialignrotated_match(device="cuda")
8283

84+
def _test_scriptability(self, device):
85+
pooler_resolution = 14
86+
canonical_level = 4
87+
canonical_scale_factor = 2 ** canonical_level
88+
pooler_scales = (1.0 / canonical_scale_factor,)
89+
sampling_ratio = 0
90+
91+
N, C, H, W = 2, 4, 10, 8
92+
N_rois = 10
93+
std = 11
94+
mean = 0
95+
feature = (torch.rand(N, C, H, W) - 0.5) * 2 * std + mean
96+
97+
features = [feature.to(device)]
98+
99+
rois = []
100+
for _ in range(N):
101+
boxes = self._rand_boxes(
102+
num_boxes=N_rois, x_max=W * canonical_scale_factor, y_max=H * canonical_scale_factor
103+
)
104+
105+
rois.append(Boxes(boxes).to(device))
106+
107+
roialignv2_pooler = ROIPooler(
108+
output_size=pooler_resolution,
109+
scales=pooler_scales,
110+
sampling_ratio=sampling_ratio,
111+
pooler_type="ROIAlignV2",
112+
)
113+
114+
roialignv2_out = roialignv2_pooler(features, rois)
115+
scripted_roialignv2_out = torch.jit.script(roialignv2_pooler)(features, rois)
116+
self.assertTrue(torch.equal(roialignv2_out, scripted_roialignv2_out))
117+
118+
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
119+
def test_scriptability_cpu(self):
120+
self._test_scriptability(device="cpu")
121+
122+
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
123+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
124+
def test_scriptability_gpu(self):
125+
self._test_scriptability(device="cuda")
126+
83127

84128
if __name__ == "__main__":
85129
unittest.main()

0 commit comments

Comments
 (0)