Skip to content

Commit

Permalink
release ckpt & detection code
Browse files Browse the repository at this point in the history
  • Loading branch information
ggjy committed Jun 4, 2023
1 parent 7b6b8ef commit f6d5bd0
Show file tree
Hide file tree
Showing 8 changed files with 586 additions and 23 deletions.
21 changes: 8 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,8 @@ VanillaNet achieves comparable performance to prevalent computer vision foundati
| **VanillaNet-13** | 58.6 | 11.9 | 4.26 |1.33|0.82|0.67| 82.05 |

## Downstream Tasks
| Framework | Backbone | FLOPs(G) | #params(M) | FPS | AP<sup>b</sup> | AP<sup>m</sup> |
|:---:|:---:|:---:|:---:| :---:|:---:|:---:|
| RetinaNet | Swin-T | 245 | 38.5 | 27.5 | 41.5 | - |
| | VanillaNet-13 | 397 | 74.6 | 29.8 | 41.8 | - |
| Mask RCNN | [Swin-T](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | 267 | 47.8 | 28.2 | 42.7 | 39.3 |
| | VanillaNet-13 | 421 | 76.3 | 32.6 | 42.9 | 39.6 |

Please refer to [this page](https://github.com/huawei-noah/VanillaNet/object_detection).

VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segmentation** tasks.

Expand All @@ -63,8 +58,8 @@ VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segm
- [x] ImageNet-1K Training Code of VanillaNet-5 to VanillaNet-10
- [x] ImageNet-1K Pretrained Weights of VanillaNet-5 to VanillaNet-10
- [ ] ImageNet-1K Training Code of VanillaNet-11 to VanillaNet-13
- [ ] ImageNet-1K Pretrained Weights of VanillaNet-11 to VanillaNet-13
- [ ] Downstream Transfer (Detection, Segmentation) Code
- [x] ImageNet-1K Pretrained Weights of VanillaNet-11 to VanillaNet-13
- [x] Downstream Transfer (Detection, Segmentation) Code

## Results and Pre-trained Models
### ImageNet-1K trained models
Expand All @@ -77,11 +72,11 @@ VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segm
| VanillaNet-8 | 37.1 | 7.7 | 2.56 | 79.13 | [model](https://drive.google.com/file/d/1XNhe2LcNMjNZqBysGNvZLSsbKTrqWTw7/view?usp=sharing) |
| VanillaNet-9 | 41.4 | 8.6 | 2.91 | 79.87 | [model](https://drive.google.com/file/d/1DKifDZR5FqrEr7ICLPzuniQzu03hnnF_/view?usp=sharing) |
| VanillaNet-10 | 45.7 | 9.4 | 3.24 | 80.57 | [model](https://drive.google.com/file/d/1JskZU6otH_6NVXJHNe-74pEaZRVPxXlP/view?usp=sharing) |
| VanillaNet-11 | 50.0 | 10.3 | 3.59 | 81.08 | - |
| VanillaNet-12 | 54.3 | 11.1 | 3.82 | 81.55 | - |
| VanillaNet-13 | 58.6 | 11.9 | 4.26 | 82.05 | - |
| VanillaNet-13-1.5x | 127.8 | 26.5 | 7.83 | 82.53 | - |
| VanillaNet-13-1.5x&dagger; | 127.8 | 48.9 | 9.72 | 83.11 | - |
| VanillaNet-11 | 50.0 | 10.3 | 3.59 | 81.08 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_11.pth) |
| VanillaNet-12 | 54.3 | 11.1 | 3.82 | 81.55 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_12.pth) |
| VanillaNet-13 | 58.6 | 11.9 | 4.26 | 82.05 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13.pth) |
| VanillaNet-13-1.5x | 127.8 | 26.5 | 7.83 | 82.53 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_x1_5.pth) |
| VanillaNet-13-1.5x&dagger; | 127.8 | 48.9 | 9.72 | 83.11 | [model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_x1_5_ada_pool.pth) |

## Installation

Expand Down
24 changes: 14 additions & 10 deletions models/vanillanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@
class activation(nn.ReLU):
def __init__(self, dim, act_num=3, deploy=False):
super(activation, self).__init__()
self.act_num = act_num
self.deploy = deploy
self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1))
self.bias = None
self.bn = nn.BatchNorm2d(dim, eps=1e-6)
self.dim = dim
self.act_num = act_num
self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1))
if deploy:
self.bias = torch.nn.Parameter(torch.zeros(dim))
else:
self.bias = None
self.bn = nn.BatchNorm2d(dim, eps=1e-6)
weight_init.trunc_normal_(self.weight, std=.02)

def forward(self, x):
if self.deploy:
return torch.nn.functional.conv2d(
super(activation, self).forward(x),
self.weight, self.bias, padding=(self.act_num*2 + 1)//2, groups=self.dim)
self.weight, self.bias, padding=self.act_num, groups=self.dim)
else:
return self.bn(torch.nn.functional.conv2d(
super(activation, self).forward(x),
Expand Down Expand Up @@ -74,7 +77,7 @@ def __init__(self, dim, dim_out, act_num=3, stride=2, deploy=False, ada_pool=Non
else:
self.pool = nn.Identity() if stride == 1 else nn.AdaptiveMaxPool2d((ada_pool, ada_pool))

self.act = activation(dim_out, act_num)
self.act = activation(dim_out, act_num, deploy=self.deploy)

def forward(self, x):
if self.deploy:
Expand Down Expand Up @@ -120,14 +123,15 @@ def __init__(self, in_chans=3, num_classes=1000, dims=[96, 192, 384, 768],
drop_rate=0, act_num=3, strides=[2,2,2,1], deploy=False, ada_pool=None, **kwargs):
super().__init__()
self.deploy = deploy
stride, padding = (4, 0) if not ada_pool else (3, 1)
if self.deploy:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
activation(dims[0], act_num)
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=stride, padding=padding),
activation(dims[0], act_num, deploy=self.deploy)
)
else:
self.stem1 = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=stride, padding=padding),
nn.BatchNorm2d(dims[0], eps=1e-6),
)
self.stem2 = nn.Sequential(
Expand Down Expand Up @@ -304,6 +308,6 @@ def vanillanet_13_x1_5_ada_pool(pretrained=False, in_22k=False, **kwargs):
model = VanillaNet(
dims=[128*6, 128*6, 256*6, 512*6, 512*6, 512*6, 512*6, 512*6, 512*6, 512*6, 1024*6, 1024*6],
strides=[1,2,2,1,1,1,1,1,1,2,1],
ada_pool=[0,40,20,0,0,0,0,0,0,10,0],
ada_pool=[0,38,19,0,0,0,0,0,0,10,0],
**kwargs)
return model
36 changes: 36 additions & 0 deletions object_detection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# COCO Object detection with VanillaNet

## Getting started

We add VanillaNet model and config files based on [mmdetection-2.x](https://github.com/open-mmlab/mmdetection/tree/2.x). Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/2.x/docs/en/get_started.md) for mmdetection installation and dataset preparation instructions.

## Results and Fine-tuned Models

| Framework | Backbone | FLOPs(G) | Params(M) | FPS | AP<sup>b</sup> | AP<sup>m</sup> | Model |
|:---:|:---:|:---:|:---:| :---:|:---:|:---:|:---:|
| RetinaNet | Swin-T | 244.8 | 38.5 | 27.5 | 41.5 | - |-|
| | VanillaNet-13 | 396.9 | 75.4 | 29.8 | 43.0 | - | [log](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/retinanet_vanillanet_13.log.json)/[model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/retinanet_vanillanet_13.pth) |
| Mask RCNN | [Swin-T](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/tree/master) | 263.8 | 47.8 | 28.2 | 43.7 | 39.8 |-|
| | ConvNeXtV2-Nano | 220.6 | 35.2 | 34.4 | 43.3 | 39.4 |-|
| | VanillaNet-13 | 420.7 | 77.1 | 32.6 | 44.3 | 40.1 | [log](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/mask_rcnn_vanillanet_13.log.json)/[model](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/mask_rcnn_vanillanet_13.pth) |


### Training

You can download the ImageNet pre-trained [checkpoint](https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_act_num_4_kd_pretrain.pth) for VanillaNet-13(act_num=4), which is trained via [knowledge distillation(this paper)](https://arxiv.org/pdf/2305.15781.pdf).

For example, to train a Mask R-CNN model with VanillaNet backbone and 8 gpus, run:
```
python -m torch.distributed.launch --nproc_per_node=8 tools/train.py configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py --gpus 8 --launcher pytorch --work-dir <WORK_DIR>
```

### Inference

For example, test with single-gpu, run:
```
python -m torch.distributed.launch --nproc_per_node=1 tools/test.py configs/vanillanet/mask_rcnn_vanillanet_13_mstrain_480-1024_adamw.py <CHECKPOINT_FILE> --launcher pytorch --eval bbox segm
```

## Acknowledgment

This code is built based on [mmdetection](https://github.com/open-mmlab/mmdetection), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) repositories.
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.

#This program is free software; you can redistribute it and/or modify it under the terms of the MIT License.

#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MIT License for more details.


_base_ = [
'../_base_/models/mask_rcnn_r50_fpn.py',
'../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]

# you can download ckpt from:
# https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_act_num_4_kd_pretrain.pth
checkpoint_file = '/your_path_to/vanillanet_13_act_num_4_kd_pretrain.pth'

model = dict(
backbone=dict(
_delete_=True,
type='Vanillanet',
act_num=4, # enlarge act_num for better downstream performance
dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 1024*4, 1024*4],
out_indices=[0, 1, 8, 10],
strides=[1,2,2,1,1,1,1,1,1,2,1],
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
neck=dict(in_channels=[128*4, 256*4, 512*4, 1024*4]))

# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(1333, 480), (1333, 512), (1333, 544), (1333, 576), (1333, 608), (1333, 640), (1333, 672), (1333, 704), (1333, 736), (1333, 768), (1333, 800), (1333, 832), (1333, 864), (1333, 896), (1333, 928), (1333, 960), (1333, 992), (1333, 1024)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
persistent_workers=True,
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(metric=['bbox', 'segm'])

optimizer = dict(
_delete_=True,
constructor='LearningRateDecayOptimizerConstructor',
type='AdamW',
lr=1.3e-4,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.6,
'decay_type': 'layer_wise',
'num_layers': 6
})

lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[10, 12])
runner = dict(max_epochs=12)

log_config = dict(
interval=200,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved.

#This program is free software; you can redistribute it and/or modify it under the terms of the MIT License.

#This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MIT License for more details.


_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

# you can download ckpt from:
# https://github.com/huawei-noah/VanillaNet/releases/download/ckpt/vanillanet_13_act_num_4_kd_pretrain.pth
checkpoint_file = '/your_path_to/vanillanet_13_act_num_4_kd_pretrain.pth'

model = dict(
backbone=dict(
_delete_=True,
type='Vanillanet',
act_num=4, # enlarge act_num for better downstream performance
dims=[128*4, 128*4, 256*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 512*4, 1024*4, 1024*4],
out_indices=[1, 8, 10],
strides=[1,2,2,1,1,1,1,1,1,2,1],
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
neck=dict(in_channels=[256*4, 512*4, 1024*4], start_level=0, num_outs=5))

# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 480), (1333, 512), (1333, 544), (1333, 576), (1333, 608), (1333, 640), (1333, 672), (1333, 704), (1333, 736), (1333, 768), (1333, 800), (1333, 832), (1333, 864), (1333, 896), (1333, 928), (1333, 960), (1333, 992), (1333, 1024)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')

optimizer = dict(
_delete_=True,
constructor='LearningRateDecayOptimizerConstructor',
type='AdamW',
lr=8e-5,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg={
'decay_rate': 0.6,
'decay_type': 'layer_wise',
'num_layers': 6
})

optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[10, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)

log_config = dict(
interval=200,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
Loading

0 comments on commit f6d5bd0

Please sign in to comment.