Skip to content
This repository has been archived by the owner on Jan 26, 2022. It is now read-only.

RetinaNet support #183

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions BENCHMARK.md
Original file line number Diff line number Diff line change
@@ -29,6 +29,72 @@
<tr><th align="left" bgcolor=#f8f8f8>ARl</th><td bgcolor=white> AR for large objects: area > 96<sup>2</sup></sup></td></tr>
</table></tbody>

## RetinaNet
### retinanet-R-50-FPN_1x

- Training command:

```
python tools/train_net_step.py \
--dataset coco2017 --cfg configs/baselines/retinanet_R-50-FPN_1x.yaml \
--bs 8 --iter_size 1 --use_tfboard
```
on four V100 GPUs.

<table><tbody>
<tr><th colspan="13" bgcolor=#f8f8f8>Box</th></tr>
<tr>
<th>source</th>
<th>AP50:95</th>
<th>AP50</th>
<th>AP75</th>
<th>APs</th>
<th>APm</th>
<th>APl</th>
<th>AR1</th>
<th>AR10</th>
<th>AR100</th>
<th>ARs</th>
<th>ARm</th>
<th>ARl</th>
</tr>
<tr>
<th bgcolor=white>PyTorch</th>
<td align="right" bgcolor=white>35.3</td>
<td align="right" bgcolor=white>54.6</td>
<td align="right" bgcolor=white>37.9</td>
<td align="right" bgcolor=white>19.4</td>
<td align="right" bgcolor=white>39.1</td>
<td align="right" bgcolor=white>47.5</td>
<td align="right" bgcolor=white>30.7</td>
<td align="right" bgcolor=white>48.9</td>
<td align="right" bgcolor=white>51.8</td>
<td align="right" bgcolor=white>32.4</td>
<td align="right" bgcolor=white>56.3</td>
<td align="right" bgcolor=white>67.4</td>
</tr>
<tr>
<th bgcolor=white>Detectron</th>
<td align="right", bgcolor=white>35.7</td>
<td align="right", bgcolor=white>54.7</td>
<td align="right", bgcolor=white>38.5</td>
<td align="right", bgcolor=white>19.5</td>
<td align="right", bgcolor=white>39.9</td>
<td align="right", bgcolor=white>47.5</td>
<td align="right", bgcolor=white>30.7</td>
<td align="right", bgcolor=white>49.1</td>
<td align="right", bgcolor=white>52.0</td>
<td align="right", bgcolor=white>32.0</td>
<td align="right", bgcolor=white>56.9</td>
<td align="right", bgcolor=white>68.0</td>
</tr>
</table></tbody>

- Total loss comparison:

![img](demo/loss_retinanet_R-50-FPN_1x.jpg)


## Faster-RCNN
### e2e_faster_rcnn-R-50-FPN_1x

41 changes: 41 additions & 0 deletions configs/baselines/retinanet_R-50-FPN_1x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
DEBUG: False
MODEL:
TYPE: retinanet
CONV_BODY: FPN.fpn_ResNet50_conv5_body
NUM_CLASSES: 81
RESNETS:
IMAGENET_PRETRAINED_WEIGHTS: 'data/pretrained_networks/resnet50_caffe.pth'
NUM_GPUS: 8
SOLVER:
WEIGHT_DECAY: 0.0001
LR_POLICY: steps_with_decay
BASE_LR: 0.01
GAMMA: 0.1
MAX_ITER: 90000
STEPS: [0, 60000, 80000]
FPN:
FPN_ON: True
MULTILEVEL_RPN: True
RPN_MAX_LEVEL: 7
RPN_MIN_LEVEL: 3
COARSEST_STRIDE: 128
EXTRA_CONV_LEVELS: True
RETINANET:
RETINANET_ON: True
NUM_CONVS: 4
ASPECT_RATIOS: (1.0, 2.0, 0.5)
SCALES_PER_OCTAVE: 3
ANCHOR_SCALE: 4
LOSS_GAMMA: 2.0
LOSS_ALPHA: 0.25
TRAIN:
SCALES: (800,)
MAX_SIZE: 1333
RPN_STRADDLE_THRESH: -1 # default 0
TEST:
SCALE: 800
MAX_SIZE: 1333
NMS: 0.5
RPN_PRE_NMS_TOP_N: 10000 # Per FPN level
RPN_POST_NMS_TOP_N: 2000
OUTPUT_DIR: .
Binary file added demo/loss_retinanet_R-50-FPN_1x.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions lib/core/test.py
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@
import utils.fpn as fpn_utils
import utils.image as image_utils
import utils.keypoints as keypoint_utils
import core.test_retinanet as test_retinanet


def im_detect_all(model, im, box_proposals=None, timers=None):
@@ -62,6 +63,13 @@ def im_detect_all(model, im, box_proposals=None, timers=None):
if timers is None:
timers = defaultdict(Timer)

# Handle RetinaNet testing separately for now
if cfg.RETINANET.RETINANET_ON:
timers['im_detect_bbox'].tic()
cls_boxes = test_retinanet.im_detect_bbox(model, im, timers)
timers['im_detect_bbox'].toc()
return cls_boxes, None, None

timers['im_detect_bbox'].tic()
if cfg.TEST.BBOX_AUG.ENABLED:
scores, boxes, im_scale, blob_conv = im_detect_bbox_aug(
196 changes: 196 additions & 0 deletions lib/core/test_retinanet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

"""Test a RetinaNet network on an image database"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np
import logging
from collections import defaultdict

from torch.autograd import Variable
import torch

from core.config import cfg
from modeling.generate_anchors import generate_anchors
from utils.timer import Timer
import utils.blob as blob_utils
import utils.boxes as box_utils
import roi_data.data_utils as data_utils

logger = logging.getLogger(__name__)


def _create_cell_anchors():
"""
Generate all types of anchors for all fpn levels/scales/aspect ratios.
This function is called only once at the beginning of inference.
"""
k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
scales_per_octave = cfg.RETINANET.SCALES_PER_OCTAVE
aspect_ratios = cfg.RETINANET.ASPECT_RATIOS
anchor_scale = cfg.RETINANET.ANCHOR_SCALE
A = scales_per_octave * len(aspect_ratios)

anchors = {}
for lvl in range(k_min, k_max + 1):
# create cell anchors array
stride = 2. ** lvl
cell_anchors = np.zeros((A, 4))
a = 0
for octave in range(scales_per_octave):
octave_scale = 2 ** (octave / float(scales_per_octave))
for aspect in aspect_ratios:
anchor_sizes = (stride * octave_scale * anchor_scale, )
anchor_aspect_ratios = (aspect, )
cell_anchors[a, :] = generate_anchors(
stride=stride, sizes=anchor_sizes,
aspect_ratios=anchor_aspect_ratios)
a += 1
anchors[lvl] = cell_anchors
return anchors


def im_detect_bbox(model, im, timers=None):
"""Generate RetinaNet detections on a single image."""
if timers is None:
timers = defaultdict(Timer)
# Although anchors are input independent and could be precomputed,
# recomputing them per image only brings a small overhead
anchors = _create_cell_anchors()
timers['im_detect_bbox'].tic()
k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL
A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS)
inputs = {}
inputs['data'], im_scale, inputs['im_info'] = \
blob_utils.get_image_blob(im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)

if cfg.PYTORCH_VERSION_LESS_THAN_040:
inputs['data'] = [
Variable(torch.from_numpy(inputs['data']), volatile=True)]
inputs['im_info'] = [
Variable(torch.from_numpy(inputs['im_info']), volatile=True)]
else:
inputs['data'] = [torch.from_numpy(inputs['data'])]
inputs['im_info'] = [torch.from_numpy(inputs['im_info'])]

return_dict = model(**inputs)
cls_probs = return_dict['cls_score']
box_preds = return_dict['bbox_pred']

# here the boxes_all are [x0, y0, x1, y1, score]
boxes_all = defaultdict(list)

cnt = 0
for lvl in range(k_min, k_max + 1):
# create cell anchors array
stride = 2. ** lvl
cell_anchors = anchors[lvl]

# fetch per level probability
cls_prob = cls_probs[cnt].data.cpu().numpy()
box_pred = box_preds[cnt].data.cpu().numpy()
cls_prob = cls_prob.reshape((
cls_prob.shape[0], A, int(cls_prob.shape[1] / A),
cls_prob.shape[2], cls_prob.shape[3]))
box_pred = box_pred.reshape((
box_pred.shape[0], A, 4, box_pred.shape[2], box_pred.shape[3]))
cnt += 1

if cfg.RETINANET.SOFTMAX:
cls_prob = cls_prob[:, :, 1::, :, :]

cls_prob_ravel = cls_prob.ravel()
# In some cases [especially for very small img sizes], it's possible that
# candidate_ind is empty if we impose threshold 0.05 at all levels. This
# will lead to errors since no detections are found for this image. Hence,
# for lvl 7 which has small spatial resolution, we take the threshold 0.0
th = cfg.RETINANET.INFERENCE_TH if lvl < k_max else 0.0
candidate_inds = np.where(cls_prob_ravel > th)[0]
if (len(candidate_inds) == 0):
continue

pre_nms_topn = min(cfg.RETINANET.PRE_NMS_TOP_N, len(candidate_inds))
inds = np.argpartition(
cls_prob_ravel[candidate_inds], -pre_nms_topn)[-pre_nms_topn:]
inds = candidate_inds[inds]

inds_5d = np.array(np.unravel_index(inds, cls_prob.shape)).transpose()
classes = inds_5d[:, 2]
anchor_ids, y, x = inds_5d[:, 1], inds_5d[:, 3], inds_5d[:, 4]
scores = cls_prob[:, anchor_ids, classes, y, x]

boxes = np.column_stack((x, y, x, y)).astype(dtype=np.float32)
boxes *= stride
boxes += cell_anchors[anchor_ids, :]

if not cfg.RETINANET.CLASS_SPECIFIC_BBOX:
box_deltas = box_pred[0, anchor_ids, :, y, x]
else:
box_cls_inds = classes * 4
box_deltas = np.vstack(
[box_pred[0, ind:ind + 4, yi, xi]
for ind, yi, xi in zip(box_cls_inds, y, x)]
)
pred_boxes = (
box_utils.bbox_transform(boxes, box_deltas)
if cfg.TEST.BBOX_REG else boxes)
pred_boxes /= im_scale
pred_boxes = box_utils.clip_tiled_boxes(pred_boxes, im.shape)
box_scores = np.zeros((pred_boxes.shape[0], 5))
box_scores[:, 0:4] = pred_boxes
box_scores[:, 4] = scores

for cls in range(1, cfg.MODEL.NUM_CLASSES):
inds = np.where(classes == cls - 1)[0]
if len(inds) > 0:
boxes_all[cls].extend(box_scores[inds, :])
timers['im_detect_bbox'].toc()

# Combine predictions across all levels and retain the top scoring by class
timers['misc_bbox'].tic()
detections = []
for cls, boxes in boxes_all.items():
cls_dets = np.vstack(boxes).astype(dtype=np.float32)
# do class specific nms here
keep = box_utils.nms(cls_dets, cfg.TEST.NMS)
cls_dets = cls_dets[keep, :]
out = np.zeros((len(keep), 6))
out[:, 0:5] = cls_dets
out[:, 5].fill(cls)
detections.append(out)

# detections (N, 6) format:
# detections[:, :4] - boxes
# detections[:, 4] - scores
# detections[:, 5] - classes
detections = np.vstack(detections)
# sort all again
inds = np.argsort(-detections[:, 4])
detections = detections[inds[0:cfg.TEST.DETECTIONS_PER_IM], :]

# Convert the detections to image cls_ format (see core/test_engine.py)
num_classes = cfg.MODEL.NUM_CLASSES
cls_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
for c in range(1, num_classes):
inds = np.where(detections[:, 5] == c)[0]
cls_boxes[c] = detections[inds, :5]
timers['misc_bbox'].toc()

return cls_boxes
10 changes: 5 additions & 5 deletions lib/modeling/FPN.py
Original file line number Diff line number Diff line change
@@ -140,7 +140,7 @@ def __init__(self, conv_body_func, fpn_level_info, P2only=False):
self.extra_pyramid_modules = nn.ModuleList()
dim_in = fpn_level_info.dims[0]
for i in range(HIGHEST_BACKBONE_LVL + 1, max_level + 1):
self.extra_pyramid_modules(
self.extra_pyramid_modules.append(
nn.Conv2d(dim_in, fpn_dim, 3, 2, 1)
)
dim_in = fpn_dim
@@ -214,7 +214,7 @@ def detectron_weight_mapping(self):
})

if hasattr(self, 'extra_pyramid_modules'):
for i in len(self.extra_pyramid_modules):
for i in range(len(self.extra_pyramid_modules)):
p_prefix = 'extra_pyramid_modules.%d' % i
d_prefix = 'fpn_%d' % (HIGHEST_BACKBONE_LVL + 1 + i)
mapping_to_detectron.update({
@@ -246,9 +246,9 @@ def forward(self, x):

if hasattr(self, 'extra_pyramid_modules'):
blob_in = conv_body_blobs[-1]
fpn_output_blobs.insert(0, self.extra_pyramid_modules(blob_in))
fpn_output_blobs.insert(0, self.extra_pyramid_modules[0](blob_in))
for module in self.extra_pyramid_modules[1:]:
fpn_output_blobs.insert(0, module(F.relu(fpn_output_blobs[0], inplace=True)))
fpn_output_blobs.insert(0, module(F.relu(fpn_output_blobs[0])))

if self.P2only:
# use only the finest level
@@ -294,7 +294,7 @@ def forward(self, top_blob, lateral_blob):
lat = self.conv_lateral(lateral_blob)
# Top-down 2x upsampling
# td = F.upsample(top_blob, size=lat.size()[2:], mode='bilinear')
td = F.upsample(top_blob, scale_factor=2, mode='nearest')
td = F.interpolate(top_blob, scale_factor=2, mode='nearest')
# Sum lateral and top-down
return lat + td

Loading