From f6d5bd05c29e93f178fc37b42030e5a7013f980b Mon Sep 17 00:00:00 2001 From: ggjy <348797244@qq.com> Date: Sun, 4 Jun 2023 22:18:09 +0800 Subject: [PATCH] release ckpt & detection code --- README.md | 21 +-- models/vanillanet.py | 24 ++- object_detection/README.md | 36 ++++ ...nn_vanillanet_13_mstrain_480-1024_adamw.py | 110 +++++++++++ ...et_vanillanet_13_mstrain_480-1024_adamw.py | 111 +++++++++++ .../layer_decay_optimizer_constructor.py | 173 ++++++++++++++++++ .../mmdet/models/backbones/__init__.py | 20 ++ .../mmdet/models/backbones/vanillanet.py | 114 ++++++++++++ 8 files changed, 586 insertions(+), 23 deletions(-) create mode 100644 object_detection/README.md create mode 100644 object_detection/configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py create mode 100644 object_detection/configs/vanillanet/retinanet_vanillanet_13_mstrain_480-1024_adamw.py create mode 100644 object_detection/mmdet/core/optimizers/layer_decay_optimizer_constructor.py create mode 100644 object_detection/mmdet/models/backbones/__init__.py create mode 100644 object_detection/mmdet/models/backbones/vanillanet.py diff --git a/README.md b/README.md index 354f849..8916f3d 100644 --- a/README.md +++ b/README.md @@ -46,13 +46,8 @@ VanillaNet achieves comparable performance to prevalent computer vision foundati | **VanillaNet-13** | 58.6 | 11.9 | 4.26 |1.33|0.82|0.67| 82.05 | ## Downstream Tasks -| Framework | Backbone | FLOPs(G) | #params(M) | FPS | APb | APm | -|:---:|:---:|:---:|:---:| :---:|:---:|:---:| -| RetinaNet | Swin-T | 245 | 38.5 | 27.5 | 41.5 | - | -| | VanillaNet-13 | 397 | 74.6 | 29.8 | 41.8 | - | -| Mask RCNN | [Swin-T](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | 267 | 47.8 | 28.2 | 42.7 | 39.3 | -| | VanillaNet-13 | 421 | 76.3 | 32.6 | 42.9 | 39.6 | +Please refer to [this page](https://github.com/huawei-noah/VanillaNet/object_detection). VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segmentation** tasks. @@ -63,8 +58,8 @@ VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segm - [x] ImageNet-1K Training Code of VanillaNet-5 to VanillaNet-10 - [x] ImageNet-1K Pretrained Weights of VanillaNet-5 to VanillaNet-10 - [ ] ImageNet-1K Training Code of VanillaNet-11 to VanillaNet-13 -- [ ] ImageNet-1K Pretrained Weights of VanillaNet-11 to VanillaNet-13 -- [ ] Downstream Transfer (Detection, Segmentation) Code +- [x] ImageNet-1K Pretrained Weights of VanillaNet-11 to VanillaNet-13 +- [x] Downstream Transfer (Detection, Segmentation) Code ## Results and Pre-trained Models ### ImageNet-1K trained models @@ -77,11 +72,11 @@ VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segm | VanillaNet-8 | 37.1 | 7.7 | 2.56 | 79.13 | [model](https://drive.google.com/file/d/1XNhe2LcNMjNZqBysGNvZLSsbKTrqWTw7/view?usp=sharing) | | VanillaNet-9 | 41.4 | 8.6 | 2.91 | 79.87 | [model](https://drive.google.com/file/d/1DKifDZR5FqrEr7ICLPzuniQzu03hnnF_/view?usp=sharing) | | VanillaNet-10 | 45.7 | 9.4 | 3.24 | 80.57 | [model](https://drive.google.com/file/d/1JskZU6otH_6NVXJHNe-74pEaZRVPxXlP/view?usp=sharing) | -| VanillaNet-11 | 50.0 | 10.3 | 3.59 | 81.08 | - | -| VanillaNet-12 | 54.3 | 11.1 | 3.82 | 81.55 | - | -| VanillaNet-13 | 58.6 | 11.9 | 4.26 | 82.05 | - | -| VanillaNet-13-1.5x | 127.8 | 26.5 | 7.83 | 82.53 | - | -| VanillaNet-13-1.5x† | 127.8 | 48.9 | 9.72 | 83.11 | - | +| VanillaNet-11 | 50.0 | 10.3 | 3.59 | 81.08 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_11.pth) | +| VanillaNet-12 | 54.3 | 11.1 | 3.82 | 81.55 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_12.pth) | +| VanillaNet-13 | 58.6 | 11.9 | 4.26 | 82.05 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13.pth) | +| VanillaNet-13-1.5x | 127.8 | 26.5 | 7.83 | 82.53 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_x1_5.pth) | +| VanillaNet-13-1.5x† | 127.8 | 48.9 | 9.72 | 83.11 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_x1_5_ada_pool.pth) | ## Installation diff --git a/models/vanillanet.py b/models/vanillanet.py index 6eef1ed..c6417a3 100644 --- a/models/vanillanet.py +++ b/models/vanillanet.py @@ -14,19 +14,22 @@ class activation(nn.ReLU): def __init__(self, dim, act_num=3, deploy=False): super(activation, self).__init__() + self.act_num = act_num self.deploy = deploy - self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1)) - self.bias = None - self.bn = nn.BatchNorm2d(dim, eps=1e-6) self.dim = dim - self.act_num = act_num + self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1)) + if deploy: + self.bias = torch.nn.Parameter(torch.zeros(dim)) + else: + self.bias = None + self.bn = nn.BatchNorm2d(dim, eps=1e-6) weight_init.trunc_normal_(self.weight, std=.02) def forward(self, x): if self.deploy: return torch.nn.functional.conv2d( super(activation, self).forward(x), - self.weight, self.bias, padding=(self.act_num*2 + 1)//2, groups=self.dim) + self.weight, self.bias, padding=self.act_num, groups=self.dim) else: return self.bn(torch.nn.functional.conv2d( super(activation, self).forward(x), @@ -74,7 +77,7 @@ def __init__(self, dim, dim_out, act_num=3, stride=2, deploy=False, ada_pool=Non else: self.pool = nn.Identity() if stride == 1 else nn.AdaptiveMaxPool2d((ada_pool, ada_pool)) - self.act = activation(dim_out, act_num) + self.act = activation(dim_out, act_num, deploy=self.deploy) def forward(self, x): if self.deploy: @@ -120,14 +123,15 @@ def __init__(self, in_chans=3, num_classes=1000, dims=[96, 192, 384, 768], drop_rate=0, act_num=3, strides=[2,2,2,1], deploy=False, ada_pool=None, **kwargs): super().__init__() self.deploy = deploy + stride, padding = (4, 0) if not ada_pool else (3, 1) if self.deploy: self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), - activation(dims[0], act_num) + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=stride, padding=padding), + activation(dims[0], act_num, deploy=self.deploy) ) else: self.stem1 = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=stride, padding=padding), nn.BatchNorm2d(dims[0], eps=1e-6), ) self.stem2 = nn.Sequential( @@ -304,6 +308,6 @@ def vanillanet_13_x1_5_ada_pool(pretrained=False, in_22k=False, **kwargs): model = VanillaNet( dims=[128*6, 128*6, 256*6, 512*6, 512*6, 512*6, 512*6, 512*6, 512*6, 512*6, 1024*6, 1024*6], strides=[1,2,2,1,1,1,1,1,1,2,1], - ada_pool=[0,40,20,0,0,0,0,0,0,10,0], + ada_pool=[0,38,19,0,0,0,0,0,0,10,0], **kwargs) return model diff --git a/object_detection/README.md b/object_detection/README.md new file mode 100644 index 0000000..9cf72b0 --- /dev/null +++ b/object_detection/README.md @@ -0,0 +1,36 @@ +# COCO Object detection with VanillaNet + +## Getting started + +We add VanillaNet model and config files based on [mmdetection-2.x](https://github.com/open-mmlab/mmdetection/tree/2.x). Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/2.x/docs/en/get_started.md) for mmdetection installation and dataset preparation instructions. + +## Results and Fine-tuned Models + +| Framework | Backbone | FLOPs(G) | Params(M) | FPS | APb | APm | Model | +|:---:|:---:|:---:|:---:| :---:|:---:|:---:|:---:| +| RetinaNet | Swin-T | 244.8 | 38.5 | 27.5 | 41.5 | - |-| +| | VanillaNet-13 | 396.9 | 75.4 | 29.8 | 43.0 | - | [log](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/retinanet_vanillanet_13.log.json)/[model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/retinanet_vanillanet_13.pth) | +| Mask RCNN | [Swin-T](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/master) | 263.8 | 47.8 | 28.2 | 43.7 | 39.8 |-| +| | ConvNeXtV2-Nano | 220.6 | 35.2 | 34.4 | 43.3 | 39.4 |-| +| | VanillaNet-13 | 420.7 | 77.1 | 32.6 | 44.3 | 40.1 | [log](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/mask_rcnn_vanillanet_13.log.json)/[model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/mask_rcnn_vanillanet_13.pth) | + + +### Training + +You can download the ImageNet pre-trained [checkpoint](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_act_num_4_kd_pretrain.pth) for VanillaNet-13(act_num=4), which is trained via [knowledge distillation(this paper)](https://arxiv.org/pdf/2305.15781.pdf). + +For example, to train a Mask R-CNN model with VanillaNet backbone and 8 gpus, run: +``` +python -m torch.distributed.launch --nproc_per_node=8 tools/train.py configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py --gpus 8 --launcher pytorch --work-dir +``` + +### Inference + +For example, test with single-gpu, run: +``` +python -m torch.distributed.launch --nproc_per_node=1 tools/test.py configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py --launcher pytorch --eval bbox segm +``` + +## Acknowledgment + +This code is built based on [mmdetection](https://github.com/open-mmlab/mmdetection), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) repositories. \ No newline at end of file diff --git a/object_detection/configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py b/object_detection/configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py new file mode 100644 index 0000000..d8cbd10 --- /dev/null +++ b/object_detection/configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py @@ -0,0 +1,110 @@ +#Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. + +#This program is free software; you can redistribute it and/or modify it under the terms of the MIT License. + +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MIT License for more details. + + +_base_ = [ + '../_base_/models/mask_rcnn_r50_fpn.py', + '../_base_/schedules/schedule_1x.py', + '../_base_/default_runtime.py' +] + +# you can download ckpt from: +# https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_act_num_4_kd_pretrain.pth +checkpoint_file = '/your_path_to/vanillanet_13_act_num_4_kd_pretrain.pth' + +model = dict( + backbone=dict( + _delete_=True, + type='Vanillanet', + act_num=4, # enlarge act_num for better downstream performance + dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 1024*4, 1024*4], + out_indices=[0, 1, 8, 10], + strides=[1,2,2,1,1,1,1,1,1,2,1], + init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)), + neck=dict(in_channels=[128*4, 256*4, 512*4, 1024*4])) + +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='Resize', + img_scale=[(1333, 480), (1333, 512), (1333, 544), (1333, 576), (1333, 608), (1333, 640), (1333, 672), (1333, 704), (1333, 736), (1333, 768), (1333, 800), (1333, 832), (1333, 864), (1333, 896), (1333, 928), (1333, 960), (1333, 992), (1333, 1024)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + persistent_workers=True, + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(metric=['bbox', 'segm']) + +optimizer = dict( + _delete_=True, + constructor='LearningRateDecayOptimizerConstructor', + type='AdamW', + lr=1.3e-4, + betas=(0.9, 0.999), + weight_decay=0.05, + paramwise_cfg={ + 'decay_rate': 0.6, + 'decay_type': 'layer_wise', + 'num_layers': 6 + }) + +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[10, 12]) +runner = dict(max_epochs=12) + +log_config = dict( + interval=200, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) diff --git a/object_detection/configs/vanillanet/retinanet_vanillanet_13_mstrain_480-1024_adamw.py b/object_detection/configs/vanillanet/retinanet_vanillanet_13_mstrain_480-1024_adamw.py new file mode 100644 index 0000000..4431596 --- /dev/null +++ b/object_detection/configs/vanillanet/retinanet_vanillanet_13_mstrain_480-1024_adamw.py @@ -0,0 +1,111 @@ +#Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. + +#This program is free software; you can redistribute it and/or modify it under the terms of the MIT License. + +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MIT License for more details. + + +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +# you can download ckpt from: +# https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_act_num_4_kd_pretrain.pth +checkpoint_file = '/your_path_to/vanillanet_13_act_num_4_kd_pretrain.pth' + +model = dict( + backbone=dict( + _delete_=True, + type='Vanillanet', + act_num=4, # enlarge act_num for better downstream performance + dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 1024*4, 1024*4], + out_indices=[1, 8, 10], + strides=[1,2,2,1,1,1,1,1,1,2,1], + init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)), + neck=dict(in_channels=[256*4, 512*4, 1024*4], start_level=0, num_outs=5)) + +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='Resize', + img_scale=[(1333, 480), (1333, 512), (1333, 544), (1333, 576), (1333, 608), (1333, 640), (1333, 672), (1333, 704), (1333, 736), (1333, 768), (1333, 800), (1333, 832), (1333, 864), (1333, 896), (1333, 928), (1333, 960), (1333, 992), (1333, 1024)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(interval=1, metric='bbox') + +optimizer = dict( + _delete_=True, + constructor='LearningRateDecayOptimizerConstructor', + type='AdamW', + lr=8e-5, + betas=(0.9, 0.999), + weight_decay=0.05, + paramwise_cfg={ + 'decay_rate': 0.6, + 'decay_type': 'layer_wise', + 'num_layers': 6 + }) + +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[10, 11]) +runner = dict(type='EpochBasedRunner', max_epochs=12) + +log_config = dict( + interval=200, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) diff --git a/object_detection/mmdet/core/optimizers/layer_decay_optimizer_constructor.py b/object_detection/mmdet/core/optimizers/layer_decay_optimizer_constructor.py new file mode 100644 index 0000000..5a122b0 --- /dev/null +++ b/object_detection/mmdet/core/optimizers/layer_decay_optimizer_constructor.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Modified from mmdet. +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. + + +import json + +from mmcv.runner import DefaultOptimizerConstructor, get_dist_info + +from mmdet.utils import get_root_logger +from .builder import OPTIMIZER_BUILDERS + + +def get_layer_id_for_vanillanet(var_name, max_layer_id): + if var_name.startswith("backbone.stem"): + return 0 + elif var_name.startswith("backbone.stages"): + stage_id = int(var_name.split('.')[2]) + return stage_id // 2 + 1 + else: + return max_layer_id + 1 + + +def get_layer_id_for_convnext(var_name, max_layer_id): + """Get the layer id to set the different learning rates in ``layer_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_layer_id (int): Maximum layer id. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + stage_id = int(var_name.split('.')[2]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + block_id = int(var_name.split('.')[3]) + if stage_id == 0: + layer_id = 1 + elif stage_id == 1: + layer_id = 2 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = max_layer_id + return layer_id + else: + return max_layer_id + 1 + + +def get_stage_id_for_convnext(var_name, max_stage_id): + """Get the stage id to set the different learning rates in ``stage_wise`` + decay_type. + + Args: + var_name (str): The key of the model. + max_stage_id (int): Maximum stage id. + + Returns: + int: The id number corresponding to different learning rate in + ``LearningRateDecayOptimizerConstructor``. + """ + + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.downsample_layers'): + return 0 + elif var_name.startswith('backbone.stages'): + stage_id = int(var_name.split('.')[2]) + return stage_id + 1 + else: + return max_stage_id - 1 + + +@OPTIMIZER_BUILDERS.register_module() +class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor): + # Different learning rates are set for different layers of backbone. + # Note: Currently, this optimizer constructor is built for ConvNeXt. + + def add_params(self, params, module, **kwargs): + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + """ + logger = get_root_logger() + + parameter_groups = {} + logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}') + num_layers = self.paramwise_cfg.get('num_layers') + 2 + decay_rate = self.paramwise_cfg.get('decay_rate') + decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise') + logger.info('Build LearningRateDecayOptimizerConstructor ' + f'{decay_type} {decay_rate} - {num_layers}') + weight_decay = self.base_wd + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith('.bias') or name in ( + 'pos_embed', 'cls_token'): + group_name = 'no_decay' + this_weight_decay = 0. + else: + group_name = 'decay' + this_weight_decay = weight_decay + if 'layer_wise' in decay_type: + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_convnext( + name, self.paramwise_cfg.get('num_layers')) + logger.info(f'set param {name} as id {layer_id}') + elif 'Vanillanet' in module.backbone.__class__.__name__: + layer_id = get_layer_id_for_vanillanet( + name, self.paramwise_cfg.get('num_layers')) + logger.info(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + elif decay_type == 'stage_wise': + if 'ConvNeXt' in module.backbone.__class__.__name__: + layer_id = get_stage_id_for_convnext(name, num_layers) + logger.info(f'set param {name} as id {layer_id}') + else: + raise NotImplementedError() + group_name = f'layer_{layer_id}_{group_name}' + + if group_name not in parameter_groups: + scale = decay_rate**(num_layers - layer_id - 1) + + parameter_groups[group_name] = { + 'weight_decay': this_weight_decay, + 'params': [], + 'param_names': [], + 'lr_scale': scale, + 'group_name': group_name, + 'lr': scale * self.base_lr, + } + + parameter_groups[group_name]['params'].append(param) + parameter_groups[group_name]['param_names'].append(name) + rank, _ = get_dist_info() + if rank == 0: + to_display = {} + for key in parameter_groups: + to_display[key] = { + 'param_names': parameter_groups[key]['param_names'], + 'lr_scale': parameter_groups[key]['lr_scale'], + 'lr': parameter_groups[key]['lr'], + 'weight_decay': parameter_groups[key]['weight_decay'], + } + logger.info(f'Param groups = {json.dumps(to_display, indent=2)}') + params.extend(parameter_groups.values()) diff --git a/object_detection/mmdet/models/backbones/__init__.py b/object_detection/mmdet/models/backbones/__init__.py new file mode 100644 index 0000000..2ddaa79 --- /dev/null +++ b/object_detection/mmdet/models/backbones/__init__.py @@ -0,0 +1,20 @@ +from .darknet import Darknet +from .detectors_resnet import DetectoRS_ResNet +from .detectors_resnext import DetectoRS_ResNeXt +from .hourglass import HourglassNet +from .hrnet import HRNet +from .regnet import RegNet +from .res2net import Res2Net +from .resnest import ResNeSt +from .resnet import ResNet, ResNetV1d +from .resnext import ResNeXt +from .ssd_vgg import SSDVGG +from .trident_resnet import TridentResNet +from .swin_transformer import SwinTransformer +from .vanillanet import Vanillanet + +__all__ = [ + 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net', + 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet', + 'ResNeSt', 'TridentResNet', 'SwinTransformer', 'Vanillanet' +] diff --git a/object_detection/mmdet/models/backbones/vanillanet.py b/object_detection/mmdet/models/backbones/vanillanet.py new file mode 100644 index 0000000..f9c0950 --- /dev/null +++ b/object_detection/mmdet/models/backbones/vanillanet.py @@ -0,0 +1,114 @@ +#Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. + +#This program is free software; you can redistribute it and/or modify it under the terms of the MIT License. + +#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MIT License for more details. + + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer, + constant_init, normal_init, trunc_normal_init) +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.utils.weight_init import trunc_normal_ +from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint, + load_state_dict) + +from ...utils import get_root_logger +from ..builder import BACKBONES + + +class activation(nn.ReLU): + def __init__(self, dim, act_num=3, norm_layer=nn.SyncBatchNorm): + super(activation, self).__init__() + self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1)) + self.bn = norm_layer(dim, eps=1e-6) + self.dim = dim + self.act_num = act_num + trunc_normal_(self.weight, std=.02) + + def forward(self, x): + return self.bn(torch.nn.functional.conv2d(super(activation, self).forward(x), self.weight, padding=self.act_num, groups=self.dim)) + +class Block(nn.Module): + def __init__(self, dim, dim_out, act_num=3, stride=2, norm_layer=nn.SyncBatchNorm): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=1), + norm_layer(dim, eps=1e-6), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(dim, dim_out, kernel_size=1), + norm_layer(dim_out, eps=1e-6) + ) + self.pool = nn.Identity() if stride == 1 else nn.MaxPool2d(stride) + self.act = activation(dim_out, act_num, norm_layer) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.pool(x) + x = self.act(x) + return x + + +@BACKBONES.register_module() +class Vanillanet(BaseModule): + def __init__(self, in_chans=3, act_num=3, dims=[96, 192, 384, 768], out_indices=[2,4,6], + strides=[2,2,2,1], norm_layer=nn.SyncBatchNorm, init_cfg=None, **kwargs): + super().__init__() + + self.out_indices = out_indices + self.init_cfg = init_cfg + + self.stem1 = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + norm_layer(dims[0], eps=1e-6), + ) + self.stem2 = nn.Sequential( + nn.Conv2d(dims[0], dims[0], kernel_size=1, stride=1), + norm_layer(dims[0], eps=1e-6), + activation(dims[0], act_num, norm_layer) + ) + + self.stages = nn.ModuleList() + for i in range(len(strides)): + stage = Block(dim=dims[i], dim_out=dims[i+1], act_num=act_num, stride=strides[i], norm_layer=norm_layer) + self.stages.append(stage) + self.depth = len(strides) + + def init_weights(self): + if self.init_cfg is None: + logger = get_root_logger() + logger.warn(f'No pre-trained weights for ' + f'{self.__class__.__name__}, ' + f'training start from scratch') + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init(m, 0, math.sqrt(2.0 / fan_out)) + else: + state_dict = torch.load(self.init_cfg.checkpoint, map_location='cpu') + msg = self.load_state_dict(state_dict['model_ema'], strict=False) + print(msg) + print('Successfully load backbone ckpt.') + + def forward(self, x): + outs = [] + x = self.stem1(x) + x = self.stem2(x) + for i in range(self.depth): + x = self.stages[i](x) + + if i in self.out_indices: + outs.append(x) + + return outs