Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct the problem when fcn_mask_head takes in invalid bboxes with negative coordinates #42

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 101 additions & 79 deletions mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,52 @@
BYTES_PER_FLOAT = 4
# TODO: This memory limit may be too much or too little. It would be better to
# determine it based on available resources.
GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit
GPU_MEM_LIMIT = 1024 ** 3 # 1 GB memory limit


@HEADS.register_module()
class FCNMaskHead(BaseModule):

def __init__(self,
num_convs=4,
roi_feat_size=14,
in_channels=256,
conv_kernel_size=3,
conv_out_channels=256,
num_classes=80,
class_agnostic=False,
upsample_cfg=dict(type='deconv', scale_factor=2),
conv_cfg=None,
norm_cfg=None,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),
init_cfg=None):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
def __init__(
self,
num_convs=4,
roi_feat_size=14,
in_channels=256,
conv_kernel_size=3,
conv_out_channels=256,
num_classes=80,
class_agnostic=False,
upsample_cfg=dict(type="deconv", scale_factor=2),
conv_cfg=None,
norm_cfg=None,
loss_mask=dict(type="CrossEntropyLoss", use_mask=True, loss_weight=1.0),
init_cfg=None,
):
assert init_cfg is None, (
"To prevent abnormal initialization "
"behavior, init_cfg is not allowed to be set"
)
super(FCNMaskHead, self).__init__(init_cfg)
self.upsample_cfg = upsample_cfg.copy()
if self.upsample_cfg['type'] not in [
None, 'deconv', 'nearest', 'bilinear', 'carafe'
if self.upsample_cfg["type"] not in [
None,
"deconv",
"nearest",
"bilinear",
"carafe",
]:
raise ValueError(
f'Invalid upsample method {self.upsample_cfg["type"]}, '
'accepted methods are "deconv", "nearest", "bilinear", '
'"carafe"')
'"carafe"'
)
self.num_convs = num_convs
# WARN: roi_feat_size is reserved and not used
self.roi_feat_size = _pair(roi_feat_size)
self.in_channels = in_channels
self.conv_kernel_size = conv_kernel_size
self.conv_out_channels = conv_out_channels
self.upsample_method = self.upsample_cfg.get('type')
self.scale_factor = self.upsample_cfg.pop('scale_factor', None)
self.upsample_method = self.upsample_cfg.get("type")
self.scale_factor = self.upsample_cfg.pop("scale_factor", None)
self.num_classes = num_classes
self.class_agnostic = class_agnostic
self.conv_cfg = conv_cfg
Expand All @@ -63,8 +70,7 @@ def __init__(self,

self.convs = ModuleList()
for i in range(self.num_convs):
in_channels = (
self.in_channels if i == 0 else self.conv_out_channels)
in_channels = self.in_channels if i == 0 else self.conv_out_channels
padding = (self.conv_kernel_size - 1) // 2
self.convs.append(
ConvModule(
Expand All @@ -74,37 +80,44 @@ def __init__(self,
padding=padding,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))
norm_cfg=norm_cfg,
)
)
upsample_in_channels = (
self.conv_out_channels if self.num_convs > 0 else in_channels)
self.conv_out_channels if self.num_convs > 0 else in_channels
)
upsample_cfg_ = self.upsample_cfg.copy()
if self.upsample_method is None:
self.upsample = None
elif self.upsample_method == 'deconv':
elif self.upsample_method == "deconv":
upsample_cfg_.update(
in_channels=upsample_in_channels,
out_channels=self.conv_out_channels,
kernel_size=self.scale_factor,
stride=self.scale_factor)
stride=self.scale_factor,
)
self.upsample = build_upsample_layer(upsample_cfg_)
elif self.upsample_method == 'carafe':
elif self.upsample_method == "carafe":
upsample_cfg_.update(
channels=upsample_in_channels, scale_factor=self.scale_factor)
channels=upsample_in_channels, scale_factor=self.scale_factor
)
self.upsample = build_upsample_layer(upsample_cfg_)
else:
# suppress warnings
align_corners = (None
if self.upsample_method == 'nearest' else False)
align_corners = None if self.upsample_method == "nearest" else False
upsample_cfg_.update(
scale_factor=self.scale_factor,
mode=self.upsample_method,
align_corners=align_corners)
align_corners=align_corners,
)
self.upsample = build_upsample_layer(upsample_cfg_)

out_channels = 1 if self.class_agnostic else self.num_classes
logits_in_channel = (
self.conv_out_channels
if self.upsample_method == 'deconv' else upsample_in_channels)
if self.upsample_method == "deconv"
else upsample_in_channels
)
self.conv_logits = Conv2d(logits_in_channel, out_channels, 1)
self.relu = nn.ReLU(inplace=True)
self.debug_imgs = None
Expand All @@ -117,8 +130,7 @@ def init_weights(self):
elif isinstance(m, CARAFEPack):
m.init_weights()
else:
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
nn.init.constant_(m.bias, 0)

@auto_fp16()
Expand All @@ -127,21 +139,20 @@ def forward(self, x):
x = conv(x)
if self.upsample is not None:
x = self.upsample(x)
if self.upsample_method == 'deconv':
if self.upsample_method == "deconv":
x = self.relu(x)
mask_pred = self.conv_logits(x)
return mask_pred

def get_targets(self, sampling_results, gt_masks, rcnn_train_cfg):
pos_proposals = [res.pos_bboxes for res in sampling_results]
pos_assigned_gt_inds = [
res.pos_assigned_gt_inds for res in sampling_results
]
mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
gt_masks, rcnn_train_cfg)
pos_assigned_gt_inds = [res.pos_assigned_gt_inds for res in sampling_results]
mask_targets = mask_target(
pos_proposals, pos_assigned_gt_inds, gt_masks, rcnn_train_cfg
)
return mask_targets

@force_fp32(apply_to=('mask_pred', ))
@force_fp32(apply_to=("mask_pred",))
def loss(self, mask_pred, mask_targets, labels):
"""
Example:
Expand All @@ -166,15 +177,25 @@ def loss(self, mask_pred, mask_targets, labels):
loss_mask = mask_pred.sum()
else:
if self.class_agnostic:
loss_mask = self.loss_mask(mask_pred, mask_targets,
torch.zeros_like(labels))
loss_mask = self.loss_mask(
mask_pred, mask_targets, torch.zeros_like(labels)
)
else:
loss_mask = self.loss_mask(mask_pred, mask_targets, labels)
loss['loss_mask'] = loss_mask
loss["loss_mask"] = loss_mask
return loss

def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
ori_shape, scale_factor, rescale, format=True):
def get_seg_masks(
self,
mask_pred,
det_bboxes,
det_labels,
rcnn_test_cfg,
ori_shape,
scale_factor,
rescale,
format=True,
):
"""Get segmentation masks from mask_pred and bboxes.

Args:
Expand Down Expand Up @@ -228,8 +249,9 @@ class label c.
mask_pred = det_bboxes.new_tensor(mask_pred)

device = mask_pred.device
cls_segms = [[] for _ in range(self.num_classes)
] # BG is not included in num_classes
cls_segms = [
[] for _ in range(self.num_classes)
] # BG is not included in num_classes
bboxes = det_bboxes[:, :4]
labels = det_labels
# No need to consider rescale and scale_factor while exporting to ONNX
Expand All @@ -240,16 +262,12 @@ class label c.
img_h, img_w = ori_shape[:2]
else:
if isinstance(scale_factor, float):
img_h = np.round(ori_shape[0] * scale_factor).astype(
np.int32)
img_w = np.round(ori_shape[1] * scale_factor).astype(
np.int32)
img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
else:
w_scale, h_scale = scale_factor[0], scale_factor[1]
img_h = np.round(ori_shape[0] * h_scale.item()).astype(
np.int32)
img_w = np.round(ori_shape[1] * w_scale.item()).astype(
np.int32)
img_h = np.round(ori_shape[0] * h_scale.item()).astype(np.int32)
img_w = np.round(ori_shape[1] * w_scale.item()).astype(np.int32)
scale_factor = 1.0

if not isinstance(scale_factor, (float, torch.Tensor)):
Expand All @@ -262,22 +280,20 @@ class label c.
if not self.class_agnostic:
box_inds = torch.arange(mask_pred.shape[0])
mask_pred = mask_pred[box_inds, labels][:, None]
masks, _ = _do_paste_mask(
mask_pred, bboxes, img_h, img_w, skip_empty=False)
masks, _ = _do_paste_mask(mask_pred, bboxes, img_h, img_w, skip_empty=False)
if threshold >= 0:
masks = (masks >= threshold).to(dtype=torch.bool)
else:
# TensorRT backend does not have data type of uint8
is_trt_backend = os.environ.get(
'ONNX_BACKEND') == 'MMCVTensorRT'
is_trt_backend = os.environ.get("ONNX_BACKEND") == "MMCVTensorRT"
target_dtype = torch.int32 if is_trt_backend else torch.uint8
masks = (masks * 255).to(dtype=target_dtype)
return masks

N = len(mask_pred)
# The actual implementation split the input into chunks,
# and paste them chunk by chunk.
if device.type == 'cpu':
if device.type == "cpu":
# CPU is most efficient when they are pasted one by one with
# skip_empty=True, so that it performs minimal number of
# operations.
Expand All @@ -286,9 +302,11 @@ class label c.
# GPU benefits from parallelism for larger chunks,
# but may have memory issue
num_chunks = int(
np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
assert (num_chunks <=
N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT)
)
assert (
num_chunks <= N
), "Default GPU_MEM_LIMIT is too small; try increasing it"
chunks = torch.chunk(torch.arange(N, device=device), num_chunks)

threshold = rcnn_test_cfg.mask_thr_binary
Expand All @@ -297,7 +315,8 @@ class label c.
img_h,
img_w,
device=device,
dtype=torch.bool if threshold >= 0 else torch.uint8)
dtype=torch.bool if threshold >= 0 else torch.uint8,
)

if not self.class_agnostic:
mask_pred = mask_pred[range(N), labels][:, None]
Expand All @@ -308,15 +327,16 @@ class label c.
bboxes[inds],
img_h,
img_w,
skip_empty=device.type == 'cpu')
skip_empty=device.type == "cpu",
)

if threshold >= 0:
masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
else:
# for visualization and debugging
masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)

im_mask[(inds, ) + spatial_inds] = masks_chunk
im_mask[(inds,) + spatial_inds] = masks_chunk

for i in range(N):
cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())
Expand Down Expand Up @@ -353,13 +373,16 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
# this has more operations but is faster on COCO-scale dataset.
device = masks.device
if skip_empty:
x0_int, y0_int = torch.clamp(
boxes.min(dim=0).values.floor()[:2] - 1,
min=0).to(dtype=torch.int32)
x1_int = torch.clamp(
boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
y1_int = torch.clamp(
boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
dtype=torch.int32
)
x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w, min=1).to(
dtype=torch.int32
)
y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h, min=1).to(
dtype=torch.int32
)

else:
x0_int, y0_int = 0, 0
x1_int, y1_int = img_w, img_h
Expand All @@ -385,8 +408,7 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)

img_masks = F.grid_sample(
masks.to(dtype=torch.float32), grid, align_corners=False)
img_masks = F.grid_sample(masks.to(dtype=torch.float32), grid, align_corners=False)

if skip_empty:
return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
Expand Down