-
Notifications
You must be signed in to change notification settings - Fork 249
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add new GAN to regular synth - DRAGAN (#48)
* feat(regular): Add new GAN - DRAGAN * feat(regular): Add DRAGAN train step and example. * feat(ragular): DRAGAN example. * fix: Remove sigmoid activation from the critics
- Loading branch information
Showing
7 changed files
with
236 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from ydata_synthetic.preprocessing.regular.adult import transformations | ||
from ydata_synthetic.synthesizers.regular import DRAGAN | ||
|
||
#Load and process the data | ||
data, processed_data, preprocessor = transformations() | ||
|
||
# WGAN_GP training | ||
#Defininf the training parameters of WGAN_GP | ||
|
||
noise_dim = 128 | ||
dim = 128 | ||
batch_size = 500 | ||
|
||
log_step = 100 | ||
epochs = 200+1 | ||
learning_rate = 1e-5 | ||
beta_1 = 0.5 | ||
beta_2 = 0.9 | ||
models_dir = './cache' | ||
|
||
gan_args = [batch_size, learning_rate, beta_1, beta_2, noise_dim, processed_data.shape[1], dim] | ||
train_args = ['', epochs, log_step] | ||
|
||
synthesizer = DRAGAN(gan_args, n_discriminator=3) | ||
synthesizer.train(processed_data, train_args) | ||
|
||
synth_data = synthesizer.sample(1000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from tensorflow import random | ||
from tensorflow import reshape, shape, math, GradientTape, reduce_mean | ||
from tensorflow import norm as tfnorm | ||
|
||
## Original code loss from | ||
## https://github.com/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Tensorflow-2/blob/master/tf2gan/loss.py | ||
def gradient_penalty(f, real, fake, mode): | ||
def _gradient_penalty(f, real, fake=None): | ||
def _interpolate(a, b=None): | ||
if b is None: # interpolation in DRAGAN | ||
beta = random.uniform(shape=shape(a), minval=0., maxval=1.) | ||
b = a + 0.5 * math.reduce_std(a) * beta | ||
shape_ = [shape(a)[0]] + [1] * (a.shape.ndims - 1) | ||
alpha = random.uniform(shape=shape_, minval=0., maxval=1.) | ||
inter = a + alpha * (b - a) | ||
inter.set_shape(a.shape) | ||
return inter | ||
|
||
x = _interpolate(real, fake) | ||
with GradientTape() as t: | ||
t.watch(x) | ||
pred = f(x) | ||
grad = t.gradient(pred, x) | ||
norm = tfnorm(reshape(grad, [shape(grad)[0], -1]), axis=1) | ||
gp = reduce_mean((norm - 1.)**2) | ||
|
||
return gp | ||
|
||
if mode == 'dragan': | ||
gp = _gradient_penalty(f, real) | ||
elif mode == 'wgangp': | ||
gp = _gradient_penalty(f, real, fake) | ||
|
||
return gp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
170 changes: 170 additions & 0 deletions
170
src/ydata_synthetic/synthesizers/regular/dragan/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import os | ||
from os import path | ||
|
||
import numpy as np | ||
import tqdm | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras.optimizers import Adam | ||
from tensorflow.keras.layers import Input, Dense, Dropout | ||
from tensorflow.keras import Model, initializers | ||
|
||
from ydata_synthetic.synthesizers import gan | ||
from ydata_synthetic.synthesizers.loss import gradient_penalty | ||
|
||
import pandas as pd | ||
|
||
class DRAGAN(gan.Model): | ||
|
||
def __init__(self, model_parameters, n_discriminator, gradient_penalty_weight=10): | ||
# As recommended in DRAGAN paper - https://arxiv.org/abs/1705.07215 | ||
self.n_discriminator = n_discriminator | ||
self.gradient_penalty_weight = gradient_penalty_weight | ||
super().__init__(model_parameters) | ||
|
||
def define_gan(self): | ||
# define generator/discriminator | ||
self.generator = Generator(self.batch_size). \ | ||
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim) | ||
|
||
self.discriminator = Discriminator(self.batch_size). \ | ||
build_model(input_shape=(self.data_dim,), dim=self.layers_dim) | ||
|
||
self.g_optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2, clipvalue=0.001) | ||
self.d_optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2, clipvalue=0.001) | ||
|
||
def gradient_penalty(self, real, fake): | ||
gp = gradient_penalty(self.discriminator, real, fake, mode='dragan') | ||
return gp | ||
|
||
def update_gradients(self, x): | ||
""" | ||
Compute the gradients for both the Generator and the Discriminator | ||
:param x: real data event | ||
:return: generator gradients, discriminator gradients | ||
""" | ||
# Update the gradients of critic for n_critic times (Training the critic) | ||
for _ in range(self.n_discriminator): | ||
with tf.GradientTape() as d_tape: | ||
d_loss = self.d_lossfn(x) | ||
# Get the gradients of the critic | ||
d_gradient = d_tape.gradient(d_loss, self.discriminator.trainable_variables) | ||
# Update the weights of the critic using the optimizer | ||
self.d_optimizer.apply_gradients( | ||
zip(d_gradient, self.discriminator.trainable_variables) | ||
) | ||
|
||
# Update the generator | ||
with tf.GradientTape() as g_tape: | ||
gen_loss = self.g_lossfn(x) | ||
|
||
# Get the gradients of the generator | ||
gen_gradients = g_tape.gradient(gen_loss, self.generator.trainable_variables) | ||
|
||
# Update the weights of the generator | ||
self.g_optimizer.apply_gradients( | ||
zip(gen_gradients, self.generator.trainable_variables) | ||
) | ||
|
||
return d_loss, gen_loss | ||
|
||
def d_lossfn(self, real): | ||
""" | ||
Calculates the critic losses | ||
""" | ||
noise = tf.random.normal((self.batch_size, self.noise_dim), dtype=tf.dtypes.float64) | ||
# run noise through generator | ||
fake = self.generator(noise) | ||
# discriminate x and x_gen | ||
logits_real = self.discriminator(real, training=True) | ||
logits_fake = self.discriminator(fake, training=True) | ||
|
||
# gradient penalty | ||
gp = self.gradient_penalty(real, fake) | ||
|
||
# getting the loss of the discriminator. | ||
d_loss = (tf.reduce_mean(logits_fake) | ||
- tf.reduce_mean(logits_real) | ||
+ gp * self.gradient_penalty_weight) | ||
return d_loss | ||
|
||
# generator loss | ||
def g_lossfn(self, real): | ||
""" | ||
Calculates the Generator losses | ||
:param real: Data batch we are analyzing | ||
:return: Loss of the generator | ||
""" | ||
# generating noise from a uniform distribution | ||
noise = tf.random.normal((real.shape[0], self.noise_dim), dtype=tf.float64) | ||
|
||
fake = self.generator(noise, training=True) | ||
logits_fake = self.discriminator(fake, training=True) | ||
g_loss = -tf.reduce_mean(logits_fake) | ||
return g_loss | ||
|
||
def get_data_batch(self, train, batch_size): | ||
buffer_size = len(train) | ||
#tensor_data = pd.concat([x_train, y_train], axis=1) | ||
train_loader = tf.data.Dataset.from_tensor_slices(train) \ | ||
.batch(batch_size).shuffle(buffer_size) | ||
return train_loader | ||
|
||
def train_step(self, train_data): | ||
d_loss, g_loss = self.update_gradients(train_data) | ||
return d_loss, g_loss | ||
|
||
def train(self, data, train_arguments): | ||
[cache_prefix, iterations, sample_interval] = train_arguments | ||
train_loader = self.get_data_batch(data, self.batch_size) | ||
|
||
# Create a summary file | ||
train_summary_writer = tf.summary.create_file_writer(path.join('..\dragan_test', 'summaries', 'train')) | ||
|
||
with train_summary_writer.as_default(): | ||
for iteration in tqdm.trange(iterations): | ||
for batch_data in train_loader: | ||
batch_data = tf.cast(batch_data, dtype=tf.float32) | ||
d_loss, g_loss = self.train_step(batch_data) | ||
|
||
print( | ||
"Iteration: {} | disc_loss: {} | gen_loss: {}".format( | ||
iteration, d_loss, g_loss | ||
)) | ||
|
||
if iteration % sample_interval == 0: | ||
# Test here data generation step | ||
# save model checkpoints | ||
if path.exists('./cache') is False: | ||
os.mkdir('./cache') | ||
model_checkpoint_base_name = './cache/' + cache_prefix + '_{}_model_weights_step_{}.h5' | ||
self.generator.save_weights(model_checkpoint_base_name.format('generator', iteration)) | ||
self.discriminator.save_weights(model_checkpoint_base_name.format('discriminator', iteration)) | ||
|
||
|
||
class Discriminator(Model): | ||
def __init__(self, batch_size): | ||
self.batch_size = batch_size | ||
|
||
def build_model(self, input_shape, dim): | ||
input = Input(shape=input_shape, batch_size=self.batch_size) | ||
x = Dense(dim * 4, kernel_initializer=initializers.TruncatedNormal(mean=0., stddev=0.5), activation='relu')(input) | ||
x = Dropout(0.1)(x) | ||
x = Dense(dim * 2, activation='relu')(x) | ||
x = Dropout(0.1)(x) | ||
x = Dense(dim, activation='relu')(x) | ||
x = Dense(1, activation='sigmoid')(x) | ||
return Model(inputs=input, outputs=x) | ||
|
||
class Generator(Model): | ||
def __init__(self, batch_size): | ||
self.batch_size = batch_size | ||
|
||
def build_model(self, input_shape, dim, data_dim): | ||
input = Input(shape=input_shape, batch_size = self.batch_size) | ||
x = Dense(dim, kernel_initializer=initializers.TruncatedNormal(mean=0., stddev=0.5), activation='relu')(input) | ||
x = Dense(dim * 2, activation='relu')(x) | ||
x = Dense(dim * 4, activation='relu')(x) | ||
x = Dense(data_dim)(x) | ||
return Model(inputs=input, outputs=x) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters