Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
HantingChen authored May 22, 2023
1 parent 4e77afb commit cf44e30
Show file tree
Hide file tree
Showing 9 changed files with 1,784 additions and 9 deletions.
12 changes: 12 additions & 0 deletions License
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Copyright (c) 2023. Huawei Technologies Co., Ltd.
All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
155 changes: 146 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
# VanillaNet: the Power of Minimalism in Deep Learning
<p align="left">
<a href="https://arxiv.org/abs/2303.16900" alt="arXiv">
<img src="https://img.shields.io/badge/arXiv-2203.16900-b31b1b.svg?style=flat" /></a>
</p>


Official PyTorch implementation of **VanillaNet**, from the following paper:\
**VanillaNet: the Power of Minimalism in Deep Learning**\
*Hanting chen, [Yunhe Wang](https://www.wangyunhe.site/), Jianyuan Guo and Dacheng Tao*
[VanillaNet: the Power of Minimalism in Deep Learning ](https://arxiv.org/abs/)\
Hanting chen, [Yunhe Wang](https://www.wangyunhe.site/), Jianyuan Guo and Dacheng Tao

<img src="pic/structure.PNG" width="800px"/>

VanillaNet is an innovative neural network architecture that focuses on **simplicity** and **efficiency**. Moving away from complex features such as **shortcuts** and **attention** mechanisms, VanillaNet uses a reduced number of layers while still **maintaining excellent performance**. This project showcases that it's possible to achieve effective results with a lean architecture, thereby setting a new path in the field of computer vision and challenging the status quo of foundation models.

## Comparison of Depth and Speed

<img src="pic/depth.PNG" width="360px"/> <img src="pic/speed.PNG" width="300px"/>
<img src="pic/depth.PNG" width="480px"/> <img src="pic/speed.PNG" width="400px"/>

VanillaNet, in its robust simplicity, offers comparable precision to prevalent computer vision foundation models, yet boasts a **reduced depth and enhanced processing speed** (test on Nvidia A100 GPU with batch size 1):
VanillaNet, in its robust simplicity, offers comparable precision to prevalent computer vision foundation models, yet boasts a **reduced depth and enhanced processing speed**:
- **9-layers'** VanillaNet achieves about **80%** Top-1 accuracy with **3.59ms**, over **100%** speed increase compared to ResNet-50 (**7.64ms**).
- **13 layers'** VanillaNet achieves about **83%** Top-1 accuracy with **9.72ms**, over **100%** speed increase compared to Swin-T (**20.25ms**).

## Downstream Tasks

| Framework | Backbone | FLOPs(G) | #params(M) | FPS | AP<sup>b</sup> | AP<sup>m</sup> |
|:---:|:---:|:---:|:---:| :---:|:---:|:---:|
| RetinaNet | Swin-T| 245 | 38.5 | 27.5 | 41.5 | - |
| ReTinaNet | Swin-T| 245 | 38.5 | 27.5 | 41.5 | - |
| | VanillaNet-11 | 386 | 67.0 | 30.8 | 41.8 | - |
| Mask RCNN | ConvNeXtV2-N | 221 | 35.2 | 31.7 | 42.7 | 38.9 |
| | [Swin-T](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | 267 | 47.8 | 28.2 | 42.7 | 39.3 |
Expand All @@ -32,9 +36,9 @@ VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segm


## Catalog
- [ ] ImageNet-1K Testing Code
- [ ] ImageNet-1K Training Code of VanillaNet-5 to VanillaNet-10
- [ ] ImageNet-1K Pretrained Weights of VanillaNet-5 to VanillaNet-10
- [x] ImageNet-1K Testing Code
- [x] ImageNet-1K Training Code of VanillaNet-5 to VanillaNet-10
- [x] ImageNet-1K Pretrained Weights of VanillaNet-5 to VanillaNet-10
- [ ] ImageNet-1K Training Code of VanillaNet-11 to VanillaNet-13
- [ ] ImageNet-1K Pretrained Weights of VanillaNet-11 to VanillaNet-13
- [ ] Downstream Transfer (Detection, Segmentation) Code
Expand All @@ -54,7 +58,7 @@ VanillaNet achieves a higher Frames Per Second (FPS) in **detection** and **segm
| VanillaNet-12 | 54.3 | 11.1 | 3.82 | 81.55 | - |
| VanillaNet-13 | 58.6 | 11.9 | 4.26 | 82.05 | - |
| VanillaNet-13-1.5x | 127.8 | 26.5 | 7.83 | 82.53 | - |
| VanillaNet-13-1.5x&dagger; | 127.8 | 48.5 | 9.72 | 83.11 | - |
| VanillaNet-13-1.5x&dagger; | 127.8 | 9.72 | 198M | 83.11 | - |

## Installation

Expand All @@ -72,6 +76,139 @@ pip install tensorboardX
pip install terminaltables
```

## Testing

We give an example evaluation command for VanillaNet-5:

without deploy:

```
python -m torch.distributed.launch --nproc_per_node=1 main.py --model vanillanet_5 --data_path /path/to/imagenet-1k/ --real_labels /path/to/imagenet_real_labels.json --finetune /path/to/vanillanet_5.pth --eval True --model_key model_ema --crop_pct 0.875
```

with deploy:
```
python -m torch.distributed.launch --nproc_per_node=1 main.py --model vanillanet_5 --data_path /path/to/imagenet-1k/ --real_labels /path/to/imagenet_real_labels.json --finetune /path/to/vanillanet_5.pth --eval True --model_key model_ema --crop_pct 0.875 --switch_to_deploy /path/to/vanillanet_5_deploy.pth
```

## Training

You can use the following command to train VanillaNet-5 on a single machine with 8 GPUs:
```
python -m torch.distributed.launch --nproc_per_node=8 main.py \
--model vanillanet_5 \
--data_path /path/to/imagenet-1k \
--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
--lr 3.5e-3 --weight_decay 0.35 --drop 0.05 \
--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.1 --bce_loss \
--output_dir /path/to/save_results \
--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
--use_amp true
```

- Here, the effective batch size = `--nproc_per_node` * `--batch_size` * `--update_freq`. In the example above, the effective batch size is `8*128*1 = 1024`.

To train other VanillaNet variants, `--model` need to be changed. Examples are given below.

<details>
<summary>
VanillaNet-6
</summary>

```
python -m torch.distributed.launch --nproc_per_node=8 main.py \
--model vanillanet_6 \
--data_path /path/to/imagenet-1k \
--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
--lr 4.8e-3 --weight_decay 0.32 --drop 0.05 \
--layer_decay 0.8 --layer_decay_num_layers 4 \
--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.15 --bce_loss \
--output_dir /path/to/save_results \
--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
--use_amp true
```
</details>

<details>
<summary>
VanillaNet-7
</summary>

```
python -m torch.distributed.launch --nproc_per_node=8 main.py \
--model vanillanet_7 \
--data_path /path/to/imagenet-1k \
--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
--lr 4.7e-3 --weight_decay 0.35 --drop 0.05 \
--layer_decay 0.8 --layer_decay_num_layers 5 \
--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.4 --bce_loss \
--output_dir /path/to/save_results \
--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
--use_amp true
```
</details>

<details>
<summary>
VanillaNet-8
</summary>

```
python -m torch.distributed.launch --nproc_per_node=8 main.py \
--model vanillanet_8 \
--data_path /path/to/imagenet-1k \
--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
--lr 3.5e-3 --weight_decay 0.3 --drop 0.05 \
--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.4 --bce_loss \
--output_dir /path/to/save_results \
--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
--use_amp true
```
</details>

<details>
<summary>
VanillaNet-9
</summary>

```
python -m torch.distributed.launch --nproc_per_node=8 main.py \
--model vanillanet_9 \
--data_path /path/to/imagenet-1k \
--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
--lr 3.5e-3 --weight_decay 0.3 --drop 0.05 \
--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.4 --bce_loss \
--output_dir /path/to/save_results \
--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
--use_amp true
```
</details>

<details>
<summary>
VanillaNet-10
</summary>

```
python -m torch.distributed.launch --nproc_per_node=8 main.py \
--model vanillanet_10 \
--data_path /path/to/imagenet-1k \
--batch_size 128 --update_freq 1 --epochs 300 --decay_epochs 100 \
--lr 3.5e-3 --weight_decay 0.25 --drop 0.05 \
--opt lamb --aa rand-m7-mstd0.5-inc1 --mixup 0.4 --bce_loss \
--output_dir /path/to/save_results \
--model_ema true --model_ema_eval true --model_ema_decay 0.99996 \
--use_amp true
```
</details>


### Acknowledgement

This repository is built using the [timm](https://github.com/rwightman/pytorch-image-models) library, [DeiT](https://github.com/facebookresearch/deit), [BEiT](https://github.com/microsoft/unilm/tree/master/beit), [RegVGG](https://github.com/DingXiaoH/RepVGG), and [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) repositories.

### License
This project is released under the MIT license. Please see the [LICENSE](License) file for more information.

### Citation
If our work is useful for your research, please consider citing:
31 changes: 31 additions & 0 deletions THIRD_PARTY_OPEN_SOURCE_SOFTWARE_NOTICE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Please note we provide an open source software notice for the third party open source software along with this software and/or this software component contributed by Huawei (in the following just “this SOFTWARE”). The open source software licenses are granted by the respective right holders.

Warranty Disclaimer
THE OPEN SOURCE SOFTWARE IN THIS SOFTWARE IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS.

Copyright Notice and License Texts
Software: ConvNeXt
Copyright notice:
Copyright (c) 2022

License: MIT License

Copyright (c) Meta Platforms, Inc. and affiliates.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
96 changes: 96 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import os
from torchvision import datasets, transforms

from timm.data.constants import \
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.data import create_transform

def build_dataset(is_train, args):
transform = build_transform(is_train, args)

print("Transform = ")
if isinstance(transform, tuple):
for trans in transform:
print(" - - - - - - - - - - ")
for t in trans.transforms:
print(t)
else:
for t in transform.transforms:
print(t)
print("---------------------------")

if args.data_set == 'CIFAR':
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
nb_classes = 100
elif args.data_set == 'IMNET':
print("reading from datapath", args.data_path)
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif args.data_set == "image_folder":
root = args.data_path if is_train else args.eval_data_path
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = args.nb_classes
assert len(dataset.class_to_idx) == nb_classes
else:
raise NotImplementedError()
print("Number of the class = %d" % nb_classes)

return dataset, nb_classes


def build_transform(is_train, args):
resize_im = args.input_size > 32
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD

if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.train_interpolation,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
if not resize_im:
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform

t = []
if resize_im:
# warping (no cropping) when evaluated at 384 or larger
if args.input_size >= 384:
t.append(
transforms.Resize((args.input_size, args.input_size),
interpolation=transforms.InterpolationMode.BICUBIC),
)
print(f"Warping {args.input_size} size input images...")
else:
if args.crop_pct is None:
args.crop_pct = 224 / 256
size = int(args.input_size / args.crop_pct)
t.append(
# to maintain same ratio w.r.t. 224 images
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
)
t.append(transforms.CenterCrop(args.input_size))

t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)
Loading

0 comments on commit cf44e30

Please sign in to comment.