This repository contains code for analysis of various pruning methods.
The code is written in PyTorch and supports distributed computing, has a CLI support which allows for easy execution from batch files.
Configuration is flexible (but still can be expanded) and configuration sweep (Hydra multirun) is also possible.
Experiments can be analyzed through robust logging using Hydra and Wandb.
To install the required packages, run the following command:
conda env create -f environment.yml
Conda/Anaconda is required to run the above command. If you don't have it installed, you can install it from here.
The pruning code is located in the pruning directory.
The entry point for the program is pruning_entry.py, the main file to run the pruning analysis.
You can configure the code from CLI, or modify the configs in the config directory.
The config is using Hydra, which is a configuration system for Python apps.
To run a single pruning experiment you need to provide the necessary parameters which depend on the method/scheduler to use.
You can check the configuration files in config directory for guidance.
You need to provide the model, dataset, path to dataset and checkpoint to use for the model.
- checkpoint path is provided through
_checkpoint_path
argument, for example:_checkpoint_path=checkpoints/resnet18_cifar100.pth
. - dataset path is provided through
dataset._path
argument. - model is chosen through
model
argument. Supported models and the code for them is defined in models directory. - dataset is chosen through
dataset
argument. Supported datasets and the code for them is defined in construct_dataset.py.
The configuration of the project is managed by Hydra and is divided into several parts, each corresponding to a Python file in the config
directory:
This is the main configuration file, it defines the main configuration class MainConfig
and registers it with Hydra.
It also registers the configurations for optimizers, datasets, pruning schedulers, pruning methods, and metrics.
Currently only magnitude methods are supported.
LnStructured
- structured pruning method, with the modifable norm and dimension which will be pruned.GlobalL1Unstructured
- unstructured global pruning, prunes all weights according to l1 norm.
This file defines the optimizers used in the project. The available optimizers are:
AdamW
SGD
You can specify the optimizer to use in the optimizer
field of the main configuration.
This file defines the pruning schedulers used in the project. The available schedulers are:
Scheduler are objects which decide how the steps for pruning iterations are calculated.
E.g. OneShotStepScheduler with step=0.6 will prune the 60% of the weights and do a one iteration of fine-tuning.
IterativeStepScheduler
OneShotStepScheduler
LogarithmicStepScheduler
ConstantStepScheduler
ManualScheduler
You can specify the scheduler to use in the pruning.scheduler
field of the main configuration.
The manual scheduler allows to provide user-provided values separately for every layer.
They are passed through pruning.scheduler.pruning_steps
argument.
Configuration for datasets like path, name and number of classes (for dynamic creation of models).
Also, arguments resize_value
and crop_value
are provided, which allow to resize and crop the images.
Common case for Imagenet1k
is dataset.resize_value=256
and dataset.crop_value=224
.
Currently supported datasets include:
cifar10
cifar100
imagenet1k
Currently the logging is logged using Hydra, but Wandb is supported and recommended.
Our checkpoints are available on dropbox.
Structured manual pruning with 3 repeats and single pruning iteration and early stopping with patience for 10 epochs.
python pruning_entry.py model=resnet18_cifar dataset=cifar100 _repeat=3 optimizer=sgd optimizer.learning_rate=0.001 _checkpoint_path=checkpoints/resnet18_cifar100.pth pruning.scheduler=manual pruning.finetune_epochs=100 dataloaders.batch_size=256 pruning.method=ln_structured 'pruning.scheduler.pruning_steps=[[0.0, 0.2, 0.2, 0.2, 0.2, 0.5, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.0]]' pruning.method.norm=2 early_stopper.enabled=True early_stopper.patience=10