Skip to content

Commit

Permalink
Code Upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianlong-Chen committed Apr 15, 2021
1 parent 2e6d3a4 commit e4e0a23
Show file tree
Hide file tree
Showing 23 changed files with 2,858 additions and 1 deletion.
Binary file added Figs/nfbnn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Figs/res.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
58 changes: 57 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,62 @@

[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)

Codes for this paper [BNN - BN = ? Training Binary Neural Networks without Batch Normalization](). [CVPRW 2021]
Codes for this paper [BNN - BN = ? Training Binary Neural Networks without Batch Normalization](). [CVPR BiVision Workshop 2021]

Tianlong Chen, Zhenyu Zhang, Xu Ouyang, Zechun Liu, Zhiqiang Shen, Zhangyang Wang.



## Overview

Batch normalization (BN) is a key facilitator and considered essential for state-of-the-art binary neural networks (BNN). However, the BN layer is costly to calculate and is typically implemented with non-binary parameters, leaving a hurdle for the efficient implementation of BNN training. It also introduces undesirable dependence between samples within each batch.

Inspired by the latest advance on Batch Normalization Free (BN-Free) training, we extend their framework to training BNNs, and for the first time demonstrate that BNs can be completed removed from BNN training and inference regimes. By plugging in and customizing techniques including adaptive gradient clipping, scale weight standardization, and specialized bottleneck block, a **BN-free BNN** is capable of maintaining competitive accuracy compared to its BN-based counterpart. Experimental results can be found in [our paper]().

<img src = "Figs/res.png" align = "center" width="60%" hight="60%">



## BN-Free Binary Neural Networks

<img src = "Figs/nfbnn.png" align = "center" width="100%" hight="60%">



## Reproduce

### Environment

```
pytorch == 1.5.0
torchvision == 0.6.0
timm == 0.4.5
```

### Training on ImageNet

```
./script/imagenet_reactnet_A_bf.sh (BN-Free ReActNet-A)
./script/imagenet_reactnet_A_bn.sh (with BN ReActNet-A)
./script/imagenet_reactnet_A_none.sh (without BN ReActNet-A)
```



## Citation

```
TBD
```



## Acknowledgement

https://github.com/liuzechun/ReActNet

https://github.com/liuzechun/Bi-Real-net

https://github.com/vballoli/nfnets-pytorch

https://github.com/deepmind/deepmind-research/tree/master/nfnets
24 changes: 24 additions & 0 deletions agc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch

def unitwise_norm(x, norm_type=2.0):
if x.ndim <= 1:
return x.norm(norm_type)
else:
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
# might need special cases for other weights (possibly MHA) where this may not be true
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)


def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
for p in parameters:
if p.grad is None:
continue
p_data = p.detach()
g_data = p.grad.detach()
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
grad_norm = unitwise_norm(g_data, norm_type=norm_type)
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
p.grad.detach().copy_(new_grads)
139 changes: 139 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os
import numpy as np
from PIL import Image
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from torch.utils.data import DataLoader

__all__ = ['cifar10_dataloaders', 'cifar100_dataloaders', 'imagenet_dataloaders']

#lighting data augmentation
imagenet_pca = {
'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
'eigvec': np.asarray([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}

class Lighting(object):
def __init__(self, alphastd,
eigval=imagenet_pca['eigval'],
eigvec=imagenet_pca['eigvec']):
self.alphastd = alphastd
assert eigval.shape == (3,)
assert eigvec.shape == (3, 3)
self.eigval = eigval
self.eigvec = eigvec

def __call__(self, img):
if self.alphastd == 0.:
return img
rnd = np.random.randn(3) * self.alphastd
rnd = rnd.astype('float32')
v = rnd
old_dtype = np.asarray(img).dtype
v = v * self.eigval
v = v.reshape((3, 1))
inc = np.dot(self.eigvec, v).reshape((3,))
img = np.add(img, inc)
if old_dtype == np.uint8:
img = np.clip(img, 0, 255)
img = Image.fromarray(img.astype(old_dtype), 'RGB')
return img

def __repr__(self):
return self.__class__.__name__ + '()'


def cifar10_dataloaders(batch_size=128, data_dir = 'datasets/cifar10', worker=4):

normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.2470, 0.2435, 0.2616])

train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])

test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])

train_set = CIFAR10(data_dir, train=True, transform=train_transform, download=True)
test_set = CIFAR10(data_dir, train=False, transform=test_transform, download=True)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=worker, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=worker, pin_memory=True)

return train_loader, test_loader

def cifar100_dataloaders(batch_size=128, data_dir = 'datasets/cifar100', worker=4):

normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
std=[0.2673, 0.2564, 0.2762])

train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])

test_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])

train_set = CIFAR100(data_dir, train=True, transform=train_transform, download=True)
test_set = CIFAR100(data_dir, train=False, transform=test_transform, download=True)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=worker, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=worker, pin_memory=True)

return train_loader, test_loader

def imagenet_dataloaders(batch_size=128, data_dir = 'datasets/cifar100', worker=4):

traindir = os.path.join(data_dir, 'train')
valdir = os.path.join(data_dir, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

# data augmentation
crop_scale = 0.08
lighting_param = 0.1
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)),
Lighting(lighting_param),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize])

train_dataset = ImageFolder(
traindir,
transform=train_transforms)

train_loader = DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=worker, pin_memory=True)

# load validation data
val_loader = DataLoader(
ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=batch_size, shuffle=False,
num_workers=worker, pin_memory=True)

return train_loader, val_loader



141 changes: 141 additions & 0 deletions layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


__all__ = ['LambdaLayer', 'ScaledStdConv2d', 'HardBinaryScaledStdConv2d', 'LearnableBias','BinaryActivation', 'HardBinaryConv']

def get_weight(module):
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (module.weight - mean) / (std + module.eps)
return weight

# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding

class LambdaLayer(nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd

def forward(self, x):
return self.lambd(x)


class ScaledStdConv2d(nn.Conv2d):
"""Conv2d layer with Scaled Weight Standardization.
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
https://arxiv.org/abs/2101.08692
NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
"""

def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=False, gamma=1.0, eps=1e-5, use_layernorm=False):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel

def get_weight(self):
if self.use_layernorm:
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
return self.gain * weight

def forward(self, x):
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

class HardBinaryScaledStdConv2d(nn.Module):

def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1, gamma=1.0, eps=1e-5, use_layernorm=False):
super(HardBinaryScaledStdConv2d, self).__init__()
self.stride = stride
self.padding = padding
self.shape = (out_chn, in_chn, kernel_size, kernel_size)
self.weight = nn.Parameter(torch.rand(self.shape) * 0.001, requires_grad=True)

self.gain = nn.Parameter(torch.ones(out_chn, 1, 1, 1))
self.scale = gamma * self.weight[0].numel() ** -0.5
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm

def get_weight(self):
if self.use_layernorm:
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)

scaling_factor = torch.mean(torch.mean(torch.mean(abs(weight),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True)
scaling_factor = scaling_factor.detach()
binary_weights_no_grad = scaling_factor * torch.sign(weight)
cliped_weights = torch.clamp(weight, -1.0, 1.0)
binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights

return self.gain * binary_weights

def forward(self, x):

return F.conv2d(x, self.get_weight(), stride=self.stride, padding=self.padding)

class LearnableBias(nn.Module):
def __init__(self, out_chn):
super(LearnableBias, self).__init__()
self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)

def forward(self, x):
out = x + self.bias.expand_as(x)
return out

class BinaryActivation(nn.Module):
def __init__(self):
super(BinaryActivation, self).__init__()

def forward(self, x):
out_forward = torch.sign(x)
#out_e1 = (x^2 + 2*x)
#out_e2 = (-x^2 + 2*x)
out_e_total = 0
mask1 = x < -1
mask2 = x < 0
mask3 = x < 1
out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
out = out_forward.detach() - out3.detach() + out3

return out

class HardBinaryConv(nn.Module):
def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1):
super(HardBinaryConv, self).__init__()
self.stride = stride
self.padding = padding
self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size
self.shape = (out_chn, in_chn, kernel_size, kernel_size)
#self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True)
self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True)

def forward(self, x):
real_weights = self.weight
scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True)
scaling_factor = scaling_factor.detach()
binary_weights_no_grad = scaling_factor * torch.sign(real_weights)
cliped_weights = torch.clamp(real_weights, -1.0, 1.0)
binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding)

return y

Loading

0 comments on commit e4e0a23

Please sign in to comment.