Skip to content

Commit

Permalink
add VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
yusugomori committed Apr 13, 2019
1 parent 7a050d4 commit 25eeab2
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ $ pip install torch torchvision
* Encoder-Decoder (Attention)
* Transformer
* Deep Q-Network
* Variational Autoencoder
* Generative Adversarial Network
* Conditional GAN

Expand All @@ -38,6 +39,7 @@ models/
├── resnet34_fashion_mnist.py
├── resnet50_fashion_mnist.py
├── transformer.py
├── vae_fashion_mnist.py
└── layers/
   ├── Attention.py
Expand Down
178 changes: 178 additions & 0 deletions models/vae_fashion_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimizers
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt


class VAE(nn.Module):
'''
Simple Variational Autoencoder
'''
def __init__(self, device='cpu'):
super().__init__()
self.device = device
self.encoder = Encoder(device=device)
self.decoder = Decoder(device=device)

def forward(self, x):
mean, var = self.encoder(x)
z = self.reparameterize(mean, var)
y = self.decoder(z)

return y

def reparameterize(self, mean, var):
eps = torch.randn(mean.size()).to(self.device)
z = mean + torch.sqrt(var) * eps
return z

def lower_bound(self, x):
mean, var = self.encoder(x)
kl = - 1/2 * torch.mean(torch.sum(1
+ torch.log(var)
- mean**2
- var, dim=1))
z = self.reparameterize(mean, var)
y = self.decoder(z)

reconst = torch.mean(torch.sum(x * torch.log(y)
+ (1 - x) * torch.log(1 - y), dim=1))

return reconst - kl


class Encoder(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
self.device = device
self.l1 = nn.Linear(784, 200)
self.l2 = nn.Linear(200, 200)
self.l_mean = nn.Linear(200, 10)
self.l_var = nn.Linear(200, 10)

def forward(self, x):
h = self.l1(x)
h = torch.relu(h)
h = self.l2(h)
h = torch.relu(h)

mean = self.l_mean(h)
var = F.softplus(self.l_var(h))

return mean, var


class Decoder(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
self.device = device
self.l1 = nn.Linear(10, 200)
self.l2 = nn.Linear(200, 200)
self.out = nn.Linear(200, 784)

def forward(self, x):
h = self.l1(x)
h = torch.relu(h)
h = self.l2(h)
h = torch.relu(h)
h = self.out(h)
y = torch.sigmoid(h)

return y


if __name__ == '__main__':
np.random.seed(1234)
torch.manual_seed(1234)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def compute_loss(x):
return -1 * criterion(x)

def train_step(x):
model.train()
loss = compute_loss(x)

optimizer.zero_grad()
loss.backward()
optimizer.step()

return loss

def generate(batch_size=10):
model.eval()
z = gen_noise(batch_size)
gen = model.decoder(z)
gen = gen.view(-1, 28, 28)

return gen

def gen_noise(batch_size):
return torch.empty(batch_size, 10).normal_().to(device)

'''
Load data
'''
root = os.path.join(os.path.dirname(__file__),
'..', 'data', 'fashion_mnist')
transform = transforms.Compose([transforms.ToTensor(),
lambda x: x.view(-1)])
mnist_train = \
torchvision.datasets.FashionMNIST(root=root,
download=True,
train=True,
transform=transform)
train_dataloader = DataLoader(mnist_train,
batch_size=100,
shuffle=True)

'''
Build model
'''
model = VAE(device=device).to(device)
criterion = model.lower_bound
optimizer = optimizers.Adam(model.parameters())

'''
Train model
'''
epochs = 10
out_path = os.path.join(os.path.dirname(__file__),
'..', 'output')

for epoch in range(epochs):
train_loss = 0.

for (x, _) in train_dataloader:
x = x.to(device)
loss = train_step(x)

train_loss += loss.item()

train_loss /= len(train_dataloader)

print('Epoch: {}, Cost: {:.3f}'.format(
epoch+1,
train_loss
))

if epoch % 5 == 4 or epoch == epochs - 1:
images = generate(batch_size=16)
images = images.squeeze().detach().cpu().numpy()
plt.figure(figsize=(6, 6))
for i, image in enumerate(images):
plt.subplot(4, 4, i+1)
plt.imshow(image, cmap='binary')
plt.axis('off')
plt.tight_layout()
# plt.show()
template = '{}/vae_fashion_mnist_epoch_{:0>4}.png'
plt.savefig(template.format(out_path, epoch+1), dpi=300)

0 comments on commit 25eeab2

Please sign in to comment.