From f6d5bd05c29e93f178fc37b42030e5a7013f980b Mon Sep 17 00:00:00 2001
From: ggjy <348797244@qq.com>
Date: Sun, 4 Jun 2023 22:18:09 +0800
Subject: [PATCH] release ckpt & detection code
---
README.md | 21 +--
models/vanillanet.py | 24 ++-
object_detection/README.md | 36 ++++
...nn_vanillanet_13_mstrain_480-1024_adamw.py | 110 +++++++++++
...et_vanillanet_13_mstrain_480-1024_adamw.py | 111 +++++++++++
.../layer_decay_optimizer_constructor.py | 173 ++++++++++++++++++
.../mmdet/models/backbones/__init__.py | 20 ++
.../mmdet/models/backbones/vanillanet.py | 114 ++++++++++++
8 files changed, 586 insertions(+), 23 deletions(-)
create mode 100644 object_detection/README.md
create mode 100644 object_detection/configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py
create mode 100644 object_detection/configs/vanillanet/retinanet_vanillanet_13_mstrain_480-1024_adamw.py
create mode 100644 object_detection/mmdet/core/optimizers/layer_decay_optimizer_constructor.py
create mode 100644 object_detection/mmdet/models/backbones/__init__.py
create mode 100644 object_detection/mmdet/models/backbones/vanillanet.py
diff --git a/README.md b/README.md
index 354f849..8916f3d 100644
--- a/README.md
+++ b/README.md
@@ -46,13 +46,8 @@ VanillaNet achieves comparable performance to prevalent computer vision foundati
| **VanillaNet-13** | 58.6 | 11.9 | 4.26 |1.33|0.82|0.67| 82.05 |
## Downstream Tasks
-| Framework | Backbone | FLOPs(G) | #params(M) | FPS | APb | APm |
-|:---:|:---:|:---:|:---:| :---:|:---:|:---:|
-| RetinaNet | Swin-T | 245 | 38.5 | 27.5 | 41.5 | - |
-| | VanillaNet-13 | 397 | 74.6 | 29.8 | 41.8 | - |
-| Mask RCNN | [Swin-T](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | 267 | 47.8 | 28.2 | 42.7 | 39.3 |
-| | VanillaNet-13 | 421 | 76.3 | 32.6 | 42.9 | 39.6 |
+Please refer to [this page](https://github.com/huawei-noah/VanillaNet/object_detection).
VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segmentation** tasks.
@@ -63,8 +58,8 @@ VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segm
- [x] ImageNet-1K Training Code of VanillaNet-5 to VanillaNet-10
- [x] ImageNet-1K Pretrained Weights of VanillaNet-5 to VanillaNet-10
- [ ] ImageNet-1K Training Code of VanillaNet-11 to VanillaNet-13
-- [ ] ImageNet-1K Pretrained Weights of VanillaNet-11 to VanillaNet-13
-- [ ] Downstream Transfer (Detection, Segmentation) Code
+- [x] ImageNet-1K Pretrained Weights of VanillaNet-11 to VanillaNet-13
+- [x] Downstream Transfer (Detection, Segmentation) Code
## Results and Pre-trained Models
### ImageNet-1K trained models
@@ -77,11 +72,11 @@ VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segm
| VanillaNet-8 | 37.1 | 7.7 | 2.56 | 79.13 | [model](https://drive.google.com/file/d/1XNhe2LcNMjNZqBysGNvZLSsbKTrqWTw7/view?usp=sharing) |
| VanillaNet-9 | 41.4 | 8.6 | 2.91 | 79.87 | [model](https://drive.google.com/file/d/1DKifDZR5FqrEr7ICLPzuniQzu03hnnF_/view?usp=sharing) |
| VanillaNet-10 | 45.7 | 9.4 | 3.24 | 80.57 | [model](https://drive.google.com/file/d/1JskZU6otH_6NVXJHNe-74pEaZRVPxXlP/view?usp=sharing) |
-| VanillaNet-11 | 50.0 | 10.3 | 3.59 | 81.08 | - |
-| VanillaNet-12 | 54.3 | 11.1 | 3.82 | 81.55 | - |
-| VanillaNet-13 | 58.6 | 11.9 | 4.26 | 82.05 | - |
-| VanillaNet-13-1.5x | 127.8 | 26.5 | 7.83 | 82.53 | - |
-| VanillaNet-13-1.5x† | 127.8 | 48.9 | 9.72 | 83.11 | - |
+| VanillaNet-11 | 50.0 | 10.3 | 3.59 | 81.08 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_11.pth) |
+| VanillaNet-12 | 54.3 | 11.1 | 3.82 | 81.55 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_12.pth) |
+| VanillaNet-13 | 58.6 | 11.9 | 4.26 | 82.05 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13.pth) |
+| VanillaNet-13-1.5x | 127.8 | 26.5 | 7.83 | 82.53 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_x1_5.pth) |
+| VanillaNet-13-1.5x† | 127.8 | 48.9 | 9.72 | 83.11 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_x1_5_ada_pool.pth) |
## Installation
diff --git a/models/vanillanet.py b/models/vanillanet.py
index 6eef1ed..c6417a3 100644
--- a/models/vanillanet.py
+++ b/models/vanillanet.py
@@ -14,19 +14,22 @@
class activation(nn.ReLU):
def __init__(self, dim, act_num=3, deploy=False):
super(activation, self).__init__()
+ self.act_num = act_num
self.deploy = deploy
- self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1))
- self.bias = None
- self.bn = nn.BatchNorm2d(dim, eps=1e-6)
self.dim = dim
- self.act_num = act_num
+ self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1))
+ if deploy:
+ self.bias = torch.nn.Parameter(torch.zeros(dim))
+ else:
+ self.bias = None
+ self.bn = nn.BatchNorm2d(dim, eps=1e-6)
weight_init.trunc_normal_(self.weight, std=.02)
def forward(self, x):
if self.deploy:
return torch.nn.functional.conv2d(
super(activation, self).forward(x),
- self.weight, self.bias, padding=(self.act_num*2 + 1)//2, groups=self.dim)
+ self.weight, self.bias, padding=self.act_num, groups=self.dim)
else:
return self.bn(torch.nn.functional.conv2d(
super(activation, self).forward(x),
@@ -74,7 +77,7 @@ def __init__(self, dim, dim_out, act_num=3, stride=2, deploy=False, ada_pool=Non
else:
self.pool = nn.Identity() if stride == 1 else nn.AdaptiveMaxPool2d((ada_pool, ada_pool))
- self.act = activation(dim_out, act_num)
+ self.act = activation(dim_out, act_num, deploy=self.deploy)
def forward(self, x):
if self.deploy:
@@ -120,14 +123,15 @@ def __init__(self, in_chans=3, num_classes=1000, dims=[96, 192, 384, 768],
drop_rate=0, act_num=3, strides=[2,2,2,1], deploy=False, ada_pool=None, **kwargs):
super().__init__()
self.deploy = deploy
+ stride, padding = (4, 0) if not ada_pool else (3, 1)
if self.deploy:
self.stem = nn.Sequential(
- nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
- activation(dims[0], act_num)
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=stride, padding=padding),
+ activation(dims[0], act_num, deploy=self.deploy)
)
else:
self.stem1 = nn.Sequential(
- nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=stride, padding=padding),
nn.BatchNorm2d(dims[0], eps=1e-6),
)
self.stem2 = nn.Sequential(
@@ -304,6 +308,6 @@ def vanillanet_13_x1_5_ada_pool(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128*6, 128*6, 256*6, 512*6, 512*6, 512*6, 512*6, 512*6, 512*6, 512*6, 1024*6, 1024*6],
strides=[1,2,2,1,1,1,1,1,1,2,1],
- ada_pool=[0,40,20,0,0,0,0,0,0,10,0],
+ ada_pool=[0,38,19,0,0,0,0,0,0,10,0],
**kwargs)
return model
diff --git a/object_detection/README.md b/object_detection/README.md
new file mode 100644
index 0000000..9cf72b0
--- /dev/null
+++ b/object_detection/README.md
@@ -0,0 +1,36 @@
+# COCO Object detection with VanillaNet
+
+## Getting started
+
+We add VanillaNet model and config files based on [mmdetection-2.x](https://github.com/open-mmlab/mmdetection/tree/2.x). Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/2.x/docs/en/get_started.md) for mmdetection installation and dataset preparation instructions.
+
+## Results and Fine-tuned Models
+
+| Framework | Backbone | FLOPs(G) | Params(M) | FPS | APb | APm | Model |
+|:---:|:---:|:---:|:---:| :---:|:---:|:---:|:---:|
+| RetinaNet | Swin-T | 244.8 | 38.5 | 27.5 | 41.5 | - |-|
+| | VanillaNet-13 | 396.9 | 75.4 | 29.8 | 43.0 | - | [log](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/retinanet_vanillanet_13.log.json)/[model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/retinanet_vanillanet_13.pth) |
+| Mask RCNN | [Swin-T](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/master) | 263.8 | 47.8 | 28.2 | 43.7 | 39.8 |-|
+| | ConvNeXtV2-Nano | 220.6 | 35.2 | 34.4 | 43.3 | 39.4 |-|
+| | VanillaNet-13 | 420.7 | 77.1 | 32.6 | 44.3 | 40.1 | [log](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/mask_rcnn_vanillanet_13.log.json)/[model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/mask_rcnn_vanillanet_13.pth) |
+
+
+### Training
+
+You can download the ImageNet pre-trained [checkpoint](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_act_num_4_kd_pretrain.pth) for VanillaNet-13(act_num=4), which is trained via [knowledge distillation(this paper)](https://arxiv.org/pdf/2305.15781.pdf).
+
+For example, to train a Mask R-CNN model with VanillaNet backbone and 8 gpus, run:
+```
+python -m torch.distributed.launch --nproc_per_node=8 tools/train.py configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py --gpus 8 --launcher pytorch --work-dir
+```
+
+### Inference
+
+For example, test with single-gpu, run:
+```
+python -m torch.distributed.launch --nproc_per_node=1 tools/test.py configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py --launcher pytorch --eval bbox segm
+```
+
+## Acknowledgment
+
+This code is built based on [mmdetection](https://github.com/open-mmlab/mmdetection), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) repositories.
\ No newline at end of file
diff --git a/object_detection/configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py b/object_detection/configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py
new file mode 100644
index 0000000..d8cbd10
--- /dev/null
+++ b/object_detection/configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py
@@ -0,0 +1,110 @@
+#Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.
+
+#This program is free software; you can redistribute it and/or modify it under the terms of the MIT License.
+
+#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MIT License for more details.
+
+
+_base_ = [
+ '../_base_/models/mask_rcnn_r50_fpn.py',
+ '../_base_/schedules/schedule_1x.py',
+ '../_base_/default_runtime.py'
+]
+
+# you can download ckpt from:
+# https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_act_num_4_kd_pretrain.pth
+checkpoint_file = '/your_path_to/vanillanet_13_act_num_4_kd_pretrain.pth'
+
+model = dict(
+ backbone=dict(
+ _delete_=True,
+ type='Vanillanet',
+ act_num=4, # enlarge act_num for better downstream performance
+ dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 1024*4, 1024*4],
+ out_indices=[0, 1, 8, 10],
+ strides=[1,2,2,1,1,1,1,1,1,2,1],
+ init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
+ neck=dict(in_channels=[128*4, 256*4, 512*4, 1024*4]))
+
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(
+ type='Resize',
+ img_scale=[(1333, 480), (1333, 512), (1333, 544), (1333, 576), (1333, 608), (1333, 640), (1333, 672), (1333, 704), (1333, 736), (1333, 768), (1333, 800), (1333, 832), (1333, 864), (1333, 896), (1333, 928), (1333, 960), (1333, 992), (1333, 1024)],
+ multiscale_mode='value',
+ keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_train2017.json',
+ img_prefix=data_root + 'train2017/',
+ pipeline=train_pipeline),
+ persistent_workers=True,
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline))
+evaluation = dict(metric=['bbox', 'segm'])
+
+optimizer = dict(
+ _delete_=True,
+ constructor='LearningRateDecayOptimizerConstructor',
+ type='AdamW',
+ lr=1.3e-4,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ paramwise_cfg={
+ 'decay_rate': 0.6,
+ 'decay_type': 'layer_wise',
+ 'num_layers': 6
+ })
+
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[10, 12])
+runner = dict(max_epochs=12)
+
+log_config = dict(
+ interval=200,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ # dict(type='TensorboardLoggerHook')
+ ])
diff --git a/object_detection/configs/vanillanet/retinanet_vanillanet_13_mstrain_480-1024_adamw.py b/object_detection/configs/vanillanet/retinanet_vanillanet_13_mstrain_480-1024_adamw.py
new file mode 100644
index 0000000..4431596
--- /dev/null
+++ b/object_detection/configs/vanillanet/retinanet_vanillanet_13_mstrain_480-1024_adamw.py
@@ -0,0 +1,111 @@
+#Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.
+
+#This program is free software; you can redistribute it and/or modify it under the terms of the MIT License.
+
+#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MIT License for more details.
+
+
+_base_ = [
+ '../_base_/models/retinanet_r50_fpn.py',
+ '../_base_/datasets/coco_detection.py',
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
+]
+
+# you can download ckpt from:
+# https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_act_num_4_kd_pretrain.pth
+checkpoint_file = '/your_path_to/vanillanet_13_act_num_4_kd_pretrain.pth'
+
+model = dict(
+ backbone=dict(
+ _delete_=True,
+ type='Vanillanet',
+ act_num=4, # enlarge act_num for better downstream performance
+ dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 1024*4, 1024*4],
+ out_indices=[1, 8, 10],
+ strides=[1,2,2,1,1,1,1,1,1,2,1],
+ init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
+ neck=dict(in_channels=[256*4, 512*4, 1024*4], start_level=0, num_outs=5))
+
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='Resize',
+ img_scale=[(1333, 480), (1333, 512), (1333, 544), (1333, 576), (1333, 608), (1333, 640), (1333, 672), (1333, 704), (1333, 736), (1333, 768), (1333, 800), (1333, 832), (1333, 864), (1333, 896), (1333, 928), (1333, 960), (1333, 992), (1333, 1024)],
+ multiscale_mode='value',
+ keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_train2017.json',
+ img_prefix=data_root + 'train2017/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline))
+evaluation = dict(interval=1, metric='bbox')
+
+optimizer = dict(
+ _delete_=True,
+ constructor='LearningRateDecayOptimizerConstructor',
+ type='AdamW',
+ lr=8e-5,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ paramwise_cfg={
+ 'decay_rate': 0.6,
+ 'decay_type': 'layer_wise',
+ 'num_layers': 6
+ })
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[10, 11])
+runner = dict(type='EpochBasedRunner', max_epochs=12)
+
+log_config = dict(
+ interval=200,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ # dict(type='TensorboardLoggerHook')
+ ])
diff --git a/object_detection/mmdet/core/optimizers/layer_decay_optimizer_constructor.py b/object_detection/mmdet/core/optimizers/layer_decay_optimizer_constructor.py
new file mode 100644
index 0000000..5a122b0
--- /dev/null
+++ b/object_detection/mmdet/core/optimizers/layer_decay_optimizer_constructor.py
@@ -0,0 +1,173 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+# Modified from mmdet.
+# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.
+
+
+import json
+
+from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
+
+from mmdet.utils import get_root_logger
+from .builder import OPTIMIZER_BUILDERS
+
+
+def get_layer_id_for_vanillanet(var_name, max_layer_id):
+ if var_name.startswith("backbone.stem"):
+ return 0
+ elif var_name.startswith("backbone.stages"):
+ stage_id = int(var_name.split('.')[2])
+ return stage_id // 2 + 1
+ else:
+ return max_layer_id + 1
+
+
+def get_layer_id_for_convnext(var_name, max_layer_id):
+ """Get the layer id to set the different learning rates in ``layer_wise``
+ decay_type.
+
+ Args:
+ var_name (str): The key of the model.
+ max_layer_id (int): Maximum layer id.
+
+ Returns:
+ int: The id number corresponding to different learning rate in
+ ``LearningRateDecayOptimizerConstructor``.
+ """
+
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
+ 'backbone.pos_embed'):
+ return 0
+ elif var_name.startswith('backbone.downsample_layers'):
+ stage_id = int(var_name.split('.')[2])
+ if stage_id == 0:
+ layer_id = 0
+ elif stage_id == 1:
+ layer_id = 2
+ elif stage_id == 2:
+ layer_id = 3
+ elif stage_id == 3:
+ layer_id = max_layer_id
+ return layer_id
+ elif var_name.startswith('backbone.stages'):
+ stage_id = int(var_name.split('.')[2])
+ block_id = int(var_name.split('.')[3])
+ if stage_id == 0:
+ layer_id = 1
+ elif stage_id == 1:
+ layer_id = 2
+ elif stage_id == 2:
+ layer_id = 3 + block_id // 3
+ elif stage_id == 3:
+ layer_id = max_layer_id
+ return layer_id
+ else:
+ return max_layer_id + 1
+
+
+def get_stage_id_for_convnext(var_name, max_stage_id):
+ """Get the stage id to set the different learning rates in ``stage_wise``
+ decay_type.
+
+ Args:
+ var_name (str): The key of the model.
+ max_stage_id (int): Maximum stage id.
+
+ Returns:
+ int: The id number corresponding to different learning rate in
+ ``LearningRateDecayOptimizerConstructor``.
+ """
+
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
+ 'backbone.pos_embed'):
+ return 0
+ elif var_name.startswith('backbone.downsample_layers'):
+ return 0
+ elif var_name.startswith('backbone.stages'):
+ stage_id = int(var_name.split('.')[2])
+ return stage_id + 1
+ else:
+ return max_stage_id - 1
+
+
+@OPTIMIZER_BUILDERS.register_module()
+class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
+ # Different learning rates are set for different layers of backbone.
+ # Note: Currently, this optimizer constructor is built for ConvNeXt.
+
+ def add_params(self, params, module, **kwargs):
+ """Add all parameters of module to the params list.
+
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+
+ Args:
+ params (list[dict]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ """
+ logger = get_root_logger()
+
+ parameter_groups = {}
+ logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
+ num_layers = self.paramwise_cfg.get('num_layers') + 2
+ decay_rate = self.paramwise_cfg.get('decay_rate')
+ decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
+ logger.info('Build LearningRateDecayOptimizerConstructor '
+ f'{decay_type} {decay_rate} - {num_layers}')
+ weight_decay = self.base_wd
+ for name, param in module.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith('.bias') or name in (
+ 'pos_embed', 'cls_token'):
+ group_name = 'no_decay'
+ this_weight_decay = 0.
+ else:
+ group_name = 'decay'
+ this_weight_decay = weight_decay
+ if 'layer_wise' in decay_type:
+ if 'ConvNeXt' in module.backbone.__class__.__name__:
+ layer_id = get_layer_id_for_convnext(
+ name, self.paramwise_cfg.get('num_layers'))
+ logger.info(f'set param {name} as id {layer_id}')
+ elif 'Vanillanet' in module.backbone.__class__.__name__:
+ layer_id = get_layer_id_for_vanillanet(
+ name, self.paramwise_cfg.get('num_layers'))
+ logger.info(f'set param {name} as id {layer_id}')
+ else:
+ raise NotImplementedError()
+ elif decay_type == 'stage_wise':
+ if 'ConvNeXt' in module.backbone.__class__.__name__:
+ layer_id = get_stage_id_for_convnext(name, num_layers)
+ logger.info(f'set param {name} as id {layer_id}')
+ else:
+ raise NotImplementedError()
+ group_name = f'layer_{layer_id}_{group_name}'
+
+ if group_name not in parameter_groups:
+ scale = decay_rate**(num_layers - layer_id - 1)
+
+ parameter_groups[group_name] = {
+ 'weight_decay': this_weight_decay,
+ 'params': [],
+ 'param_names': [],
+ 'lr_scale': scale,
+ 'group_name': group_name,
+ 'lr': scale * self.base_lr,
+ }
+
+ parameter_groups[group_name]['params'].append(param)
+ parameter_groups[group_name]['param_names'].append(name)
+ rank, _ = get_dist_info()
+ if rank == 0:
+ to_display = {}
+ for key in parameter_groups:
+ to_display[key] = {
+ 'param_names': parameter_groups[key]['param_names'],
+ 'lr_scale': parameter_groups[key]['lr_scale'],
+ 'lr': parameter_groups[key]['lr'],
+ 'weight_decay': parameter_groups[key]['weight_decay'],
+ }
+ logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
+ params.extend(parameter_groups.values())
diff --git a/object_detection/mmdet/models/backbones/__init__.py b/object_detection/mmdet/models/backbones/__init__.py
new file mode 100644
index 0000000..2ddaa79
--- /dev/null
+++ b/object_detection/mmdet/models/backbones/__init__.py
@@ -0,0 +1,20 @@
+from .darknet import Darknet
+from .detectors_resnet import DetectoRS_ResNet
+from .detectors_resnext import DetectoRS_ResNeXt
+from .hourglass import HourglassNet
+from .hrnet import HRNet
+from .regnet import RegNet
+from .res2net import Res2Net
+from .resnest import ResNeSt
+from .resnet import ResNet, ResNetV1d
+from .resnext import ResNeXt
+from .ssd_vgg import SSDVGG
+from .trident_resnet import TridentResNet
+from .swin_transformer import SwinTransformer
+from .vanillanet import Vanillanet
+
+__all__ = [
+ 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
+ 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet',
+ 'ResNeSt', 'TridentResNet', 'SwinTransformer', 'Vanillanet'
+]
diff --git a/object_detection/mmdet/models/backbones/vanillanet.py b/object_detection/mmdet/models/backbones/vanillanet.py
new file mode 100644
index 0000000..f9c0950
--- /dev/null
+++ b/object_detection/mmdet/models/backbones/vanillanet.py
@@ -0,0 +1,114 @@
+#Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.
+
+#This program is free software; you can redistribute it and/or modify it under the terms of the MIT License.
+
+#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MIT License for more details.
+
+
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
+ constant_init, normal_init, trunc_normal_init)
+from mmcv.cnn.bricks.drop import build_dropout
+from mmcv.cnn.utils.weight_init import trunc_normal_
+from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint,
+ load_state_dict)
+
+from ...utils import get_root_logger
+from ..builder import BACKBONES
+
+
+class activation(nn.ReLU):
+ def __init__(self, dim, act_num=3, norm_layer=nn.SyncBatchNorm):
+ super(activation, self).__init__()
+ self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1))
+ self.bn = norm_layer(dim, eps=1e-6)
+ self.dim = dim
+ self.act_num = act_num
+ trunc_normal_(self.weight, std=.02)
+
+ def forward(self, x):
+ return self.bn(torch.nn.functional.conv2d(super(activation, self).forward(x), self.weight, padding=self.act_num, groups=self.dim))
+
+class Block(nn.Module):
+ def __init__(self, dim, dim_out, act_num=3, stride=2, norm_layer=nn.SyncBatchNorm):
+ super().__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(dim, dim, kernel_size=1),
+ norm_layer(dim, eps=1e-6),
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(dim, dim_out, kernel_size=1),
+ norm_layer(dim_out, eps=1e-6)
+ )
+ self.pool = nn.Identity() if stride == 1 else nn.MaxPool2d(stride)
+ self.act = activation(dim_out, act_num, norm_layer)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.pool(x)
+ x = self.act(x)
+ return x
+
+
+@BACKBONES.register_module()
+class Vanillanet(BaseModule):
+ def __init__(self, in_chans=3, act_num=3, dims=[96, 192, 384, 768], out_indices=[2,4,6],
+ strides=[2,2,2,1], norm_layer=nn.SyncBatchNorm, init_cfg=None, **kwargs):
+ super().__init__()
+
+ self.out_indices = out_indices
+ self.init_cfg = init_cfg
+
+ self.stem1 = nn.Sequential(
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+ norm_layer(dims[0], eps=1e-6),
+ )
+ self.stem2 = nn.Sequential(
+ nn.Conv2d(dims[0], dims[0], kernel_size=1, stride=1),
+ norm_layer(dims[0], eps=1e-6),
+ activation(dims[0], act_num, norm_layer)
+ )
+
+ self.stages = nn.ModuleList()
+ for i in range(len(strides)):
+ stage = Block(dim=dims[i], dim_out=dims[i+1], act_num=act_num, stride=strides[i], norm_layer=norm_layer)
+ self.stages.append(stage)
+ self.depth = len(strides)
+
+ def init_weights(self):
+ if self.init_cfg is None:
+ logger = get_root_logger()
+ logger.warn(f'No pre-trained weights for '
+ f'{self.__class__.__name__}, '
+ f'training start from scratch')
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=.02, bias=0.)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[
+ 1] * m.out_channels
+ fan_out //= m.groups
+ normal_init(m, 0, math.sqrt(2.0 / fan_out))
+ else:
+ state_dict = torch.load(self.init_cfg.checkpoint, map_location='cpu')
+ msg = self.load_state_dict(state_dict['model_ema'], strict=False)
+ print(msg)
+ print('Successfully load backbone ckpt.')
+
+ def forward(self, x):
+ outs = []
+ x = self.stem1(x)
+ x = self.stem2(x)
+ for i in range(self.depth):
+ x = self.stages[i](x)
+
+ if i in self.out_indices:
+ outs.append(x)
+
+ return outs