diff --git a/Figs/nfbnn.png b/Figs/nfbnn.png new file mode 100644 index 0000000..474737e Binary files /dev/null and b/Figs/nfbnn.png differ diff --git a/Figs/res.png b/Figs/res.png new file mode 100644 index 0000000..fb4777f Binary files /dev/null and b/Figs/res.png differ diff --git a/README.md b/README.md index 2a008fe..29f1201 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,62 @@ [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) -Codes for this paper [BNN - BN = ? Training Binary Neural Networks without Batch Normalization](). [CVPRW 2021] +Codes for this paper [BNN - BN = ? Training Binary Neural Networks without Batch Normalization](). [CVPR BiVision Workshop 2021] Tianlong Chen, Zhenyu Zhang, Xu Ouyang, Zechun Liu, Zhiqiang Shen, Zhangyang Wang. + + + +## Overview + +Batch normalization (BN) is a key facilitator and considered essential for state-of-the-art binary neural networks (BNN). However, the BN layer is costly to calculate and is typically implemented with non-binary parameters, leaving a hurdle for the efficient implementation of BNN training. It also introduces undesirable dependence between samples within each batch. + +Inspired by the latest advance on Batch Normalization Free (BN-Free) training, we extend their framework to training BNNs, and for the first time demonstrate that BNs can be completed removed from BNN training and inference regimes. By plugging in and customizing techniques including adaptive gradient clipping, scale weight standardization, and specialized bottleneck block, a **BN-free BNN** is capable of maintaining competitive accuracy compared to its BN-based counterpart. Experimental results can be found in [our paper](). + + + + + +## BN-Free Binary Neural Networks + + + + + +## Reproduce + +### Environment + +``` +pytorch == 1.5.0 +torchvision == 0.6.0 +timm == 0.4.5 +``` + +### Training on ImageNet + +``` +./script/imagenet_reactnet_A_bf.sh (BN-Free ReActNet-A) +./script/imagenet_reactnet_A_bn.sh (with BN ReActNet-A) +./script/imagenet_reactnet_A_none.sh (without BN ReActNet-A) +``` + + + +## Citation + +``` +TBD +``` + + + +## Acknowledgement + +https://github.com/liuzechun/ReActNet + +https://github.com/liuzechun/Bi-Real-net + +https://github.com/vballoli/nfnets-pytorch + +https://github.com/deepmind/deepmind-research/tree/master/nfnets \ No newline at end of file diff --git a/agc.py b/agc.py new file mode 100644 index 0000000..04e2099 --- /dev/null +++ b/agc.py @@ -0,0 +1,24 @@ +import torch + +def unitwise_norm(x, norm_type=2.0): + if x.ndim <= 1: + return x.norm(norm_type) + else: + # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor + # might need special cases for other weights (possibly MHA) where this may not be true + return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) + + +def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + p_data = p.detach() + g_data = p.grad.detach() + max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) + grad_norm = unitwise_norm(g_data, norm_type=norm_type) + clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) + p.grad.detach().copy_(new_grads) \ No newline at end of file diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..89651ca --- /dev/null +++ b/dataset.py @@ -0,0 +1,139 @@ +import os +import numpy as np +from PIL import Image +from torchvision import transforms +from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder +from torch.utils.data import DataLoader + +__all__ = ['cifar10_dataloaders', 'cifar100_dataloaders', 'imagenet_dataloaders'] + +#lighting data augmentation +imagenet_pca = { + 'eigval': np.asarray([0.2175, 0.0188, 0.0045]), + 'eigvec': np.asarray([ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203], + ]) +} + +class Lighting(object): + def __init__(self, alphastd, + eigval=imagenet_pca['eigval'], + eigvec=imagenet_pca['eigvec']): + self.alphastd = alphastd + assert eigval.shape == (3,) + assert eigvec.shape == (3, 3) + self.eigval = eigval + self.eigvec = eigvec + + def __call__(self, img): + if self.alphastd == 0.: + return img + rnd = np.random.randn(3) * self.alphastd + rnd = rnd.astype('float32') + v = rnd + old_dtype = np.asarray(img).dtype + v = v * self.eigval + v = v.reshape((3, 1)) + inc = np.dot(self.eigvec, v).reshape((3,)) + img = np.add(img, inc) + if old_dtype == np.uint8: + img = np.clip(img, 0, 255) + img = Image.fromarray(img.astype(old_dtype), 'RGB') + return img + + def __repr__(self): + return self.__class__.__name__ + '()' + + +def cifar10_dataloaders(batch_size=128, data_dir = 'datasets/cifar10', worker=4): + + normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], + std=[0.2470, 0.2435, 0.2616]) + + train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ]) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + normalize + ]) + + train_set = CIFAR10(data_dir, train=True, transform=train_transform, download=True) + test_set = CIFAR10(data_dir, train=False, transform=test_transform, download=True) + + train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=worker, pin_memory=True) + test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=worker, pin_memory=True) + + return train_loader, test_loader + +def cifar100_dataloaders(batch_size=128, data_dir = 'datasets/cifar100', worker=4): + + normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], + std=[0.2673, 0.2564, 0.2762]) + + train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ]) + + test_transform = transforms.Compose([ + transforms.ToTensor(), + normalize + ]) + + train_set = CIFAR100(data_dir, train=True, transform=train_transform, download=True) + test_set = CIFAR100(data_dir, train=False, transform=test_transform, download=True) + + train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=worker, pin_memory=True) + test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=worker, pin_memory=True) + + return train_loader, test_loader + +def imagenet_dataloaders(batch_size=128, data_dir = 'datasets/cifar100', worker=4): + + traindir = os.path.join(data_dir, 'train') + valdir = os.path.join(data_dir, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + # data augmentation + crop_scale = 0.08 + lighting_param = 0.1 + train_transforms = transforms.Compose([ + transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)), + Lighting(lighting_param), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize]) + + train_dataset = ImageFolder( + traindir, + transform=train_transforms) + + train_loader = DataLoader(train_dataset, + batch_size=batch_size, shuffle=True, + num_workers=worker, pin_memory=True) + + # load validation data + val_loader = DataLoader( + ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])), + batch_size=batch_size, shuffle=False, + num_workers=worker, pin_memory=True) + + return train_loader, val_loader + + + diff --git a/layers.py b/layers.py new file mode 100644 index 0000000..3d5afab --- /dev/null +++ b/layers.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +__all__ = ['LambdaLayer', 'ScaledStdConv2d', 'HardBinaryScaledStdConv2d', 'LearnableBias','BinaryActivation', 'HardBinaryConv'] + +def get_weight(module): + std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = (module.weight - mean) / (std + module.eps) + return weight + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + +class LambdaLayer(nn.Module): + def __init__(self, lambd): + super(LambdaLayer, self).__init__() + self.lambd = lambd + + def forward(self, x): + return self.lambd(x) + + +class ScaledStdConv2d(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization. + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, + bias=False, gamma=1.0, eps=1e-5, use_layernorm=False): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) + self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) + self.eps = eps ** 2 if use_layernorm else eps + self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel + + def get_weight(self): + if self.use_layernorm: + weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + else: + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = self.scale * (self.weight - mean) / (std + self.eps) + return self.gain * weight + + def forward(self, x): + return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + +class HardBinaryScaledStdConv2d(nn.Module): + + def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1, gamma=1.0, eps=1e-5, use_layernorm=False): + super(HardBinaryScaledStdConv2d, self).__init__() + self.stride = stride + self.padding = padding + self.shape = (out_chn, in_chn, kernel_size, kernel_size) + self.weight = nn.Parameter(torch.rand(self.shape) * 0.001, requires_grad=True) + + self.gain = nn.Parameter(torch.ones(out_chn, 1, 1, 1)) + self.scale = gamma * self.weight[0].numel() ** -0.5 + self.eps = eps ** 2 if use_layernorm else eps + self.use_layernorm = use_layernorm + + def get_weight(self): + if self.use_layernorm: + weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + else: + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = self.scale * (self.weight - mean) / (std + self.eps) + + scaling_factor = torch.mean(torch.mean(torch.mean(abs(weight),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True) + scaling_factor = scaling_factor.detach() + binary_weights_no_grad = scaling_factor * torch.sign(weight) + cliped_weights = torch.clamp(weight, -1.0, 1.0) + binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights + + return self.gain * binary_weights + + def forward(self, x): + + return F.conv2d(x, self.get_weight(), stride=self.stride, padding=self.padding) + +class LearnableBias(nn.Module): + def __init__(self, out_chn): + super(LearnableBias, self).__init__() + self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True) + + def forward(self, x): + out = x + self.bias.expand_as(x) + return out + +class BinaryActivation(nn.Module): + def __init__(self): + super(BinaryActivation, self).__init__() + + def forward(self, x): + out_forward = torch.sign(x) + #out_e1 = (x^2 + 2*x) + #out_e2 = (-x^2 + 2*x) + out_e_total = 0 + mask1 = x < -1 + mask2 = x < 0 + mask3 = x < 1 + out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32)) + out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32)) + out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32)) + out = out_forward.detach() - out3.detach() + out3 + + return out + +class HardBinaryConv(nn.Module): + def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1): + super(HardBinaryConv, self).__init__() + self.stride = stride + self.padding = padding + self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size + self.shape = (out_chn, in_chn, kernel_size, kernel_size) + #self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True) + self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True) + + def forward(self, x): + real_weights = self.weight + scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True) + scaling_factor = scaling_factor.detach() + binary_weights_no_grad = scaling_factor * torch.sign(real_weights) + cliped_weights = torch.clamp(real_weights, -1.0, 1.0) + binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights + y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding) + + return y + diff --git a/models/Qa_reactnet_18_bf.py b/models/Qa_reactnet_18_bf.py new file mode 100644 index 0000000..6e2b362 --- /dev/null +++ b/models/Qa_reactnet_18_bf.py @@ -0,0 +1,137 @@ +''' +React-birealnet-18(modified from resnet) + +BN setting: remove all BatchNorm layers +Conv setting: replace conv2d with ScaledstdConv2d (add alpha beta each blocks) +Binary setting: only activation are binarized + +''' + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F + +from layers import * + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return ScaledStdConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return ScaledStdConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryScaledStdConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryScaledStdConv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, alpha, beta, stride=1, downsample=None): + super(BasicBlock, self).__init__() + + self.alpha = alpha + self.beta = beta + + self.move0 = LearnableBias(inplanes) + self.binary_activation = BinaryActivation() + self.binary_conv = conv3x3(inplanes, planes, stride=stride) + self.move1 = LearnableBias(planes) + self.prelu = nn.PReLU(planes) + self.move2 = LearnableBias(planes) + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + + residual = x + x_in = x*self.beta + + out = self.move0(x_in) + out = self.binary_activation(out) + out = self.binary_conv(out) + + if self.downsample is not None: + residual = self.downsample(x_in) + + out = out*self.alpha + residual + out = self.move1(out) + out = self.prelu(out) + out = self.move2(out) + + return out + +class BiRealNet(nn.Module): + + def __init__(self, block, layers, imagenet=True, alpha=0.2, num_classes=1000): + super(BiRealNet, self).__init__() + self.inplanes = 64 + + if imagenet: + self.conv1 = ScaledStdConv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + self.conv1 = ScaledStdConv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.maxpool = nn.Identity() + + expected_var = 1.0 + self.layer1, expected_var = self._make_layer(block, 64, layers[0], alpha, expected_var) + self.layer2, expected_var = self._make_layer(block, 128, layers[1], alpha, expected_var, stride=2) + self.layer3, expected_var = self._make_layer(block, 256, layers[2], alpha, expected_var, stride=2) + self.layer4, expected_var = self._make_layer(block, 512, layers[3], alpha, expected_var, stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, alpha, expected_var, stride=1): + + beta = 1. / expected_var ** 0.5 + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=stride), + conv1x1(self.inplanes, planes * block.expansion) + ) + # Reset expected var at a transition block + expected_var = 1.0 + + layers = [] + layers.append(block(self.inplanes, planes, alpha, beta, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + beta = 1. / expected_var ** 0.5 + layers.append(block(self.inplanes, planes, alpha, beta)) + expected_var += alpha ** 2 + + return nn.Sequential(*layers), expected_var + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def birealnet18(pretrained=False, **kwargs): + """Constructs a BiRealNet-18 model. """ + model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs) + return model + + + diff --git a/models/Qa_reactnet_18_bn.py b/models/Qa_reactnet_18_bn.py new file mode 100644 index 0000000..23164bf --- /dev/null +++ b/models/Qa_reactnet_18_bn.py @@ -0,0 +1,130 @@ +''' +React-birealnet-18(modified from resnet) + +BN setting: original all BN +Conv setting: original Conv2d +Binary setting: only activation are binarized + +''' + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F + +from layers import * + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + + self.move0 = LearnableBias(inplanes) + self.binary_activation = BinaryActivation() + self.binary_conv = conv3x3(inplanes, planes, stride=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.move1 = LearnableBias(planes) + self.prelu = nn.PReLU(planes) + self.move2 = LearnableBias(planes) + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.move0(x) + out = self.binary_activation(out) + out = self.binary_conv(out) + out = self.bn1(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.move1(out) + out = self.prelu(out) + out = self.move2(out) + + return out + +class BiRealNet(nn.Module): + + def __init__(self, block, layers, imagenet=True, num_classes=1000): + super(BiRealNet, self).__init__() + self.inplanes = 64 + + if imagenet: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.Identity() + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=stride), + conv1x1(self.inplanes, planes * block.expansion), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def birealnet18(pretrained=False, **kwargs): + """Constructs a BiRealNet-18 model. """ + model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs) + return model + + + diff --git a/models/Qa_reactnet_18_none.py b/models/Qa_reactnet_18_none.py new file mode 100644 index 0000000..1c70178 --- /dev/null +++ b/models/Qa_reactnet_18_none.py @@ -0,0 +1,126 @@ +''' +React-birealnet-18(modified from resnet) + +BN setting: remove all BatchNorm layers +Conv setting: original Conv2d +Binary setting: only activation are binarized + +''' + + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F + +from layers import * + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + + self.move0 = LearnableBias(inplanes) + self.binary_activation = BinaryActivation() + self.binary_conv = conv3x3(inplanes, planes, stride=stride) + self.move1 = LearnableBias(planes) + self.prelu = nn.PReLU(planes) + self.move2 = LearnableBias(planes) + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.move0(x) + out = self.binary_activation(out) + out = self.binary_conv(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.move1(out) + out = self.prelu(out) + out = self.move2(out) + + return out + +class BiRealNet(nn.Module): + + def __init__(self, block, layers, imagenet=True, num_classes=1000): + super(BiRealNet, self).__init__() + self.inplanes = 64 + + if imagenet: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.maxpool = nn.Identity() + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=stride), + conv1x1(self.inplanes, planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def birealnet18(pretrained=False, **kwargs): + """Constructs a BiRealNet-18 model. """ + model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs) + return model + + + diff --git a/models/Qa_reactnet_A_bf.py b/models/Qa_reactnet_A_bf.py new file mode 100644 index 0000000..8bc8bb1 --- /dev/null +++ b/models/Qa_reactnet_A_bf.py @@ -0,0 +1,172 @@ +''' +ReActNet(modified from MobileNetv1) + +BN setting: remove all BatchNorm layers +Conv setting: replace conv2d with ScaledstdConv2d (add alpha beta each blocks) +Binary setting: only activation are binarized + +''' + + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F +import numpy as np + +from layers import * + +stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2 + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return ScaledStdConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return ScaledStdConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryScaledStdConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryScaledStdConv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class firstconv3x3(nn.Module): + def __init__(self, inp, oup, stride): + super(firstconv3x3, self).__init__() + self.conv1 = ScaledStdConv2d(inp, oup, 3, stride, 1, bias=False) + + def forward(self, x): + out = self.conv1(x) + return out + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, alpha, beta1, beta2, stride=1): + super(BasicBlock, self).__init__() + norm_layer = nn.BatchNorm2d + + self.alpha = alpha + self.beta1 = beta1 + self.beta2 = beta2 + + self.move11 = LearnableBias(inplanes) + self.binary_3x3= conv3x3(inplanes, inplanes, stride=stride) + + self.move12 = LearnableBias(inplanes) + self.prelu1 = nn.PReLU(inplanes) + self.move13 = LearnableBias(inplanes) + + self.move21 = LearnableBias(inplanes) + + if inplanes == planes: + self.binary_pw = conv1x1(inplanes, planes) + else: + self.binary_pw_down1 = conv1x1(inplanes, inplanes) + self.binary_pw_down2 = conv1x1(inplanes, inplanes) + + self.move22 = LearnableBias(planes) + self.prelu2 = nn.PReLU(planes) + self.move23 = LearnableBias(planes) + + self.binary_activation = BinaryActivation() + self.stride = stride + self.inplanes = inplanes + self.planes = planes + + if self.inplanes != self.planes: + self.pooling = nn.AvgPool2d(2,2) + + def forward(self, x): + + x_in = x*self.beta1 + + out1 = self.move11(x_in) + out1 = self.binary_activation(out1) + out1 = self.binary_3x3(out1) + + if self.stride == 2: + x = self.pooling(x_in) + + out1 = x + out1*self.alpha + + out1 = self.move12(out1) + out1 = self.prelu1(out1) + out1 = self.move13(out1) + + out1_in = out1*self.beta2 + + out2 = self.move21(out1_in) + out2 = self.binary_activation(out2) + + if self.inplanes == self.planes: + out2 = self.binary_pw(out2) + out2 = out2*self.alpha + out1 + + else: + assert self.planes == self.inplanes * 2 + + out2_1 = self.binary_pw_down1(out2) + out2_2 = self.binary_pw_down2(out2) + out2_1 = out2_1*self.alpha + out1 + out2_2 = out2_2*self.alpha + out1 + out2 = torch.cat([out2_1, out2_2], dim=1) + + out2 = self.move22(out2) + out2 = self.prelu2(out2) + out2 = self.move23(out2) + + return out2 + +class reactnet(nn.Module): + + def __init__(self, alpha=0.2, num_classes=1000): + super(reactnet, self).__init__() + + self.feature = nn.ModuleList() + for i in range(len(stage_out_channel)): + if i == 0: + + expected_var = 1.0 + beta1 = 1. / expected_var ** 0.5 + expected_var += alpha ** 2 + beta2 = 1. / expected_var ** 0.5 + + self.feature.append(firstconv3x3(3, stage_out_channel[i], 2)) + elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], alpha, beta1, beta2, 2)) + # Reset expected var at a transition block + expected_var = 1.0 + beta1 = 1. / expected_var ** 0.5 + expected_var += alpha ** 2 + beta2 = 1. / expected_var ** 0.5 + + else: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], alpha, beta1, beta2, 1)) + + expected_var += alpha ** 2 + beta1 = 1. / expected_var ** 0.5 + expected_var += alpha ** 2 + beta2 = 1. / expected_var ** 0.5 + + self.pool1 = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + for i, block in enumerate(self.feature): + x = block(x) + + x = self.pool1(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + + + + + diff --git a/models/Qa_reactnet_A_bn.py b/models/Qa_reactnet_A_bn.py new file mode 100644 index 0000000..70d2de4 --- /dev/null +++ b/models/Qa_reactnet_A_bn.py @@ -0,0 +1,156 @@ +''' +ReActNet(modified from MobileNetv1) + +BN setting: original all BN +Conv setting: original Conv2d +Binary setting: only activation are binarized + +''' + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F +import numpy as np + +from layers import * + +stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2 + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class firstconv3x3(nn.Module): + def __init__(self, inp, oup, stride): + super(firstconv3x3, self).__init__() + + self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False) + self.bn1 = nn.BatchNorm2d(oup) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + + return out + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1): + super(BasicBlock, self).__init__() + norm_layer = nn.BatchNorm2d + + self.move11 = LearnableBias(inplanes) + self.binary_3x3= conv3x3(inplanes, inplanes, stride=stride) + self.bn1 = norm_layer(inplanes) + + self.move12 = LearnableBias(inplanes) + self.prelu1 = nn.PReLU(inplanes) + self.move13 = LearnableBias(inplanes) + + self.move21 = LearnableBias(inplanes) + + if inplanes == planes: + self.binary_pw = conv1x1(inplanes, planes) + self.bn2 = norm_layer(planes) + else: + self.binary_pw_down1 = conv1x1(inplanes, inplanes) + self.binary_pw_down2 = conv1x1(inplanes, inplanes) + self.bn2_1 = norm_layer(inplanes) + self.bn2_2 = norm_layer(inplanes) + + self.move22 = LearnableBias(planes) + self.prelu2 = nn.PReLU(planes) + self.move23 = LearnableBias(planes) + + self.binary_activation = BinaryActivation() + self.stride = stride + self.inplanes = inplanes + self.planes = planes + + if self.inplanes != self.planes: + self.pooling = nn.AvgPool2d(2,2) + + def forward(self, x): + + out1 = self.move11(x) + + out1 = self.binary_activation(out1) + out1 = self.binary_3x3(out1) + out1 = self.bn1(out1) + + if self.stride == 2: + x = self.pooling(x) + + out1 = x + out1 + + out1 = self.move12(out1) + out1 = self.prelu1(out1) + out1 = self.move13(out1) + + out2 = self.move21(out1) + out2 = self.binary_activation(out2) + + if self.inplanes == self.planes: + out2 = self.binary_pw(out2) + out2 = self.bn2(out2) + out2 += out1 + + else: + assert self.planes == self.inplanes * 2 + + out2_1 = self.binary_pw_down1(out2) + out2_2 = self.binary_pw_down2(out2) + out2_1 = self.bn2_1(out2_1) + out2_2 = self.bn2_2(out2_2) + out2_1 += out1 + out2_2 += out1 + out2 = torch.cat([out2_1, out2_2], dim=1) + + out2 = self.move22(out2) + out2 = self.prelu2(out2) + out2 = self.move23(out2) + + return out2 + +class reactnet(nn.Module): + def __init__(self, num_classes=1000): + super(reactnet, self).__init__() + self.feature = nn.ModuleList() + for i in range(len(stage_out_channel)): + if i == 0: + self.feature.append(firstconv3x3(3, stage_out_channel[i], 2)) + elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 2)) + else: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 1)) + self.pool1 = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + for i, block in enumerate(self.feature): + x = block(x) + + x = self.pool1(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + + + + + diff --git a/models/Qa_reactnet_A_none.py b/models/Qa_reactnet_A_none.py new file mode 100644 index 0000000..50746b9 --- /dev/null +++ b/models/Qa_reactnet_A_none.py @@ -0,0 +1,144 @@ +''' +ReActNet(modified from MobileNetv1) + +BN setting: remove all BatchNorm layers +Conv setting: original Conv2d +Binary setting: only activation are binarized + +''' + + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F +import numpy as np + +from layers import * + +stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2 + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class firstconv3x3(nn.Module): + def __init__(self, inp, oup, stride): + super(firstconv3x3, self).__init__() + self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False) + + def forward(self, x): + out = self.conv1(x) + return out + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1): + super(BasicBlock, self).__init__() + + self.move11 = LearnableBias(inplanes) + self.binary_3x3= conv3x3(inplanes, inplanes, stride=stride) + + self.move12 = LearnableBias(inplanes) + self.prelu1 = nn.PReLU(inplanes) + self.move13 = LearnableBias(inplanes) + + self.move21 = LearnableBias(inplanes) + + if inplanes == planes: + self.binary_pw = conv1x1(inplanes, planes) + else: + self.binary_pw_down1 = conv1x1(inplanes, inplanes) + self.binary_pw_down2 = conv1x1(inplanes, inplanes) + + self.move22 = LearnableBias(planes) + self.prelu2 = nn.PReLU(planes) + self.move23 = LearnableBias(planes) + + self.binary_activation = BinaryActivation() + self.stride = stride + self.inplanes = inplanes + self.planes = planes + + if self.inplanes != self.planes: + self.pooling = nn.AvgPool2d(2,2) + + def forward(self, x): + + out1 = self.move11(x) + + out1 = self.binary_activation(out1) + out1 = self.binary_3x3(out1) + + if self.stride == 2: + x = self.pooling(x) + + out1 = x + out1 + + out1 = self.move12(out1) + out1 = self.prelu1(out1) + out1 = self.move13(out1) + + out2 = self.move21(out1) + out2 = self.binary_activation(out2) + + if self.inplanes == self.planes: + out2 = self.binary_pw(out2) + out2 += out1 + + else: + assert self.planes == self.inplanes * 2 + + out2_1 = self.binary_pw_down1(out2) + out2_2 = self.binary_pw_down2(out2) + out2_1 += out1 + out2_2 += out1 + out2 = torch.cat([out2_1, out2_2], dim=1) + + out2 = self.move22(out2) + out2 = self.prelu2(out2) + out2 = self.move23(out2) + + return out2 + +class reactnet(nn.Module): + def __init__(self, num_classes=1000): + super(reactnet, self).__init__() + self.feature = nn.ModuleList() + for i in range(len(stage_out_channel)): + if i == 0: + self.feature.append(firstconv3x3(3, stage_out_channel[i], 2)) + elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 2)) + else: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 1)) + self.pool1 = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + for i, block in enumerate(self.feature): + x = block(x) + + x = self.pool1(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + + + + + diff --git a/models/Qaw_reactnet_18_bf.py b/models/Qaw_reactnet_18_bf.py new file mode 100644 index 0000000..658a6b7 --- /dev/null +++ b/models/Qaw_reactnet_18_bf.py @@ -0,0 +1,137 @@ +''' +React-birealnet-18(modified from resnet) + +BN setting: remove all BatchNorm layers +Conv setting: replace conv2d with ScaledstdConv2d (add alpha beta each blocks) +Binary setting: only activation are binarized + +''' + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F + +from layers import * + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return ScaledStdConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return ScaledStdConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryScaledStdConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryScaledStdConv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, alpha, beta, stride=1, downsample=None): + super(BasicBlock, self).__init__() + + self.alpha = alpha + self.beta = beta + + self.move0 = LearnableBias(inplanes) + self.binary_activation = BinaryActivation() + self.binary_conv = binaryconv3x3(inplanes, planes, stride=stride) + self.move1 = LearnableBias(planes) + self.prelu = nn.PReLU(planes) + self.move2 = LearnableBias(planes) + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + + residual = x + x_in = x*self.beta + + out = self.move0(x_in) + out = self.binary_activation(out) + out = self.binary_conv(out) + + if self.downsample is not None: + residual = self.downsample(x_in) + + out = out*self.alpha + residual + out = self.move1(out) + out = self.prelu(out) + out = self.move2(out) + + return out + +class BiRealNet(nn.Module): + + def __init__(self, block, layers, imagenet=True, alpha=0.2, num_classes=1000): + super(BiRealNet, self).__init__() + self.inplanes = 64 + + if imagenet: + self.conv1 = ScaledStdConv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + self.conv1 = ScaledStdConv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.maxpool = nn.Identity() + + expected_var = 1.0 + self.layer1, expected_var = self._make_layer(block, 64, layers[0], alpha, expected_var) + self.layer2, expected_var = self._make_layer(block, 128, layers[1], alpha, expected_var, stride=2) + self.layer3, expected_var = self._make_layer(block, 256, layers[2], alpha, expected_var, stride=2) + self.layer4, expected_var = self._make_layer(block, 512, layers[3], alpha, expected_var, stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, alpha, expected_var, stride=1): + + beta = 1. / expected_var ** 0.5 + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=stride), + binaryconv1x1(self.inplanes, planes * block.expansion) + ) + # Reset expected var at a transition block + expected_var = 1.0 + + layers = [] + layers.append(block(self.inplanes, planes, alpha, beta, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + beta = 1. / expected_var ** 0.5 + layers.append(block(self.inplanes, planes, alpha, beta)) + expected_var += alpha ** 2 + + return nn.Sequential(*layers), expected_var + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def birealnet18(pretrained=False, **kwargs): + """Constructs a BiRealNet-18 model. """ + model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs) + return model + + + diff --git a/models/Qaw_reactnet_18_bn.py b/models/Qaw_reactnet_18_bn.py new file mode 100644 index 0000000..ed1c854 --- /dev/null +++ b/models/Qaw_reactnet_18_bn.py @@ -0,0 +1,131 @@ +''' +React-birealnet-18(modified from resnet) + +BN setting: original all BN +Conv setting: original Conv2d +Binary setting: both activation and weight are binarized + +''' + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F + +from layers import * + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + + self.move0 = LearnableBias(inplanes) + self.binary_activation = BinaryActivation() + self.binary_conv = binaryconv3x3(inplanes, planes, stride=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.move1 = LearnableBias(planes) + self.prelu = nn.PReLU(planes) + self.move2 = LearnableBias(planes) + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.move0(x) + out = self.binary_activation(out) + out = self.binary_conv(out) + out = self.bn1(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.move1(out) + out = self.prelu(out) + out = self.move2(out) + + return out + +class BiRealNet(nn.Module): + + def __init__(self, block, layers, imagenet=True, num_classes=1000): + super(BiRealNet, self).__init__() + self.inplanes = 64 + + if imagenet: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.Identity() + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=stride), + binaryconv1x1(self.inplanes, planes * block.expansion), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def birealnet18(pretrained=False, **kwargs): + """Constructs a BiRealNet-18 model. """ + model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs) + return model + + + diff --git a/models/Qaw_reactnet_18_none.py b/models/Qaw_reactnet_18_none.py new file mode 100644 index 0000000..e34920c --- /dev/null +++ b/models/Qaw_reactnet_18_none.py @@ -0,0 +1,124 @@ +''' +React-birealnet-18(modified from resnet) + +BN setting: remove all BatchNorm layers +Conv setting: original Conv2d +Binary setting: both activation and weight are binarized + +''' + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F + +from layers import * + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + + self.move0 = LearnableBias(inplanes) + self.binary_activation = BinaryActivation() + self.binary_conv = binaryconv3x3(inplanes, planes, stride=stride) + self.move1 = LearnableBias(planes) + self.prelu = nn.PReLU(planes) + self.move2 = LearnableBias(planes) + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.move0(x) + out = self.binary_activation(out) + out = self.binary_conv(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.move1(out) + out = self.prelu(out) + out = self.move2(out) + + return out + +class BiRealNet(nn.Module): + + def __init__(self, block, layers, imagenet=True, num_classes=1000): + super(BiRealNet, self).__init__() + self.inplanes = 64 + if imagenet: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.maxpool = nn.Identity() + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=stride), + binaryconv1x1(self.inplanes, planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def birealnet18(pretrained=False, **kwargs): + """Constructs a BiRealNet-18 model. """ + model = BiRealNet(BasicBlock, [4, 4, 4, 4], **kwargs) + return model + + + diff --git a/models/Qaw_reactnet_A_bf.py b/models/Qaw_reactnet_A_bf.py new file mode 100644 index 0000000..321d42f --- /dev/null +++ b/models/Qaw_reactnet_A_bf.py @@ -0,0 +1,172 @@ +''' +ReActNet(modified from MobileNetv1) + +BN setting: remove all BatchNorm layers +Conv setting: replace conv2d with ScaledstdConv2d (add alpha beta each blocks) +Binary setting: only activation are binarized + +''' + + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F +import numpy as np + +from layers import * + +stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2 + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return ScaledStdConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return ScaledStdConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryScaledStdConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryScaledStdConv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class firstconv3x3(nn.Module): + def __init__(self, inp, oup, stride): + super(firstconv3x3, self).__init__() + self.conv1 = ScaledStdConv2d(inp, oup, 3, stride, 1, bias=False) + + def forward(self, x): + out = self.conv1(x) + return out + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, alpha, beta1, beta2, stride=1): + super(BasicBlock, self).__init__() + norm_layer = nn.BatchNorm2d + + self.alpha = alpha + self.beta1 = beta1 + self.beta2 = beta2 + + self.move11 = LearnableBias(inplanes) + self.binary_3x3= binaryconv3x3(inplanes, inplanes, stride=stride) + + self.move12 = LearnableBias(inplanes) + self.prelu1 = nn.PReLU(inplanes) + self.move13 = LearnableBias(inplanes) + + self.move21 = LearnableBias(inplanes) + + if inplanes == planes: + self.binary_pw = binaryconv1x1(inplanes, planes) + else: + self.binary_pw_down1 = binaryconv1x1(inplanes, inplanes) + self.binary_pw_down2 = binaryconv1x1(inplanes, inplanes) + + self.move22 = LearnableBias(planes) + self.prelu2 = nn.PReLU(planes) + self.move23 = LearnableBias(planes) + + self.binary_activation = BinaryActivation() + self.stride = stride + self.inplanes = inplanes + self.planes = planes + + if self.inplanes != self.planes: + self.pooling = nn.AvgPool2d(2,2) + + def forward(self, x): + + x_in = x*self.beta1 + + out1 = self.move11(x_in) + out1 = self.binary_activation(out1) + out1 = self.binary_3x3(out1) + + if self.stride == 2: + x = self.pooling(x_in) + + out1 = x + out1*self.alpha + + out1 = self.move12(out1) + out1 = self.prelu1(out1) + out1 = self.move13(out1) + + out1_in = out1*self.beta2 + + out2 = self.move21(out1_in) + out2 = self.binary_activation(out2) + + if self.inplanes == self.planes: + out2 = self.binary_pw(out2) + out2 = out2*self.alpha + out1 + + else: + assert self.planes == self.inplanes * 2 + + out2_1 = self.binary_pw_down1(out2) + out2_2 = self.binary_pw_down2(out2) + out2_1 = out2_1*self.alpha + out1 + out2_2 = out2_2*self.alpha + out1 + out2 = torch.cat([out2_1, out2_2], dim=1) + + out2 = self.move22(out2) + out2 = self.prelu2(out2) + out2 = self.move23(out2) + + return out2 + +class reactnet(nn.Module): + + def __init__(self, alpha=0.2, num_classes=1000): + super(reactnet, self).__init__() + + self.feature = nn.ModuleList() + for i in range(len(stage_out_channel)): + if i == 0: + + expected_var = 1.0 + beta1 = 1. / expected_var ** 0.5 + expected_var += alpha ** 2 + beta2 = 1. / expected_var ** 0.5 + + self.feature.append(firstconv3x3(3, stage_out_channel[i], 2)) + elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], alpha, beta1, beta2, 2)) + # Reset expected var at a transition block + expected_var = 1.0 + beta1 = 1. / expected_var ** 0.5 + expected_var += alpha ** 2 + beta2 = 1. / expected_var ** 0.5 + + else: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], alpha, beta1, beta2, 1)) + + expected_var += alpha ** 2 + beta1 = 1. / expected_var ** 0.5 + expected_var += alpha ** 2 + beta2 = 1. / expected_var ** 0.5 + + self.pool1 = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + for i, block in enumerate(self.feature): + x = block(x) + + x = self.pool1(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + + + + + diff --git a/models/Qaw_reactnet_A_bn.py b/models/Qaw_reactnet_A_bn.py new file mode 100644 index 0000000..a24b255 --- /dev/null +++ b/models/Qaw_reactnet_A_bn.py @@ -0,0 +1,157 @@ +''' +ReActNet(modified from MobileNetv1) + +BN setting: original all BN +Conv setting: original Conv2d +Binary setting: both activation and weight are binarized + +''' + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F +import numpy as np + +from layers import * + +stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2 + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class firstconv3x3(nn.Module): + def __init__(self, inp, oup, stride): + super(firstconv3x3, self).__init__() + + self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False) + self.bn1 = nn.BatchNorm2d(oup) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + + return out + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1): + super(BasicBlock, self).__init__() + norm_layer = nn.BatchNorm2d + + self.move11 = LearnableBias(inplanes) + self.binary_3x3= binaryconv3x3(inplanes, inplanes, stride=stride) + self.bn1 = norm_layer(inplanes) + + self.move12 = LearnableBias(inplanes) + self.prelu1 = nn.PReLU(inplanes) + self.move13 = LearnableBias(inplanes) + + self.move21 = LearnableBias(inplanes) + + if inplanes == planes: + self.binary_pw = binaryconv1x1(inplanes, planes) + self.bn2 = norm_layer(planes) + else: + self.binary_pw_down1 = binaryconv1x1(inplanes, inplanes) + self.binary_pw_down2 = binaryconv1x1(inplanes, inplanes) + self.bn2_1 = norm_layer(inplanes) + self.bn2_2 = norm_layer(inplanes) + + self.move22 = LearnableBias(planes) + self.prelu2 = nn.PReLU(planes) + self.move23 = LearnableBias(planes) + + self.binary_activation = BinaryActivation() + self.stride = stride + self.inplanes = inplanes + self.planes = planes + + if self.inplanes != self.planes: + self.pooling = nn.AvgPool2d(2,2) + + def forward(self, x): + + out1 = self.move11(x) + + out1 = self.binary_activation(out1) + out1 = self.binary_3x3(out1) + out1 = self.bn1(out1) + + if self.stride == 2: + x = self.pooling(x) + + out1 = x + out1 + + out1 = self.move12(out1) + out1 = self.prelu1(out1) + out1 = self.move13(out1) + + out2 = self.move21(out1) + out2 = self.binary_activation(out2) + + if self.inplanes == self.planes: + out2 = self.binary_pw(out2) + out2 = self.bn2(out2) + out2 += out1 + + else: + assert self.planes == self.inplanes * 2 + + out2_1 = self.binary_pw_down1(out2) + out2_2 = self.binary_pw_down2(out2) + out2_1 = self.bn2_1(out2_1) + out2_2 = self.bn2_2(out2_2) + out2_1 += out1 + out2_2 += out1 + out2 = torch.cat([out2_1, out2_2], dim=1) + + out2 = self.move22(out2) + out2 = self.prelu2(out2) + out2 = self.move23(out2) + + return out2 + +class reactnet(nn.Module): + def __init__(self, num_classes=1000): + super(reactnet, self).__init__() + self.feature = nn.ModuleList() + for i in range(len(stage_out_channel)): + if i == 0: + self.feature.append(firstconv3x3(3, stage_out_channel[i], 2)) + elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 2)) + else: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 1)) + self.pool1 = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + for i, block in enumerate(self.feature): + x = block(x) + + x = self.pool1(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + + + + + diff --git a/models/Qaw_reactnet_A_none.py b/models/Qaw_reactnet_A_none.py new file mode 100644 index 0000000..0aef536 --- /dev/null +++ b/models/Qaw_reactnet_A_none.py @@ -0,0 +1,143 @@ +''' +ReActNet(modified from MobileNetv1) + +BN setting: remove all BatchNorm layers +Conv setting: original Conv2d +Binary setting: both activation and weight are binarized + +''' + + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F +import numpy as np + +from layers import * + +stage_out_channel = [32] + [64] + [128] * 2 + [256] * 2 + [512] * 6 + [1024] * 2 + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +def binaryconv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return HardBinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) + +def binaryconv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return HardBinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) + +class firstconv3x3(nn.Module): + def __init__(self, inp, oup, stride): + super(firstconv3x3, self).__init__() + self.conv1 = nn.Conv2d(inp, oup, 3, stride, 1, bias=False) + + def forward(self, x): + out = self.conv1(x) + return out + +class BasicBlock(nn.Module): + def __init__(self, inplanes, planes, stride=1): + super(BasicBlock, self).__init__() + + self.move11 = LearnableBias(inplanes) + self.binary_3x3= binaryconv3x3(inplanes, inplanes, stride=stride) + + self.move12 = LearnableBias(inplanes) + self.prelu1 = nn.PReLU(inplanes) + self.move13 = LearnableBias(inplanes) + + self.move21 = LearnableBias(inplanes) + + if inplanes == planes: + self.binary_pw = binaryconv1x1(inplanes, planes) + else: + self.binary_pw_down1 = binaryconv1x1(inplanes, inplanes) + self.binary_pw_down2 = binaryconv1x1(inplanes, inplanes) + + self.move22 = LearnableBias(planes) + self.prelu2 = nn.PReLU(planes) + self.move23 = LearnableBias(planes) + + self.binary_activation = BinaryActivation() + self.stride = stride + self.inplanes = inplanes + self.planes = planes + + if self.inplanes != self.planes: + self.pooling = nn.AvgPool2d(2,2) + + def forward(self, x): + + out1 = self.move11(x) + + out1 = self.binary_activation(out1) + out1 = self.binary_3x3(out1) + + if self.stride == 2: + x = self.pooling(x) + + out1 = x + out1 + + out1 = self.move12(out1) + out1 = self.prelu1(out1) + out1 = self.move13(out1) + + out2 = self.move21(out1) + out2 = self.binary_activation(out2) + + if self.inplanes == self.planes: + out2 = self.binary_pw(out2) + out2 += out1 + + else: + assert self.planes == self.inplanes * 2 + + out2_1 = self.binary_pw_down1(out2) + out2_2 = self.binary_pw_down2(out2) + out2_1 += out1 + out2_2 += out1 + out2 = torch.cat([out2_1, out2_2], dim=1) + + out2 = self.move22(out2) + out2 = self.prelu2(out2) + out2 = self.move23(out2) + + return out2 + +class reactnet(nn.Module): + def __init__(self, num_classes=1000): + super(reactnet, self).__init__() + self.feature = nn.ModuleList() + for i in range(len(stage_out_channel)): + if i == 0: + self.feature.append(firstconv3x3(3, stage_out_channel[i], 2)) + elif stage_out_channel[i-1] != stage_out_channel[i] and stage_out_channel[i] != 64: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 2)) + else: + self.feature.append(BasicBlock(stage_out_channel[i-1], stage_out_channel[i], 1)) + self.pool1 = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(1024, num_classes) + + def forward(self, x): + for i, block in enumerate(self.feature): + x = block(x) + + x = self.pool1(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + + + + + diff --git a/script/imagenet_reactnet_A_bf.sh b/script/imagenet_reactnet_A_bf.sh new file mode 100644 index 0000000..cf84f1b --- /dev/null +++ b/script/imagenet_reactnet_A_bf.sh @@ -0,0 +1,47 @@ +#Binarize activation +DATADIR=data/imagenet +SAVEDIR1=ReActNet_Qa_bf +SAVEDIR2=ReActNet_Qaw_bf +BS=256 +Epoch=256 +python -u train.py \ + --data ${DATADIR} \ + --save ${SAVEDIR1} \ + --dataset imagenet \ + --batch_size ${BS} \ + --arch reactnet-A \ + --bn_type bf \ + --loss_type kd \ + --teacher dm_nfnet_f0 \ + --learning_rate 5e-4 \ + --epochs ${Epoch} \ + --weight_decay 1e-5 \ + --agc \ + --clip_value 0.02 + + +#Binarize activation and weight +python -u train.py \ + --data ${DATADIR} \ + --save ${SAVEDIR2} \ + --dataset imagenet \ + --batch_size ${BS} \ + --arch reactnet-A \ + --bn_type bf \ + --binary_w \ + --pretrained ${SAVEDIR1}/model_best.pth.tar \ + --loss_type kd \ + --teacher dm_nfnet_f0 \ + --learning_rate 5e-4 \ + --epochs ${Epoch} \ + --weight_decay 0 \ + --agc \ + --clip_value 0.02 + + + + + + + + diff --git a/script/imagenet_reactnet_A_bn.sh b/script/imagenet_reactnet_A_bn.sh new file mode 100644 index 0000000..579f22e --- /dev/null +++ b/script/imagenet_reactnet_A_bn.sh @@ -0,0 +1,47 @@ +#Binarize activation +DATADIR=data/imagenet +SAVEDIR1=ReActNet_Qa_bn +SAVEDIR2=ReActNet_Qaw_bn +BS=256 +Epoch=256 +python -u train.py \ + --data ${DATADIR} \ + --save ${SAVEDIR1} \ + --dataset imagenet \ + --batch_size ${BS} \ + --arch reactnet-A \ + --bn_type bn \ + --loss_type kd \ + --teacher dm_nfnet_f0 \ + --learning_rate 5e-4 \ + --epochs ${Epoch} \ + --weight_decay 1e-5 \ + --agc \ + --clip_value 0.02 + + +#Binarize activation and weight +python -u train.py \ + --data ${DATADIR} \ + --save ${SAVEDIR2} \ + --dataset imagenet \ + --batch_size ${BS} \ + --arch reactnet-A \ + --bn_type bn \ + --binary_w \ + --pretrained ${SAVEDIR1}/model_best.pth.tar \ + --loss_type kd \ + --teacher dm_nfnet_f0 \ + --learning_rate 5e-4 \ + --epochs ${Epoch} \ + --weight_decay 0 \ + --agc \ + --clip_value 0.02 + + + + + + + + diff --git a/script/imagenet_reactnet_A_none.sh b/script/imagenet_reactnet_A_none.sh new file mode 100644 index 0000000..55c15c5 --- /dev/null +++ b/script/imagenet_reactnet_A_none.sh @@ -0,0 +1,47 @@ +#Binarize activation +DATADIR=data/imagenet +SAVEDIR1=ReActNet_Qa_none +SAVEDIR2=ReActNet_Qaw_none +BS=256 +Epoch=256 +python -u train.py \ + --data ${DATADIR} \ + --save ${SAVEDIR1} \ + --dataset imagenet \ + --batch_size ${BS} \ + --arch reactnet-A \ + --bn_type none \ + --loss_type kd \ + --teacher dm_nfnet_f0 \ + --learning_rate 5e-4 \ + --epochs ${Epoch} \ + --weight_decay 1e-5 \ + --agc \ + --clip_value 0.02 + + +#Binarize activation and weight +python -u train.py \ + --data ${DATADIR} \ + --save ${SAVEDIR2} \ + --dataset imagenet \ + --batch_size ${BS} \ + --arch reactnet-A \ + --bn_type none \ + --binary_w \ + --pretrained ${SAVEDIR1}/model_best.pth.tar \ + --loss_type kd \ + --teacher dm_nfnet_f0 \ + --learning_rate 5e-4 \ + --epochs ${Epoch} \ + --weight_decay 0 \ + --agc \ + --clip_value 0.02 + + + + + + + + diff --git a/train.py b/train.py new file mode 100644 index 0000000..bca4f4d --- /dev/null +++ b/train.py @@ -0,0 +1,343 @@ +import os +import sys +import shutil +import numpy as np +import time, datetime +import torch +import random +import logging +import argparse +import torch.nn as nn +import torch.utils +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.utils.data.distributed + +import timm +from utils import * +from torchvision import datasets, transforms +from torch.autograd import Variable +import torchvision.models as models +from agc import adaptive_clip_grad + +parser = argparse.ArgumentParser("normalize free BNN") +################################# basic settings ###################################### +parser.add_argument('--data', type=str, default='../data', help='location of the data corpus') +parser.add_argument('--save', type=str, default='./models', help='path for saving trained models') +parser.add_argument('--dataset', type=str, default='imagenet', help='dataset') +parser.add_argument('--batch_size', type=int, default=512, help='batch size') +parser.add_argument('--arch', type=str, default='reactnet', help='architecture') +parser.add_argument('--bn_type', type=str, default='bn', help='[w/w.o bn or nf-module]') +parser.add_argument('--binary_w', action="store_true", help="whether binarize weight") +parser.add_argument('--resume', action="store_true", help="whether resume training") +parser.add_argument('--pretrained', type=str, default=None, help='pretrained weight') +parser.add_argument('--loss_type', type=str, default='kd', help='[kd, ce, ls]') +parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing') +parser.add_argument('--teacher', type=str, default='resnet34', help='path of ImageNet') +parser.add_argument('--teacher_weight', type=str, default=None, help='pretrained teacher weight') +################################# training settings ###################################### +parser.add_argument('--epochs', type=int, default=120, help='num of training epochs') +parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate') +parser.add_argument('--momentum', type=float, default=0.9, help='momentum') +parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') +parser.add_argument('--agc', action="store_true", help="whether using agc") +parser.add_argument('--clip_value', type=float, default=0.04, help='lambda for AGC') +################################# other settings ###################################### +parser.add_argument('-j', '--workers', default=40, type=int, metavar='N', help='number of data loading workers (default: 4)') + +args = parser.parse_args() + +def main(): + global args + print(args) + + if not torch.cuda.is_available(): + sys.exit(1) + start_t = time.time() + + cudnn.benchmark = True + cudnn.enabled=True + + train_loader, val_loader, model_student, CLASSES = setup_model_dataloader(args) + model_student = nn.DataParallel(model_student).cuda() + print(model_student) + + # load teacher model + if args.loss_type == 'kd': + print('* Loading teacher model') + if not 'nfnet' in args.teacher: + model_teacher = models.__dict__[args.teacher](pretrained=True) + classes_in_teacher = model_teacher.fc.out_features + num_features = model_teacher.fc.in_features + else: + model_teacher = timm.create_model(args.teacher, pretrained=True) + classes_in_teacher = model_teacher.head.fc.out_features + num_features = model_teacher.head.fc.in_features + + if not classes_in_teacher == CLASSES: + print('* change fc layers in teacher') + if not 'nfnet' in args.teacher: + model_teacher.fc = nn.Linear(num_features, CLASSES) + else: + model_teacher.head.fc = nn.Linear(num_features, CLASSES) + print('* loading pretrained teacher weight from {}'.format(args.teacher_weight)) + pretrain_teacher = torch.load(args.teacher_weight, map_location='cpu')['state_dict'] + model_teacher.load_state_dict(pretrain_teacher) + + model_teacher = nn.DataParallel(model_teacher).cuda() + for p in model_teacher.parameters(): + p.requires_grad = False + model_teacher.eval() + + + #criterion + criterion = nn.CrossEntropyLoss().cuda() + criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth).cuda() + criterion_kd = DistributionLoss() + + #optimizer + all_parameters = model_student.parameters() + weight_parameters = [] + for pname, p in model_student.named_parameters(): + if p.ndimension() == 4 or 'conv' in pname: + weight_parameters.append(p) + weight_parameters_id = list(map(id, weight_parameters)) + other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters)) + + optimizer = torch.optim.Adam( + [{'params' : other_parameters}, + {'params' : weight_parameters, 'weight_decay' : args.weight_decay}], + lr=args.learning_rate,) + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1) + start_epoch = 0 + best_top1_acc= 0 + + if args.pretrained: + print('* loading pretrained weight {}'.format(args.pretrained)) + pretrain_student = torch.load(args.pretrained) + if 'state_dict' in pretrain_student.keys(): + pretrain_student = pretrain_student['state_dict'] + + for key in pretrain_student.keys(): + if not key in model_student.state_dict().keys(): + print('unload key: {}'.format(key)) + + model_student.load_state_dict(pretrain_student, strict=False) + + if args.resume: + checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar') + if os.path.exists(checkpoint_tar): + print('loading checkpoint {} ..........'.format(checkpoint_tar)) + checkpoint = torch.load(checkpoint_tar) + start_epoch = checkpoint['epoch'] + best_top1_acc = checkpoint['best_top1_acc'] + model_student.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + scheduler.load_state_dict(checkpoint['scheduler']) + print("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch'])) + else: + raise ValueError('no checkpoint for resume') + + if args.loss_type == 'kd': + if not classes_in_teacher == CLASSES: + validate('teacher', val_loader, model_teacher, criterion, args) + + # train the model + epoch = start_epoch + while epoch < args.epochs: + + if args.loss_type == 'kd': + train_obj, train_top1_acc, train_top5_acc = train_kd(epoch, train_loader, model_student, model_teacher, criterion_kd, optimizer, scheduler) + elif args.loss_type == 'ce': + train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, criterion, optimizer, scheduler) + elif args.loss_type == 'ls': + train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, criterion_smooth, optimizer, scheduler) + else: + raise ValueError('unsupport loss_type') + + valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args) + + is_best = False + if valid_top1_acc > best_top1_acc: + best_top1_acc = valid_top1_acc + is_best = True + + save_checkpoint({ + 'epoch': epoch, + 'state_dict': model_student.state_dict(), + 'best_top1_acc': best_top1_acc, + 'optimizer' : optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), + }, is_best, args.save) + + epoch += 1 + + training_time = (time.time() - start_t) / 3600 + print('total training time = {} hours'.format(training_time)) + print('* best acc = {}'.format(best_top1_acc)) + + +def train_kd(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + + model_student.train() + model_teacher.eval() + end = time.time() + scheduler.step() + + for param_group in optimizer.param_groups: + cur_lr = param_group['lr'] + print('learning_rate:', cur_lr) + + for i, (images, target) in enumerate(train_loader): + data_time.update(time.time() - end) + images = images.cuda() + target = target.cuda() + + # compute outputy + logits_student = model_student(images) + logits_teacher = model_teacher(images) + loss = criterion(logits_student, logits_teacher) + + # measure accuracy and record loss + prec1, prec5 = accuracy(logits_student, target, topk=(1, 5)) + n = images.size(0) + losses.update(loss.item(), n) #accumulated loss + top1.update(prec1.item(), n) + top5.update(prec5.item(), n) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + + # clip gradient if necessary + if args.agc: + parameters_list = [] + for name, p in model_student.named_parameters(): + if not 'fc' in name: + parameters_list.append(p) + adaptive_clip_grad(parameters_list, clip_factor=args.clip_value) + + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i%50 == 0: + progress.display(i) + + return losses.avg, top1.avg, top5.avg + +def train(epoch, train_loader, model_student, criterion, optimizer, scheduler): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + + model_student.train() + end = time.time() + scheduler.step() + + for param_group in optimizer.param_groups: + cur_lr = param_group['lr'] + print('learning_rate:', cur_lr) + + for i, (images, target) in enumerate(train_loader): + data_time.update(time.time() - end) + images = images.cuda() + target = target.cuda() + + # compute outputy + logits_student = model_student(images) + loss = criterion(logits_student, target) + + # measure accuracy and record loss + prec1, prec5 = accuracy(logits_student, target, topk=(1, 5)) + n = images.size(0) + losses.update(loss.item(), n) #accumulated loss + top1.update(prec1.item(), n) + top5.update(prec5.item(), n) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + + # clip gradient if necessary + if args.agc: + parameters_list = [] + for name, p in model_student.named_parameters(): + if not 'fc' in name: + parameters_list.append(p) + adaptive_clip_grad(parameters_list, clip_factor=args.clip_value) + + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i%50 == 0: + progress.display(i) + + return losses.avg, top1.avg, top5.avg + +def validate(epoch, val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, losses, top1, top5], + prefix='Test: ') + + # switch to evaluation mode + model.eval() + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + images = images.cuda() + target = target.cuda() + + # compute output + logits = model(images) + loss = criterion(logits, target) + + # measure accuracy and record loss + pred1, pred5 = accuracy(logits, target, topk=(1, 5)) + n = images.size(0) + losses.update(loss.item(), n) + top1.update(pred1[0], n) + top5.update(pred5[0], n) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i%50 == 0: + progress.display(i) + + print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return losses.avg, top1.avg, top5.avg + +if __name__ == '__main__': + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..559fd5d --- /dev/null +++ b/utils.py @@ -0,0 +1,284 @@ +import os +import sys +import shutil +import numpy as np +import time, datetime +import torch +import random +import logging +import argparse +import torch.nn as nn +import torch.utils +import torchvision.datasets as dset +import torchvision.transforms as transforms +import torch.backends.cudnn as cudnn +from torch.nn import functional as F +from torch.nn.modules import loss + +from PIL import Image +from torch.autograd import Variable +from dataset import * + +# All models +#reactnet-18 +from models.Qa_reactnet_18_bn import birealnet18 as Qa_reactnet_18_bn +from models.Qa_reactnet_18_none import birealnet18 as Qa_reactnet_18_none +from models.Qa_reactnet_18_bf import birealnet18 as Qa_reactnet_18_bf + +from models.Qaw_reactnet_18_bn import birealnet18 as Qaw_reactnet_18_bn +from models.Qaw_reactnet_18_none import birealnet18 as Qaw_reactnet_18_none +from models.Qaw_reactnet_18_bf import birealnet18 as Qaw_reactnet_18_bf + +#reactnet-A +from models.Qa_reactnet_A_bn import reactnet as Qa_reactnet_A_bn +from models.Qa_reactnet_A_none import reactnet as Qa_reactnet_A_none +from models.Qa_reactnet_A_bf import reactnet as Qa_reactnet_A_bf + +from models.Qaw_reactnet_A_bn import reactnet as Qaw_reactnet_A_bn +from models.Qaw_reactnet_A_none import reactnet as Qaw_reactnet_A_none +from models.Qaw_reactnet_A_bf import reactnet as Qaw_reactnet_A_bf + + +#lighting data augmentation +imagenet_pca = { + 'eigval': np.asarray([0.2175, 0.0188, 0.0045]), + 'eigvec': np.asarray([ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203], + ]) +} + +class Lighting(object): + def __init__(self, alphastd, + eigval=imagenet_pca['eigval'], + eigvec=imagenet_pca['eigvec']): + self.alphastd = alphastd + assert eigval.shape == (3,) + assert eigvec.shape == (3, 3) + self.eigval = eigval + self.eigvec = eigvec + + def __call__(self, img): + if self.alphastd == 0.: + return img + rnd = np.random.randn(3) * self.alphastd + rnd = rnd.astype('float32') + v = rnd + old_dtype = np.asarray(img).dtype + v = v * self.eigval + v = v.reshape((3, 1)) + inc = np.dot(self.eigvec, v).reshape((3,)) + img = np.add(img, inc) + if old_dtype == np.uint8: + img = np.clip(img, 0, 255) + img = Image.fromarray(img.astype(old_dtype), 'RGB') + return img + + def __repr__(self): + return self.__class__.__name__ + '()' + +#label smooth +class CrossEntropyLabelSmooth(nn.Module): + + def __init__(self, num_classes, epsilon): + super(CrossEntropyLabelSmooth, self).__init__() + self.num_classes = num_classes + self.epsilon = epsilon + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, targets): + log_probs = self.logsoftmax(inputs) + targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes + loss = (-targets * log_probs).mean(0).sum() + return loss + +class DistributionLoss(loss._Loss): + """The KL-Divergence loss for the binary student model and real teacher output. + + output must be a pair of (model_output, real_output), both NxC tensors. + The rows of real_output must all add up to one (probability scores); + however, model_output must be the pre-softmax output of the network.""" + + def forward(self, model_output, real_output): + + self.size_average = True + + # Target is ignored at training time. Loss is defined as KL divergence + # between the model output and the refined labels. + if real_output.requires_grad: + raise ValueError("real network output should not require gradients.") + + model_output_log_prob = F.log_softmax(model_output, dim=1) + real_output_soft = F.softmax(real_output, dim=1) + del model_output, real_output + + # Loss is -dot(model_output_log_prob, real_output). Prepare tensors + # for batch matrix multiplicatio + real_output_soft = real_output_soft.unsqueeze(1) + model_output_log_prob = model_output_log_prob.unsqueeze(2) + + # Compute the loss, and average/sum for the batch. + cross_entropy_loss = -torch.bmm(real_output_soft, model_output_log_prob) + if self.size_average: + cross_entropy_loss = cross_entropy_loss.mean() + else: + cross_entropy_loss = cross_entropy_loss.sum() + # Return a pair of (loss_output, model_output). Model output will be + # used for top-1 and top-5 evaluation. + # model_output_log_prob = model_output_log_prob.squeeze(2) + return cross_entropy_loss + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + +def save_checkpoint(state, is_best, save): + if not os.path.exists(save): + os.makedirs(save) + filename = os.path.join(save, 'checkpoint.pth.tar') + torch.save(state, filename) + if is_best: + best_filename = os.path.join(save, 'model_best.pth.tar') + shutil.copyfile(filename, best_filename) + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + +def setup_model_dataloader(args): + + # Dataset + if_imagenet = False + + if args.dataset == 'imagenet': + print('* Dataset = ImageNet') + train_loader, val_loader = imagenet_dataloaders(args.batch_size, args.data, args.workers) + classes = 1000 + if_imagenet = True + + elif args.dataset == 'cifar10': + print('* Dataset = CIFAR10') + train_loader, val_loader = cifar10_dataloaders(args.batch_size, args.data, args.workers) + classes = 10 + + elif args.dataset == 'cifar100': + print('* Dataset = CIFAR100') + train_loader, val_loader = cifar100_dataloaders(args.batch_size, args.data, args.workers) + classes = 100 + + else: + raise ValueError('unknow dataset') + + # architecture + if args.arch == 'reactnet-18': + print('* Model = ReActNet-18') + if args.binary_w: + print('* Binarize both activation and weights') + if args.bn_type == 'bn': + print('* with BN') + model = Qaw_reactnet_18_bn(num_classes=classes) + elif args.bn_type == 'none': + print('* without BN') + model = Qaw_reactnet_18_none(num_classes=classes) + elif args.bn_type == 'bf': + print('* BN-Free') + model = Qaw_reactnet_18_bf(num_classes=classes) + + else: + print('* Binarize only activation') + if args.bn_type == 'bn': + print('* with BN') + model = Qa_reactnet_18_bn(num_classes=classes) + elif args.bn_type == 'none': + print('* without BN') + model = Qa_reactnet_18_none(num_classes=classes) + elif args.bn_type == 'bf': + print('* BN-Free') + model = Qa_reactnet_18_bf(num_classes=classes) + + + elif args.arch == 'reactnet-A': + print('* Model = reactnet-A') + if args.binary_w: + print('* Binarize both activation and weights') + if args.bn_type == 'bn': + print('* with BN') + model = Qaw_reactnet_A_bn(num_classes=classes) + elif args.bn_type == 'none': + print('* without BN') + model = Qaw_reactnet_A_none(num_classes=classes) + elif args.bn_type == 'bf': + print('* BN-Free') + model = Qaw_reactnet_A_bf(num_classes=classes) + + else: + print('* Binarize only activation') + if args.bn_type == 'bn': + print('* with BN') + model = Qa_reactnet_A_bn(num_classes=classes) + elif args.bn_type == 'none': + print('* without BN') + model = Qa_reactnet_A_none(num_classes=classes) + elif args.bn_type == 'bf': + print('* BN-Free') + model = Qa_reactnet_A_bf(num_classes=classes) + + return train_loader, val_loader, model, classes + + +