Skip to content

A simple profiler for PyTorch models that can profile: FLOPs for the forward pass, FLOPs for the backward pass and peak memory consumption.

License

Notifications You must be signed in to change notification settings

gslama12/pytorch-model-profiler

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch Model Profiler

A simple profiler for PyTorch (vision) models.

This combines the following implementations:

Profiling capabilities:

  • Floating-point operations (FLOPs) for the forward pass
  • FLOPs for the backward pass
  • Peak memory consumption for one complete training step (fwd + backward)

Supported Layers:

  • nn.Linear
  • nn.Conv1d, nn.Conv2d, nn.Conv3d
  • nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm
  • nn.ReLU, nn.ReLU6, nn.LeakyReLU
  • nn.Sigmoid, nn.Tanh, Hswish, Hsigmoid

Usage

Installation

pip install git+https://github.com/gslama12/pytorch-model-profiler

Example

import torch
from model_profiler import Profiler

resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18')
optimizer = torch.optim.SGD(params=resnet.parameters())  # optimizer is optional
p = Profiler(resnet, optimizer=optimizer, flops_per_layer=True)
p.profile(torch.rand(1, 3, 244, 244))  #specify model input

Tested Models:

  • MobileNetV1
  • MobileNetV2
  • MobileNetV3
  • ResNet
  • WideResNet
  • GoogLeNet
  • AlexNet
  • VGG-Nets

Tested PEFT methods:

  • LoRA
  • DoRA
  • GaLore
  • Batch normalization + head-only fine-tuning

About

A simple profiler for PyTorch models that can profile: FLOPs for the forward pass, FLOPs for the backward pass and peak memory consumption.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages