From f4bdede841dffe60dc6ab292d6ffbe85c31447d5 Mon Sep 17 00:00:00 2001 From: Francisco Santos Date: Mon, 8 Nov 2021 16:50:54 +0000 Subject: [PATCH] PATEGAN base implementation --- requirements.txt | 1 + .../synthesizers/regular/__init__.py | 4 +- .../synthesizers/regular/pategan/__init__.py | 0 .../synthesizers/regular/pategan/model.py | 256 ++++++++++++++++++ .../test_activation_interface.py | 72 +++++ .../custom_layers/test_gumbel_softmax.py | 54 ++++ 6 files changed, 386 insertions(+), 1 deletion(-) create mode 100644 src/ydata_synthetic/synthesizers/regular/pategan/__init__.py create mode 100644 src/ydata_synthetic/synthesizers/regular/pategan/model.py create mode 100644 src/ydata_synthetic/tests/custom_layers/test_activation_interface.py create mode 100644 src/ydata_synthetic/tests/custom_layers/test_gumbel_softmax.py diff --git a/requirements.txt b/requirements.txt index 3f806133..b140d5ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ pmlb==1.0.* tqdm<5.0 typeguard==2.13.* pytest==6.2.* +tensorflow_probability==0.12.* diff --git a/src/ydata_synthetic/synthesizers/regular/__init__.py b/src/ydata_synthetic/synthesizers/regular/__init__.py index 9f0464da..435274cb 100644 --- a/src/ydata_synthetic/synthesizers/regular/__init__.py +++ b/src/ydata_synthetic/synthesizers/regular/__init__.py @@ -4,6 +4,7 @@ from ydata_synthetic.synthesizers.regular.wgangp.model import WGAN_GP from ydata_synthetic.synthesizers.regular.dragan.model import DRAGAN from ydata_synthetic.synthesizers.regular.cramergan.model import CRAMERGAN +from ydata_synthetic.synthesizers.regular.pategan.model import PATEGAN __all__ = [ "VanilllaGAN", @@ -11,5 +12,6 @@ "WGAN", "WGAN_GP", "DRAGAN", - "CRAMERGAN" + "CRAMERGAN", + "PATEGAN" ] diff --git a/src/ydata_synthetic/synthesizers/regular/pategan/__init__.py b/src/ydata_synthetic/synthesizers/regular/pategan/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ydata_synthetic/synthesizers/regular/pategan/model.py b/src/ydata_synthetic/synthesizers/regular/pategan/model.py new file mode 100644 index 00000000..650e67e8 --- /dev/null +++ b/src/ydata_synthetic/synthesizers/regular/pategan/model.py @@ -0,0 +1,256 @@ +"PATEGAN implementation supporting Differential Privacy budget specification." +# pylint: disable = W0622, E0401 +from math import log +from typing import List, NamedTuple, Optional + +import tqdm +from tensorflow import (GradientTape, clip_by_value, concat, constant, + expand_dims, ones_like, tensor_scatter_nd_update, + transpose, zeros, zeros_like) +from tensorflow.data import Dataset +from tensorflow.dtypes import cast, float64, int64 +from tensorflow.keras import Model +from tensorflow.keras.layers import Dense, Input, ReLU +from tensorflow.keras.losses import BinaryCrossentropy +from tensorflow.keras.optimizers import Adam +from tensorflow.math import abs, exp, pow, reduce_sum, square +from tensorflow.random import uniform +from tensorflow_probability import distributions + +from ydata_synthetic.synthesizers import TrainParameters +from ydata_synthetic.synthesizers.gan import BaseModel +from ydata_synthetic.utils.gumbel_softmax import ActivationInterface + + +# pylint: disable=R0902 +class PATEGAN(BaseModel): + "A basic PATEGAN synthesizer implementation with configurable differential privacy budget." + + __MODEL__='PATEGAN' + + def __init__(self, model_parameters, n_teachers: int, target_delta: float, target_epsilon: float): + super().__init__(model_parameters) + self.n_teachers = n_teachers + self.target_epsilon = target_epsilon + self.target_delta = target_delta + + # pylint: disable=W0201 + def define_gan(self, processor_info: Optional[NamedTuple] = None): + def discriminator(): + return Discriminator(self.batch_size).build_model((self.data_dim,), self.layers_dim) + + self.generator = Generator(self.batch_size). \ + build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim, + processor_info=processor_info) + self.s_discriminator = discriminator() + self.t_discriminators = [discriminator() for i in range(self.n_teachers)] + + generator_optimizer = Adam(learning_rate=self.g_lr) + discriminator_optimizer = Adam(learning_rate=self.d_lr) + + loss_fn = BinaryCrossentropy(from_logits=True) + self.generator.compile(loss=loss_fn, optimizer=generator_optimizer) + self.s_discriminator.compile(loss=loss_fn, optimizer=discriminator_optimizer) + for teacher in self.t_discriminators: + teacher.compile(loss=loss_fn, optimizer=discriminator_optimizer) + + # pylint: disable = C0103 + @staticmethod + def _moments_acc(n_teachers, votes, lap_scale, l_list): + q = (2 + lap_scale * abs(2 * votes - n_teachers))/(4 * exp(lap_scale * abs(2 * votes - n_teachers))) + + update = [] + for l in l_list: + clip = 2 * square(lap_scale) * l * (l + 1) + t = (1 - q) * pow((1 - q) / (1 - exp(2*lap_scale) * q), l) + q * exp(2 * lap_scale * l) + update.append(reduce_sum(clip_by_value(t, clip_value_min=-clip, clip_value_max=clip))) + return cast(update, dtype=float64) + + def get_data_loader(self, data) -> List[Dataset]: + "Obtain a List of TF Datasets corresponding to partitions for each teacher in n_teachers." + loader = [] + SHUFFLE_BUFFER_SIZE = 100 + + for teacher_id in range(self.n_teachers): + start_id = int(teacher_id * len(data) / self.n_teachers) + end_id = int((teacher_id + 1) * len(data) / self.n_teachers if \ + teacher_id != (self.n_teachers - 1) else len(data)) + loader.append(Dataset.from_tensor_slices(data[start_id:end_id:])\ + .batch(self.batch_size).shuffle(SHUFFLE_BUFFER_SIZE)) + return loader + + # pylint:disable=R0913 + def train(self, data, class_ratios, train_arguments: TrainParameters, num_cols: List[str], cat_cols: List[str]): + """ + Args: + data: A pandas DataFrame or a Numpy array with the data to be synthesized + class_ratios: + train_arguments: GAN training arguments. + num_cols: List of columns of the data object to be handled as numerical + cat_cols: List of columns of the data object to be handled as categorical + """ + super().train(data, num_cols, cat_cols) + + data = self.processor.transform(data) + self.data_dim = data.shape[1] + self.define_gan(self.processor.col_transform_info) + + self.class_ratios = class_ratios + + alpha = cast([0.0 for _ in range(train_arguments.num_moments)], float64) + l_list = 1 + cast(range(train_arguments.num_moments), float64) + + # print("initial alpha", l_list.shape) + + cross_entropy = BinaryCrossentropy(from_logits=True) + + generator_optimizer = Adam(learning_rate=train_arguments.lr) + disc_opt_stu = Adam(learning_rate=train_arguments.lr) + disc_opt_t = [Adam(learning_rate=train_arguments.lr) for i in range(self.n_teachers)] + + train_loader = self.get_data_loader(data, self.batch_size) + + steps = 0 + epsilon = 0 + + category_samples = distributions.Categorical(probs=self.class_ratios, dtype=float64) + + while epsilon < self.target_epsilon: + # train the teacher descriminator + for t_2 in range(train_arguments.num_teacher_iters): + for i in range(self.n_teachers): + inputs, categories = None, None + for b, data_ in enumerate(train_loader[i]): + inputs, categories = data_, b + #categories will give zero value in each loop as the loop break after running the first time + #inputs will have only the first batch of data + break + + with GradientTape() as disc_tape: + # train with real + dis_data = concat([inputs, zeros([inputs.shape[0], 1], dtype=float64)], 1) + # print("1st batch data", dis_data.shape) + real_output = self.t_discriminators[i](dis_data, training=True) + # print(real_output.shape, tf.ones.shape) + + # train with fake + z = uniform([inputs.shape[0], self.z_dim], dtype=float64) + # print("uniformly distributed noise", z.shape) + + sample = expand_dims(category_samples.sample(inputs.shape[0]), axis=1) + # print("category", sample.shape) + + fake = self.generator(concat([z, sample], 1)) + # print('fake', fake.shape) + + fake_output = self.t_discriminators[i](concat([fake, sample], 1), training=True) + # print('fake_output_dis', fake_output.shape) + + # print("watch", disc_tape.watch(self.teacher_disc[i].trainable_variables) + real_loss_disc = cross_entropy(ones_like(real_output), real_output) + fake_loss_disc = cross_entropy(zeros_like(fake_output), fake_output) + + disc_loss = real_loss_disc + fake_loss_disc + # print(disc_loss, real_loss_disc, fake_loss_disc) + + gradients_of_discriminator = disc_tape.gradient(disc_loss, self.t_discriminators[i].trainable_variables) + # print(gradients_of_discriminator) + + disc_opt_t[i].apply_gradients(zip(gradients_of_discriminator, self.t_discriminators[i].trainable_variables)) + + # train the student discriminator + for t_3 in range(train_arguments.num_student_iters): + z = uniform([inputs.shape[0], self.z_dim], dtype=float64) + + sample = expand_dims(category_samples.sample(inputs.shape[0]), axis=1) + # print("category_stu", sample.shape) + + with GradientTape() as stu_tape: + fake = self.generator(concat([z, sample], 1)) + # print('fake_stu', fake.shape) + + predictions, clean_votes = self._pate_voting( + concat([fake, sample], 1), self.t_discriminators, train_arguments.lap_scale) + # print("noisy_labels", predictions.shape, "clean_votes", clean_votes.shape) + outputs = self.s_discriminator(concat([fake, sample], 1)) + + # update the moments + alpha = alpha + self._moments_acc(self.n_teachers, clean_votes, train_arguments.lap_scale, l_list) + # print("final_alpha", alpha) + + stu_loss = cross_entropy(predictions, outputs) + gradients_of_stu = stu_tape.gradient(stu_loss, self.s_discriminator.trainable_variables) + # print(gradients_of_stu) + + disc_opt_stu.apply_gradients(zip(gradients_of_stu, self.s_discriminator.trainable_variables)) + + # train the generator + z = uniform([inputs.shape[0], self.z_dim], dtype=float64) + + sample_g = expand_dims(category_samples.sample(inputs.shape[0]), axis=1) + + with GradientTape() as gen_tape: + fake = self.generator(concat([z, sample_g], 1)) + output = self.s_discriminator(concat([fake, sample_g], 1)) + + loss_gen = cross_entropy(ones_like(output), output) + gradients_of_generator = gen_tape.gradient(loss_gen, self.generator.trainable_variables) + generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables)) + + # Calculate the current privacy cost + epsilon = min((alpha - log(self.delta)) / l_list) + if steps % 1 == 0: + print("Step : ", steps, "Loss SD : ", stu_loss, "Loss G : ", loss_gen, "Epsilon : ", epsilon) + + steps += 1 + # self.generator.summary() + + def _pate_voting(self, data, netTD, lap_scale): + # TODO: Validate the logic against original article + ## Faz os votos dos teachers (1/0) netTD para cada record em data e guarda em results + results = zeros([len(netTD), data.shape[0]], dtype=int64) + # print(results) + for i in range(len(netTD)): + output = netTD[i](data, training=True) + pred = transpose(cast((output > 0.5), int64)) + # print(pred) + results = tensor_scatter_nd_update(results, constant([[i]]), pred) + # print(results) + + #guarda o somatorio das probabilidades atribuidas por cada disc a cada record (valores entre 0 e len(netTD)) + clean_votes = expand_dims(cast(reduce_sum(results, 0), dtype=float64), 1) + # print("clean_votes",clean_votes) + noise_sample = distributions.Laplace(loc=0, scale=1/lap_scale).sample(clean_votes.shape) + # print("noise_sample", noise_sample) + noisy_results = clean_votes + cast(noise_sample, float64) + noisy_labels = cast((noisy_results > len(netTD)/2), float64) + + return noisy_labels, clean_votes + + +class Discriminator(Model): + def __init__(self, batch_size): + self.batch_size = batch_size + + def build_model(self, input_shape, dim): + input = Input(shape=input_shape, batch_size=self.batch_size) + x = Dense(dim * 4)(input) + x = ReLU()(x) + x = Dense(dim * 2)(x) + x = Dense(1)(x) + return Model(inputs=input, outputs=x) + + +class Generator(Model): + def __init__(self, batch_size): + self.batch_size = batch_size + + def build_model(self, input_shape, dim, data_dim, processor_info: Optional[NamedTuple] = None): + input = Input(shape=input_shape, batch_size = self.batch_size) + x = Dense(dim)(input) + x = ReLU()(x) + x = Dense(dim * 2)(x) + x = Dense(data_dim)(x) + if processor_info: + x = ActivationInterface(processor_info, 'ActivationInterface')(x) + return Model(inputs=input, outputs=x) diff --git a/src/ydata_synthetic/tests/custom_layers/test_activation_interface.py b/src/ydata_synthetic/tests/custom_layers/test_activation_interface.py new file mode 100644 index 00000000..b6bbec63 --- /dev/null +++ b/src/ydata_synthetic/tests/custom_layers/test_activation_interface.py @@ -0,0 +1,72 @@ +"Activation Interface layer test suite." +from itertools import cycle, islice +from re import search + +from numpy import array, cumsum, isin, split +from numpy import sum as npsum +from numpy.random import normal +from pandas import DataFrame, concat +from pytest import fixture +from tensorflow.keras import Model +from tensorflow.keras.layers import Dense, Input + +from ydata_synthetic.preprocessing.regular.processor import \ + RegularDataProcessor +from ydata_synthetic.utils.gumbel_softmax import ActivationInterface + +BATCH_SIZE = 10 + +@fixture(name='noise_batch') +def fixture_noise_batch(): + "Sample noise for mock output generation." + return normal(size=(BATCH_SIZE, 16)) + +@fixture(name='mock_data') +def fixture_mock_data(): + "Creates mock data for the tests." + num_block = DataFrame(normal(size=(BATCH_SIZE, 6)), columns = [f'num_{i}' for i in range(6)]) + cat_block_1 = DataFrame(array(list(islice(cycle(range(2)), BATCH_SIZE))), columns = ['cat_0']) + cat_block_2 = DataFrame(array(list(islice(cycle(range(4)), BATCH_SIZE))), columns = ['cat_1']) + return concat([num_block, cat_block_1, cat_block_2], axis = 1) + +@fixture(name='mock_processor') +def fixture_mock_processor(mock_data): + "Creates a mock data processor for the mock data." + num_cols = [col for col in mock_data.columns if col.startswith('num')] + cat_cols = [col for col in mock_data.columns if col.startswith('cat')] + return RegularDataProcessor(num_cols, cat_cols).fit(mock_data) + +# pylint: disable=C0103 +@fixture(name='mock_generator') +def fixture_mock_generator(noise_batch, mock_processor): + "A mock generator with the Activation Interface as final layer." + input_ = Input(shape=noise_batch.shape[1], batch_size = BATCH_SIZE) + dim = 15 + data_dim = 12 + x = Dense(dim, activation='relu')(input_) + x = Dense(dim * 2, activation='relu')(x) + x = Dense(dim * 4, activation='relu')(x) + x = Dense(data_dim)(x) + x = ActivationInterface(processor_info=mock_processor.col_transform_info, name='act_itf')(x) + return Model(inputs=input_, outputs=x) + +@fixture(name='mock_output') +def fixture_mock_output(noise_batch, mock_generator): + "Returns mock output of the model as a numpy object." + return mock_generator(noise_batch).numpy() + +# pylint: disable=W0632 +def test_io(mock_processor, mock_output): + "Tests the output format of the activation interface for a known input." + num_lens = len(mock_processor.col_transform_info.numerical.feat_names_out) + cat_lens = len(mock_processor.col_transform_info.categorical.feat_names_out) + assert mock_output.shape == (BATCH_SIZE, num_lens + cat_lens), "The output has wrong shape." + num_part, cat_part = split(mock_output, [num_lens], 1) + assert not isin(num_part, [0, 1]).all(), "The numerical block is not expected to contain 0 or 1." + assert isin(cat_part, [0, 1]).all(), "The categorical block is expected to contain only 0 or 1." + cat_i, cat_o = mock_processor.col_transform_info.categorical + cat_blocks = cumsum([len([col for col in cat_o if col.startswith(feat) and search('_[0-9]*$', col)]) \ + for feat in cat_i]) + cat_blocks = split(cat_part, cat_blocks[:-1], 1) + assert all(npsum(abs(block)) == BATCH_SIZE for block in cat_blocks), "There are non one-hot encoded \ + categorical blocks." diff --git a/src/ydata_synthetic/tests/custom_layers/test_gumbel_softmax.py b/src/ydata_synthetic/tests/custom_layers/test_gumbel_softmax.py new file mode 100644 index 00000000..dd52c71d --- /dev/null +++ b/src/ydata_synthetic/tests/custom_layers/test_gumbel_softmax.py @@ -0,0 +1,54 @@ +"Test suite for the Gumbel-Softmax layer implementation." +import tensorflow as tf +from numpy import amax, amin, isclose, ones +from numpy import sum as npsum +from pytest import fixture +from tensorflow.keras import layers + +from ydata_synthetic.utils.gumbel_softmax import GumbelSoftmaxLayer + + +# pylint:disable=W0613 +def custom_initializer(shape_list, dtype): + "A constant weight intializer to ensure test reproducibility." + return tf.constant(ones((5, 5)), dtype=tf.dtypes.float32) + +@fixture(name='rand_input') +def fixture_rand_input(): + "A random, reproducible, input for the mock model." + return tf.constant(tf.random.normal([4, 5], seed=42)) + +def test_hard_sample_output_format(rand_input): + """Tests that the hard output samples are in the expected formats. + The hard sample should be returned as a one-hot tensor.""" + affined = layers.Dense(5, use_bias = False, kernel_initializer=custom_initializer)(rand_input) + hard_sample, _ = GumbelSoftmaxLayer()(affined) + assert npsum(hard_sample) == hard_sample.shape[0], "The sum of the hard samples should equal the number." + assert all(npsum(hard_sample == 0, 1) == hard_sample.shape[1] - 1), "The hard samples is not a one-hot tensor." + +def test_soft_sample_output_format(rand_input): + """Tests that the soft output samples are in the expected formats. + The soft sample should be returned as a probabilities tensor.""" + affined = layers.Dense(5, use_bias = False, kernel_initializer=custom_initializer)(rand_input) + _, soft_sample = GumbelSoftmaxLayer(tau=0.5)(affined) + assert isclose(npsum(soft_sample), soft_sample.shape[0]), "The sum of the soft samples should be close to \ + the number of records." + assert amax(soft_sample) <= 1, "Invalid probability values found." + assert amin(soft_sample) >= 0, "Invalid probability values found." + +def test_gradients(rand_input): + "Performs basic numerical assertions on the gradients of the sof/hard samples." + def mock(i): + return GumbelSoftmaxLayer()(layers.Dense(5, use_bias=False, kernel_initializer=custom_initializer)(i)) + with tf.GradientTape() as hard_tape: + hard_tape.watch(rand_input) + hard_sample, _ = mock(rand_input) + with tf.GradientTape() as soft_tape: + soft_tape.watch(rand_input) + _, soft_sample = mock(rand_input) + hard_grads = hard_tape.gradient(hard_sample, rand_input) + soft_grads = soft_tape.gradient(soft_sample, rand_input) + + assert hard_grads is None, "The hard sample must not compute gradients." + assert soft_grads is not None, "The soft sample is expected to compute gradients." + assert npsum(abs(soft_grads)) != 0, "The soft sample is expected to have non-zero gradients."