Filter Sketch for Network Pruning (Link).
Pruning neural network model via filter sketch.
Any problem, free to contact the authors via emails: [email protected] or [email protected]. Do not post issues with github as much as possible, just in case that I could not receive the emails from github thus ignore the posted issues.
If you find FilterSketch useful in your research, please consider citing:
@article{lin2020filter,
title={Filter Sketch for Network Pruning},
author={Lin, Mingbao and Ji, Rongrong and Li, Shaojie and Ye, Qixiang and Tian, Yonghong and Liu, Jianzhuang and Tian, Qi},
journal={arXiv preprint arXiv:2001.08514},
year={2020}
}
We provide the pre-trained models used in our paper.
| ResNet56 | ResNet110 |GoogLeNet |
| ResNet50 |
We provide our pruned models in the experiments, along with their training loggers and configurations.
DataSet | Sketch Rate | Flops (Prune Rate) |
Params (Prune Rate) |
Top-1 Accuracy | Top-5 Accuracy | Download | |
---|---|---|---|---|---|---|---|
ResNet56 | CIFAR-10 | [0.6]*27 | 73.36M(41.5%) | 0.50M(41.2%) | 93.19% | - | Link |
ResNet110 | CIFAR-10 | [0.9]*3+[0.4]*24+[0.3]*24+[0.9]*3 | 92.84M(63.3%) | 0.69M(59.9%) | 93.44% | - | Link |
GoogLeNet | CIFAR-10 | [0.25]*9 | 0.59B(61.1%) | 2.61M(57.6%) | 94.88% | - | Link |
ResNet50 | ImageNet | [0.2]*16 | 0.93B(77.3%) | 7.18M(71.8%) | 69.43% | 89.23% | Link |
ResNet50 | ImageNet | [0.4]*16 | 1.51B(63.1%) | 10.40M(59.2%) | 73.04% | 91.18% | Link |
ResNet50 | ImageNet | [0.6]*16 | 2.23B(45.5%) | 14.53M(43.0%) | 74.68% | 92.17% | Link |
ResNet50 | ImageNet | [0.7]*16 | 2.64B(35.5%) | 16.95M(33.5%) | 75.22% | 92.41% | Link |
Performance of FilterSketch using ResNet-56 under different compression rates.
DataSet | Sketch Rate | Flops (Prune Rate) |
Params (Prune Rate) |
Top-1 Accuracy | Download |
---|---|---|---|---|---|
CIFAR-10 | [0.1]*27 | 11.43M(91.0%) | 0.08M(90.45%) | 87.38% | Link |
CIFAR-10 | [0.2]*27 | 24.54M(80.6%) | 0.16M(81.0%) | 90.19% | Link |
CIFAR-10 | [0.3]*27 | 35.61M(71.9%) | 0.25M(70.6%) | 91.65% | Link |
CIFAR-10 | [0.4]*27 | 48.72M(61.5%) | 0.33M(61.1%) | 92.00% | Link |
CIFAR-10 | [0.5]*27 | 63.78M(49.6%) | 0.43M(49.8%) | 92.29% | Link |
CIFAR-10 | [0.6]*27 | 73.36M(41.5%) | 0.50M(41.2%) | 93.19% | Link |
CIFAR-10 | [0.7]*27 | 87.31M(31.0%) | 0.59M(31.1%) | 93.36% | Link |
CIFAR-10 | [0.8]*27 | 98.40M(22.3%) | 0.68M(20.8%) | 93.40% | Link |
CIFAR-10 | [0.9]*27 | 111.5M(11.9%) | 0.75M(11.3%) | 93.44% | Link |
CIFAR-10 | [0.9]*3+[0.1]*10+[0.1]*10+[0.6]*4 | 32.47M(74.4%) | 0.24M(71.8%) | 91.20% | Link |
CIFAR-10 | [0.7]*3+[0.4]*10+[0.4]*10+[0.9]*4 | 62.63M(50.5%) | 0.48M(43.3%) | 92.94% | Link |
CIFAR-10 | [0.8]*3+[0.5]*10+[0.8]*10+[0.9]*4 | 88.05M(30.4%) | 0.68M(20.6%) | 93.65% | Link |
The code has been tested using Pytorch1.3 and CUDA10.0 on Ubuntu16.04.
You can run the following code to sketch model on Cifar-10:
python sketch_cifar.py
--data_set cifar10
--data_path ../data/cifar10/
--sketch_model ./experiment/pretrain/resnet56.pt
--job_dir ./experiment/resnet56/sketch/
--arch resnet
--cfg resnet56
--lr 0.01
--lr_decay_step 50 100
--num_epochs 150
--gpus 0
--sketch_rate [0.6]*27
--weight_norm_method l2
You can run the following code to sketch model on Imagenet:
python sketch_imagenet.py
--data_set imagenet
--data_path ../data/imagenet/
--sketch_model ./experiment/pretrain/resnet50.pth
--job_dir ./experiment/resnet50/sketch/
--arch resnet
--cfg resnet50
--lr 0.1
--lr_decay_step 30 60
--num_epochs 90
--gpus 0
--sketch_rate [0.6]*16
--weight_norm_method l2
Follow the command below to verify our pruned models:
python test.py
--data_set cifar10
--data_path ../data/cifar10
--arch resnet
--cfg resnet56
--sketch_model ./experiment/result/sketch_resnet56.pt
--sketch_rate [0.6]*27
--gpus 0
You can use the following command to install the thop python package when you need to calculate the flops of the model:
pip install thop
python get_flops_params.py
--data_set cifar10
--input_image_size 32
--arch resnet
--cfg resnet56
--sketch_rate [0.6]*27
The number of pruning rates required for different networks is as follows:
CIFAR-10 | ImageNet | |
---|---|---|
ResNet56 | 27 | - |
ResNet110 | 54 | - |
GoogLeNet | 9 | - |
ResNet50 | - | 16 |
optional arguments:
-h, --help show this help message and exit
--gpus GPUS [GPUS ...]
Select gpu_id to use. default:[0]
--data_set DATA_SET Select dataset to train. default:cifar10
--data_path DATA_PATH
The dictionary where the input is stored.
default:/home/lishaojie/data/cifar10/
--job_dir JOB_DIR The directory where the summaries will be stored.
default:./experiments
--arch ARCH Architecture of model. default:resnet
--cfg CFG Detail architecuture of model. default:resnet56
--num_epochs NUM_EPOCHS
The num of epochs to train. default:150
--train_batch_size TRAIN_BATCH_SIZE
Batch size for training. default:128
--eval_batch_size EVAL_BATCH_SIZE
Batch size for validation. default:100
--momentum MOMENTUM Momentum for MomentumOptimizer. default:0.9
--lr LR Learning rate for train. default:1e-2
--lr_decay_step LR_DECAY_STEP [LR_DECAY_STEP ...]
the iterval of learn rate. default:50, 100
--weight_decay WEIGHT_DECAY
The weight decay of loss. default:5e-4
--start_conv START_CONV
The index of Conv to start sketch, index starts from
0. default:1
--sketch_rate SKETCH_RATE
The proportion of each layer reserved after sketching
convolution layer sketch. default:None
--sketch_model SKETCH_MODEL
Path to the model wait for sketch. default:None
--weight_norm_method WEIGHT_NORM_METHOD
Select the weight norm method. default:None
Optional:l2