-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2e6d3a4
commit e4e0a23
Showing
23 changed files
with
2,858 additions
and
1 deletion.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch | ||
|
||
def unitwise_norm(x, norm_type=2.0): | ||
if x.ndim <= 1: | ||
return x.norm(norm_type) | ||
else: | ||
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor | ||
# might need special cases for other weights (possibly MHA) where this may not be true | ||
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) | ||
|
||
|
||
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): | ||
if isinstance(parameters, torch.Tensor): | ||
parameters = [parameters] | ||
for p in parameters: | ||
if p.grad is None: | ||
continue | ||
p_data = p.detach() | ||
g_data = p.grad.detach() | ||
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) | ||
grad_norm = unitwise_norm(g_data, norm_type=norm_type) | ||
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) | ||
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) | ||
p.grad.detach().copy_(new_grads) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import os | ||
import numpy as np | ||
from PIL import Image | ||
from torchvision import transforms | ||
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder | ||
from torch.utils.data import DataLoader | ||
|
||
__all__ = ['cifar10_dataloaders', 'cifar100_dataloaders', 'imagenet_dataloaders'] | ||
|
||
#lighting data augmentation | ||
imagenet_pca = { | ||
'eigval': np.asarray([0.2175, 0.0188, 0.0045]), | ||
'eigvec': np.asarray([ | ||
[-0.5675, 0.7192, 0.4009], | ||
[-0.5808, -0.0045, -0.8140], | ||
[-0.5836, -0.6948, 0.4203], | ||
]) | ||
} | ||
|
||
class Lighting(object): | ||
def __init__(self, alphastd, | ||
eigval=imagenet_pca['eigval'], | ||
eigvec=imagenet_pca['eigvec']): | ||
self.alphastd = alphastd | ||
assert eigval.shape == (3,) | ||
assert eigvec.shape == (3, 3) | ||
self.eigval = eigval | ||
self.eigvec = eigvec | ||
|
||
def __call__(self, img): | ||
if self.alphastd == 0.: | ||
return img | ||
rnd = np.random.randn(3) * self.alphastd | ||
rnd = rnd.astype('float32') | ||
v = rnd | ||
old_dtype = np.asarray(img).dtype | ||
v = v * self.eigval | ||
v = v.reshape((3, 1)) | ||
inc = np.dot(self.eigvec, v).reshape((3,)) | ||
img = np.add(img, inc) | ||
if old_dtype == np.uint8: | ||
img = np.clip(img, 0, 255) | ||
img = Image.fromarray(img.astype(old_dtype), 'RGB') | ||
return img | ||
|
||
def __repr__(self): | ||
return self.__class__.__name__ + '()' | ||
|
||
|
||
def cifar10_dataloaders(batch_size=128, data_dir = 'datasets/cifar10', worker=4): | ||
|
||
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], | ||
std=[0.2470, 0.2435, 0.2616]) | ||
|
||
train_transform = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
|
||
test_transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
|
||
train_set = CIFAR10(data_dir, train=True, transform=train_transform, download=True) | ||
test_set = CIFAR10(data_dir, train=False, transform=test_transform, download=True) | ||
|
||
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=worker, pin_memory=True) | ||
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=worker, pin_memory=True) | ||
|
||
return train_loader, test_loader | ||
|
||
def cifar100_dataloaders(batch_size=128, data_dir = 'datasets/cifar100', worker=4): | ||
|
||
normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], | ||
std=[0.2673, 0.2564, 0.2762]) | ||
|
||
train_transform = transforms.Compose([ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
|
||
test_transform = transforms.Compose([ | ||
transforms.ToTensor(), | ||
normalize | ||
]) | ||
|
||
train_set = CIFAR100(data_dir, train=True, transform=train_transform, download=True) | ||
test_set = CIFAR100(data_dir, train=False, transform=test_transform, download=True) | ||
|
||
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=worker, pin_memory=True) | ||
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=worker, pin_memory=True) | ||
|
||
return train_loader, test_loader | ||
|
||
def imagenet_dataloaders(batch_size=128, data_dir = 'datasets/cifar100', worker=4): | ||
|
||
traindir = os.path.join(data_dir, 'train') | ||
valdir = os.path.join(data_dir, 'val') | ||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]) | ||
|
||
# data augmentation | ||
crop_scale = 0.08 | ||
lighting_param = 0.1 | ||
train_transforms = transforms.Compose([ | ||
transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)), | ||
Lighting(lighting_param), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
normalize]) | ||
|
||
train_dataset = ImageFolder( | ||
traindir, | ||
transform=train_transforms) | ||
|
||
train_loader = DataLoader(train_dataset, | ||
batch_size=batch_size, shuffle=True, | ||
num_workers=worker, pin_memory=True) | ||
|
||
# load validation data | ||
val_loader = DataLoader( | ||
ImageFolder(valdir, transforms.Compose([ | ||
transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
normalize, | ||
])), | ||
batch_size=batch_size, shuffle=False, | ||
num_workers=worker, pin_memory=True) | ||
|
||
return train_loader, val_loader | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
__all__ = ['LambdaLayer', 'ScaledStdConv2d', 'HardBinaryScaledStdConv2d', 'LearnableBias','BinaryActivation', 'HardBinaryConv'] | ||
|
||
def get_weight(module): | ||
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) | ||
weight = (module.weight - mean) / (std + module.eps) | ||
return weight | ||
|
||
# Calculate symmetric padding for a convolution | ||
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: | ||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 | ||
return padding | ||
|
||
class LambdaLayer(nn.Module): | ||
def __init__(self, lambd): | ||
super(LambdaLayer, self).__init__() | ||
self.lambd = lambd | ||
|
||
def forward(self, x): | ||
return self.lambd(x) | ||
|
||
|
||
class ScaledStdConv2d(nn.Conv2d): | ||
"""Conv2d layer with Scaled Weight Standardization. | ||
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - | ||
https://arxiv.org/abs/2101.08692 | ||
NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. | ||
""" | ||
|
||
def __init__( | ||
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, | ||
bias=False, gamma=1.0, eps=1e-5, use_layernorm=False): | ||
if padding is None: | ||
padding = get_padding(kernel_size, stride, dilation) | ||
super().__init__( | ||
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, | ||
groups=groups, bias=bias) | ||
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) | ||
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) | ||
self.eps = eps ** 2 if use_layernorm else eps | ||
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel | ||
|
||
def get_weight(self): | ||
if self.use_layernorm: | ||
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) | ||
else: | ||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) | ||
weight = self.scale * (self.weight - mean) / (std + self.eps) | ||
return self.gain * weight | ||
|
||
def forward(self, x): | ||
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) | ||
|
||
class HardBinaryScaledStdConv2d(nn.Module): | ||
|
||
def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1, gamma=1.0, eps=1e-5, use_layernorm=False): | ||
super(HardBinaryScaledStdConv2d, self).__init__() | ||
self.stride = stride | ||
self.padding = padding | ||
self.shape = (out_chn, in_chn, kernel_size, kernel_size) | ||
self.weight = nn.Parameter(torch.rand(self.shape) * 0.001, requires_grad=True) | ||
|
||
self.gain = nn.Parameter(torch.ones(out_chn, 1, 1, 1)) | ||
self.scale = gamma * self.weight[0].numel() ** -0.5 | ||
self.eps = eps ** 2 if use_layernorm else eps | ||
self.use_layernorm = use_layernorm | ||
|
||
def get_weight(self): | ||
if self.use_layernorm: | ||
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) | ||
else: | ||
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) | ||
weight = self.scale * (self.weight - mean) / (std + self.eps) | ||
|
||
scaling_factor = torch.mean(torch.mean(torch.mean(abs(weight),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True) | ||
scaling_factor = scaling_factor.detach() | ||
binary_weights_no_grad = scaling_factor * torch.sign(weight) | ||
cliped_weights = torch.clamp(weight, -1.0, 1.0) | ||
binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights | ||
|
||
return self.gain * binary_weights | ||
|
||
def forward(self, x): | ||
|
||
return F.conv2d(x, self.get_weight(), stride=self.stride, padding=self.padding) | ||
|
||
class LearnableBias(nn.Module): | ||
def __init__(self, out_chn): | ||
super(LearnableBias, self).__init__() | ||
self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True) | ||
|
||
def forward(self, x): | ||
out = x + self.bias.expand_as(x) | ||
return out | ||
|
||
class BinaryActivation(nn.Module): | ||
def __init__(self): | ||
super(BinaryActivation, self).__init__() | ||
|
||
def forward(self, x): | ||
out_forward = torch.sign(x) | ||
#out_e1 = (x^2 + 2*x) | ||
#out_e2 = (-x^2 + 2*x) | ||
out_e_total = 0 | ||
mask1 = x < -1 | ||
mask2 = x < 0 | ||
mask3 = x < 1 | ||
out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32)) | ||
out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32)) | ||
out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32)) | ||
out = out_forward.detach() - out3.detach() + out3 | ||
|
||
return out | ||
|
||
class HardBinaryConv(nn.Module): | ||
def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1): | ||
super(HardBinaryConv, self).__init__() | ||
self.stride = stride | ||
self.padding = padding | ||
self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size | ||
self.shape = (out_chn, in_chn, kernel_size, kernel_size) | ||
#self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True) | ||
self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True) | ||
|
||
def forward(self, x): | ||
real_weights = self.weight | ||
scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True) | ||
scaling_factor = scaling_factor.detach() | ||
binary_weights_no_grad = scaling_factor * torch.sign(real_weights) | ||
cliped_weights = torch.clamp(real_weights, -1.0, 1.0) | ||
binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights | ||
y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding) | ||
|
||
return y | ||
|
Oops, something went wrong.