From 81abe1ddd706226bc5c5ab24b69ab55b8052213b Mon Sep 17 00:00:00 2001 From: Fabiana <30911746+fabclmnt@users.noreply.github.com> Date: Mon, 22 May 2023 20:00:39 -0700 Subject: [PATCH] feat: add new gmm based synth for fast synthesis (#269) * feat: Add new GMM model for fast synthesis * feat: add save and load for new model * fix: synthesis base class * fix: linter * fix: linter warnings --- README.md | 24 ++- .../Fast_Adult_Census_Income_Data.ipynb | 203 ++++++++++++++++++ src/ydata_synthetic/preprocessing/__init__.py | 7 + src/ydata_synthetic/synthesizers/__init__.py | 7 +- .../synthesizers/{gan.py => base.py} | 52 ++++- .../synthesizers/regular/cgan/model.py | 2 +- .../synthesizers/regular/cramergan/model.py | 4 +- .../synthesizers/regular/ctgan/model.py | 4 +- .../synthesizers/regular/cwgangp/model.py | 2 +- .../synthesizers/regular/dragan/model.py | 4 +- .../synthesizers/regular/gmm/__init__.py | 0 .../synthesizers/regular/gmm/model.py | 111 ++++++++++ .../synthesizers/regular/model.py | 14 +- .../synthesizers/regular/vanillagan/model.py | 5 +- .../synthesizers/regular/wgan/model.py | 4 +- .../synthesizers/regular/wgangp/model.py | 4 +- .../synthesizers/timeseries/timegan/model.py | 4 +- 17 files changed, 415 insertions(+), 36 deletions(-) create mode 100644 examples/regular/models/Fast_Adult_Census_Income_Data.ipynb rename src/ydata_synthetic/synthesizers/{gan.py => base.py} (89%) create mode 100644 src/ydata_synthetic/synthesizers/regular/gmm/__init__.py create mode 100644 src/ydata_synthetic/synthesizers/regular/gmm/model.py diff --git a/README.md b/README.md index b3009748..11d131b8 100644 --- a/README.md +++ b/README.md @@ -12,25 +12,30 @@ Join us on [![Discord](https://img.shields.io/badge/Discord-7289DA?style=for-the # YData Synthetic A package to generate synthetic tabular and time-series data leveraging the state of the art generative models. -## 🎊 We have **big news**: v1.0.0 is here -> We have exciting news for you. The new version of `ydata-synthetic` include new and exciting features: +## 🎊 The exciting features: +> These are must try features whne it comes to synthetic data generation: + > - A new streamlit app that delivers the synthetic data generation experience with a UI interface. A low code experience for the quick generation of synthetic data + > - A new fast synthetic data generation model based on Gaussian Mixture. So you can quickstart in the world of synthetic data generation without the need for a GPU. > - A conditional architecture for tabular data: CTGAN, which will make the process of synthetic data generation easier and with higher quality! - > - A new streamlit app that delivers the synthetic data generation experience with a UI interface - + ## Synthetic data ### What is synthetic data? Synthetic data is artificially generated data that is not collected from real world events. It replicates the statistical components of real data without containing any identifiable information, ensuring individuals' privacy. ### Why Synthetic Data? Synthetic data can be used for many applications: - - Privacy + - Privacy compliance for data-sharing and Machine Learning development - Remove bias - Balance datasets - Augment datasets # ydata-synthetic -This repository contains material related with Generative Adversarial Networks for synthetic data generation, in particular regular tabular data and time-series. -It consists a set of different GANs architectures developed using Tensorflow 2.0. Several example Jupyter Notebooks and Python scripts are included, to show how to use the different architectures. +This repository contains material related with architectures and models for synthetic data, from Generative Adversarial Networks (GANs) to Gaussian Mixtures. +The repo includes a full ecosystem for synthetic data generation, that includes different models for the generation of synthetic structure data and time-series. +All the Deep Learning models are implemented leveraging Tensorflow 2.0. +Several example Jupyter Notebooks and Python scripts are included, to show how to use the different architectures. + +Are you ready to learn more about synthetic data and the bext-practices for synthetic data generation? ## Quickstart The source code is currently hosted on GitHub at: https://github.com/ydataai/ydata-synthetic @@ -78,8 +83,8 @@ The below models are supported: ### Examples Here you can find usage examples of the package and models to synthesize tabular data. - - - Tabular synthetic data generation with CTGAN on adult census income dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Data-Centric-AI-Community/awesome-python-for-data-science/blob/main/workshop-ds/Workshop%20-%20Data-Centric%20AI%20pipelines%20-%20How%20and%20why.ipynb) + - Fast tabular data synthesis on adult census income dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/regular/models/Fast_Adult_Census_Income_Data.ipynb) + - Tabular synthetic data generation with CTGAN on adult census income dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/regular/models/CTGAN_Adult_Census_Income_Data.ipynb) - Time Series synthetic data generation with TimeGAN on stock dataset [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ydataai/ydata-synthetic/blob/master/examples/timeseries/TimeGAN_Synthetic_stock_data.ipynb) - More examples are continuously added and can be found in `/examples` directory. @@ -106,6 +111,7 @@ In this repository you can find the several GAN architectures that are used to c - [Cramer GAN (The Cramer Distance as a Solution to Biased Wasserstein Gradients)](https://arxiv.org/abs/1705.10743) - [CWGAN-GP (Conditional Wassertein GAN with Gradient Penalty)](https://cameronfabbri.github.io/papers/conditionalWGAN.pdf) - [CTGAN (Conditional Tabular GAN)](https://arxiv.org/pdf/1907.00503.pdf) + - [Gaussian Mixture](https://towardsdatascience.com/gaussian-mixture-models-explained-6986aaf5a95) ### Sequential data - [TimeGAN](https://papers.nips.cc/paper/2019/file/c9efe5f26cd17ba6216bbe2a7d26d490-Paper.pdf) diff --git a/examples/regular/models/Fast_Adult_Census_Income_Data.ipynb b/examples/regular/models/Fast_Adult_Census_Income_Data.ipynb new file mode 100644 index 00000000..41e66eed --- /dev/null +++ b/examples/regular/models/Fast_Adult_Census_Income_Data.ipynb @@ -0,0 +1,203 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "#Uncomment to install ydata-synthetic lib\n", + "#!pip install ydata-synthetic" + ], + "metadata": { + "id": "fwXSWiYu_tl0", + "pycharm": { + "name": "#%%\n" + } + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Tabular Synthetic Data Generation with Gaussian Mixture\n", + "- This notebook is an example of how to use a synthetic data generation methods based on [GMM](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html) to generate synthetic tabular data with numeric and categorical features.\n", + "\n", + "## Dataset\n", + "- The data used is the [Adult Census Income](https://www.kaggle.com/datasets/uciml/adult-census-income) which we will fecth by importing the `pmlb` library (a wrapper for the Penn Machine Learning Benchmark data repository).\n" + ], + "metadata": { + "id": "6T8gjToi_yKA", + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "source": [ + "from pmlb import fetch_data\n", + "\n", + "from ydata_synthetic.synthesizers.regular import RegularSynthesizer\n", + "from ydata_synthetic.synthesizers import ModelParameters, TrainParameters" + ], + "metadata": { + "id": "Ix4gZ9iSCVZI", + "pycharm": { + "name": "#%%\n" + } + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Load the data" + ], + "metadata": { + "id": "I0qyPwoECZ5x", + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "source": [ + "# Load data\n", + "data = fetch_data('adult')\n", + "num_cols = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']\n", + "cat_cols = ['workclass','education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',\n", + " 'native-country', 'target']" + ], + "metadata": { + "id": "YeFPnJVOMVqd", + "pycharm": { + "name": "#%%\n" + } + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Create and Train the synthetic data generator" + ], + "metadata": { + "id": "68MoepO0Cpx6", + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "source": [ + "synth = RegularSynthesizer(modelname='fast')\n", + "synth.fit(data=data, num_cols=num_cols, cat_cols=cat_cols)" + ], + "metadata": { + "id": "oIHMVgSZMg8_", + "pycharm": { + "name": "#%%\n" + } + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Generate new synthetic data" + ], + "metadata": { + "id": "xHK-SRPyDUin", + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "source": [ + "synth_data = synth.sample(1000)\n", + "print(synth_data)" + ], + "metadata": { + "id": "0aa2g0RLMkqe", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "01808aa4-a700-4385-e7df-b2f7abd162a0", + "pycharm": { + "name": "#%%\n" + } + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " age workclass fnlwgt education education-num \\\n", + "0 38.753654 4 179993.565472 8 10.0 \n", + "1 36.408844 4 245841.807958 9 10.0 \n", + "2 56.251066 4 400895.076058 11 13.0 \n", + "3 26.846605 4 240156.201048 11 10.0 \n", + "4 29.083102 1 5601.059126 11 9.0 \n", + ".. ... ... ... ... ... \n", + "995 79.281276 4 30664.183560 1 10.0 \n", + "996 51.423132 4 414524.980527 1 10.0 \n", + "997 17.342915 6 177716.451926 11 13.0 \n", + "998 39.298867 4 132011.369567 15 12.0 \n", + "999 46.977763 2 92662.371635 9 13.0 \n", + "\n", + " marital-status occupation relationship race sex capital-gain \\\n", + "0 4 0 3 4 0 55.771499 \n", + "1 6 7 0 4 1 124.337939 \n", + "2 4 3 3 4 1 27.968087 \n", + "3 4 6 1 4 0 25.065678 \n", + "4 6 3 0 4 0 126.269337 \n", + ".. ... ... ... ... ... ... \n", + "995 2 0 3 4 1 4.393001 \n", + "996 4 7 3 2 0 54.841598 \n", + "997 4 4 4 4 0 99.394428 \n", + "998 4 14 1 4 1 97.834797 \n", + "999 4 8 1 4 0 51.258308 \n", + "\n", + " capital-loss hours-per-week native-country target \n", + "0 -1.271118 39.749641 39 1 \n", + "1 -2.114950 44.488198 39 1 \n", + "2 1.541738 40.042696 39 1 \n", + "3 1.148560 39.952615 39 1 \n", + "4 -1.786768 39.808085 39 0 \n", + ".. ... ... ... ... \n", + "995 0.224015 50.580637 39 1 \n", + "996 1.319341 4.441194 39 1 \n", + "997 -5.231663 39.779674 39 1 \n", + "998 1.595817 39.731359 13 1 \n", + "999 1.129814 39.838415 39 1 \n", + "\n", + "[1000 rows x 15 columns]\n" + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/src/ydata_synthetic/preprocessing/__init__.py b/src/ydata_synthetic/preprocessing/__init__.py index e69de29b..4c5e2055 100644 --- a/src/ydata_synthetic/preprocessing/__init__.py +++ b/src/ydata_synthetic/preprocessing/__init__.py @@ -0,0 +1,7 @@ +from ydata_synthetic.preprocessing.regular.processor import RegularDataProcessor +from ydata_synthetic.preprocessing.timeseries.timeseries_processor import TimeSeriesDataProcessor + +__all__ = [ + "RegularDataProcessor", + "TimeSeriesDataProcessor" +] \ No newline at end of file diff --git a/src/ydata_synthetic/synthesizers/__init__.py b/src/ydata_synthetic/synthesizers/__init__.py index b2f944a5..65e8da40 100644 --- a/src/ydata_synthetic/synthesizers/__init__.py +++ b/src/ydata_synthetic/synthesizers/__init__.py @@ -1 +1,6 @@ -from ydata_synthetic.synthesizers.gan import ModelParameters, TrainParameters \ No newline at end of file +from ydata_synthetic.synthesizers.base import ModelParameters, TrainParameters + +__all__ = [ + "ModelParameters", + "TrainParameters" +] \ No newline at end of file diff --git a/src/ydata_synthetic/synthesizers/gan.py b/src/ydata_synthetic/synthesizers/base.py similarity index 89% rename from src/ydata_synthetic/synthesizers/gan.py rename to src/ydata_synthetic/synthesizers/base.py index d3666f78..055ee705 100644 --- a/src/ydata_synthetic/synthesizers/gan.py +++ b/src/ydata_synthetic/synthesizers/base.py @@ -1,7 +1,9 @@ "Implements a GAN BaseModel synthesizer, not meant to be directly instantiated." +from abc import ABC, abstractmethod from collections import namedtuple from typing import List, Optional, Union +import pandas as pd import tqdm from numpy import array, vstack, ndarray @@ -40,10 +42,50 @@ ModelParameters = namedtuple('ModelParameters', _model_parameters, defaults=_model_parameters_df) TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None, 10, 0.005, True)) +@typechecked +class BaseModel(ABC): + """ + Abstract class for synthetic data generation nmodels + + The main methods are train (for fitting the synthesizer), save/load and sample (generating synthetic records). + + """ + __MODEL__ = None + + @abstractmethod + def fit(self, data: Union[DataFrame, array], + num_cols: Optional[List[str]] = None, + cat_cols: Optional[List[str]] = None): + """ + ### Description: + Trains and fit a synthesizer model to a given input dataset. + + ### Args: + `data` (Union[DataFrame, array]): Training data + `num_cols` (Optional[List[str]]) : List with the names of the categorical columns + `cat_cols` (Optional[List[str]]): List of names of categorical columns + + ### Returns: + **self:** *object* + Fitted synthesizer + """ + ... + @abstractmethod + def sample(self, n_samples:int) -> pd.DataFrame: + assert n_samples>0, "Please insert a value bigger than 0 for n_samples parameter." + ... + + @classmethod + def load(cls, path: str): + ... + + @abstractmethod + def save(self, path: str): + ... # pylint: disable=R0902 @typechecked -class BaseModel(): +class BaseGANModel(BaseModel): """ Base class of GAN synthesizer models. The main methods are train (for fitting the synthesizer), save/load and sample (obtain synthetic records). @@ -51,8 +93,6 @@ class BaseModel(): model_parameters (ModelParameters): Set of architectural parameters for model definition. """ - __MODEL__ = None - def __init__( self, model_parameters: ModelParameters @@ -84,7 +124,7 @@ def __init__( self.gp_lambda = model_parameters.gp_lambda self.pac = model_parameters.pac - self.processor = None + self.processor=None if self.__MODEL__ in RegularModels.__members__ or \ self.__MODEL__ == CTGANDataProcessor.SUPPORTED_MODEL: self.tau = model_parameters.tau_gs @@ -183,8 +223,8 @@ def save(self, path): make_keras_picklable() dump(self, path) - @staticmethod - def load(path): + @classmethod + def load(cls, path): """ ### Description: Loads a saved synthesizer from a pickle. diff --git a/src/ydata_synthetic/synthesizers/regular/cgan/model.py b/src/ydata_synthetic/synthesizers/regular/cgan/model.py index b6673605..33f899c4 100644 --- a/src/ydata_synthetic/synthesizers/regular/cgan/model.py +++ b/src/ydata_synthetic/synthesizers/regular/cgan/model.py @@ -20,7 +20,7 @@ #Import ydata synthetic classes from ....synthesizers import TrainParameters -from ....synthesizers.gan import ConditionalModel +from ....synthesizers.base import ConditionalModel class CGAN(ConditionalModel): "CGAN model for discrete conditions" diff --git a/src/ydata_synthetic/synthesizers/regular/cramergan/model.py b/src/ydata_synthetic/synthesizers/regular/cramergan/model.py index c7c2b57a..1cf3393a 100644 --- a/src/ydata_synthetic/synthesizers/regular/cramergan/model.py +++ b/src/ydata_synthetic/synthesizers/regular/cramergan/model.py @@ -15,10 +15,10 @@ #Import ydata synthetic classes from ....synthesizers import TrainParameters -from ....synthesizers.gan import BaseModel +from ....synthesizers.base import BaseGANModel from ....synthesizers.loss import Mode, gradient_penalty -class CRAMERGAN(BaseModel): +class CRAMERGAN(BaseGANModel): __MODEL__='CRAMERGAN' diff --git a/src/ydata_synthetic/synthesizers/regular/ctgan/model.py b/src/ydata_synthetic/synthesizers/regular/ctgan/model.py index e7feb57c..3599d7fd 100644 --- a/src/ydata_synthetic/synthesizers/regular/ctgan/model.py +++ b/src/ydata_synthetic/synthesizers/regular/ctgan/model.py @@ -12,10 +12,10 @@ import ConditionalLoss, RealDataSampler, ConditionalSampler from ydata_synthetic.synthesizers.loss import gradient_penalty, Mode as ModeGP -from ydata_synthetic.synthesizers.gan import BaseModel, ModelParameters, TrainParameters +from ydata_synthetic.synthesizers.base import BaseGANModel, ModelParameters, TrainParameters from ydata_synthetic.preprocessing.regular.ctgan_processor import CTGANDataProcessor -class CTGAN(BaseModel): +class CTGAN(BaseGANModel): """ Conditional Tabular GAN model. Based on the paper https://arxiv.org/abs/1907.00503. diff --git a/src/ydata_synthetic/synthesizers/regular/cwgangp/model.py b/src/ydata_synthetic/synthesizers/regular/cwgangp/model.py index 7c3ba7ab..840ef527 100644 --- a/src/ydata_synthetic/synthesizers/regular/cwgangp/model.py +++ b/src/ydata_synthetic/synthesizers/regular/cwgangp/model.py @@ -15,7 +15,7 @@ #Import ydata synthetic classes from ....synthesizers import TrainParameters -from ....synthesizers.gan import ConditionalModel +from ....synthesizers.base import ConditionalModel from ....synthesizers.regular.wgangp.model import WGAN_GP class CWGANGP(ConditionalModel, WGAN_GP): diff --git a/src/ydata_synthetic/synthesizers/regular/dragan/model.py b/src/ydata_synthetic/synthesizers/regular/dragan/model.py index 47d5d4e3..e1555b0f 100644 --- a/src/ydata_synthetic/synthesizers/regular/dragan/model.py +++ b/src/ydata_synthetic/synthesizers/regular/dragan/model.py @@ -12,10 +12,10 @@ from keras.optimizers import Adam #Import ydata synthetic classes -from ....synthesizers.gan import BaseModel +from ....synthesizers.base import BaseGANModel from ....synthesizers.loss import Mode, gradient_penalty -class DRAGAN(BaseModel): +class DRAGAN(BaseGANModel): __MODEL__='DRAGAN' diff --git a/src/ydata_synthetic/synthesizers/regular/gmm/__init__.py b/src/ydata_synthetic/synthesizers/regular/gmm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ydata_synthetic/synthesizers/regular/gmm/model.py b/src/ydata_synthetic/synthesizers/regular/gmm/model.py new file mode 100644 index 00000000..466d1978 --- /dev/null +++ b/src/ydata_synthetic/synthesizers/regular/gmm/model.py @@ -0,0 +1,111 @@ +""" + GMM based synthetic data generation model +""" +from typing import List, Optional, Union + +from joblib import dump, load +from tqdm import tqdm + +from pandas import DataFrame +from numpy import (array, arange) + +from sklearn.mixture import GaussianMixture +from sklearn.metrics import silhouette_score + +from ydata_synthetic.synthesizers.base import BaseModel +from ydata_synthetic.preprocessing import RegularDataProcessor + +class GMM(BaseModel): + + def __init__(self, + covariance_type:str="full", + random_state:int=0): + self.covariance_type = covariance_type + self.random_state = random_state + self.__MODEL__ = GaussianMixture(covariance_type=covariance_type, + random_state=random_state) + self.processor = RegularDataProcessor + + def __optimize(self, prep_data: array): + """ + Auxiliary method to optimize the number of components to be considered for the Gaussian or Bayesian Mixture + Returns: + n_components (int): Optimal number of components calculated based on Silhouette score + """ + c = arange(2, 40, 5) + n_components=2 + max_silhouette=0 + for n in tqdm(c, desc="Hyperparameter search"): + model = GaussianMixture(n, covariance_type=self.covariance_type, random_state=self.random_state) + labels = model.fit_predict(prep_data) + s = silhouette_score(prep_data, labels, metric='euclidean') + if model.converged_: + if max_silhouette < s: + n_components = n + max_silhouette=s + return n_components + + def fit(self, data: Union[DataFrame, array], + num_cols: Optional[List[str]] = None, + cat_cols: Optional[List[str]] = None,): + """ + ### Description: + Trains and fit a synthesizer model to a given input dataset. + + ### Args: + `data` (Union[DataFrame, array]): Training data + `num_cols` (Optional[List[str]]) : List with the names of the categorical columns + `cat_cols` (Optional[List[str]]): List of names of categorical columns + + ### Returns: + **self:** *object* + Fitted synthesizer + """ + self.processor = RegularDataProcessor(num_cols=num_cols, cat_cols=cat_cols).fit(data) + train_data = self.processor.transform(data) + + #optimize the n_components selection + n_components = self.__optimize(train_data) + + self.__MODEL__.n_components=n_components + #Fit the gaussian model + self.__MODEL__.fit(train_data) + + def sample(self, n_samples: int): + """ + ### Description: + Generates samples from the trained synthesizer. + + ### Args: + `n_samples` (int): Number of rows to generated. + + ### Returns: + **synth_sample:** pandas.DataFrame, shape (n_samples, n_features) + Returns the generated synthetic samples. + """ + sample = self.__MODEL__.sample(n_samples=n_samples)[0] + + return self.processor.inverse_transform(sample) + + def save(self, path='str'): + """ + Save a model as a pickle + Args: + path (str): The path where the model should be saved as pickle + """ + try: + with open(path, 'wb') as f: + dump(self, f) + except: + raise Exception(f'The path {path} provided is not valid. Please validate your inputs') + + @classmethod + def load(cls, path:str): + """ + Load a trained synthesizer from a given path + Returns: + model (GMM): A trained GMM model + """ + with open(path, 'rb') as f: + model = load(f) + return model diff --git a/src/ydata_synthetic/synthesizers/regular/model.py b/src/ydata_synthetic/synthesizers/regular/model.py index 7332ab5a..3e3b2cfc 100644 --- a/src/ydata_synthetic/synthesizers/regular/model.py +++ b/src/ydata_synthetic/synthesizers/regular/model.py @@ -15,6 +15,7 @@ from ydata_synthetic.synthesizers.regular.cramergan.model import CRAMERGAN from ydata_synthetic.synthesizers.regular.dragan.model import DRAGAN from ydata_synthetic.synthesizers.regular.ctgan.model import CTGAN +from ydata_synthetic.synthesizers.regular.gmm.model import GMM @unique @@ -27,6 +28,7 @@ class Model(Enum): CRAMER = 'cramer' DEEPREGRET = 'dragan' CONDITIONALTABULAR = 'ctgan' + FAST = 'fast' __MAPPING__ = { VANILLA : VanilllaGAN, @@ -36,7 +38,8 @@ class Model(Enum): CWASSERTEINGP: CWGANGP, CRAMER: CRAMERGAN, DEEPREGRET: DRAGAN, - CONDITIONALTABULAR: CTGAN + CONDITIONALTABULAR: CTGAN, + FAST: GMM } @property @@ -45,8 +48,13 @@ def function(self): class RegularSynthesizer(): "Abstraction class " - def __new__(cls, modelname: str, model_parameters, **kwargs): - return Model(modelname).function(model_parameters, **kwargs) + def __new__(cls, modelname: str, model_parameters =None, **kwargs): + model = None + if Model(modelname) == Model.FAST: + model=Model(modelname).function(**kwargs) + else: + model=Model(modelname).function(model_parameters, **kwargs) + return model @staticmethod def load(path): diff --git a/src/ydata_synthetic/synthesizers/regular/vanillagan/model.py b/src/ydata_synthetic/synthesizers/regular/vanillagan/model.py index 5c29766a..45bcc0cc 100644 --- a/src/ydata_synthetic/synthesizers/regular/vanillagan/model.py +++ b/src/ydata_synthetic/synthesizers/regular/vanillagan/model.py @@ -14,11 +14,10 @@ from keras.optimizers import Adam #Import ydata synthetic classes -from ....synthesizers.gan import BaseModel +from ....synthesizers.base import BaseGANModel from ....synthesizers import TrainParameters -from ....utils.gumbel_softmax import GumbelSoftmaxActivation -class VanilllaGAN(BaseModel): +class VanilllaGAN(BaseGANModel): __MODEL__='GAN' diff --git a/src/ydata_synthetic/synthesizers/regular/wgan/model.py b/src/ydata_synthetic/synthesizers/regular/wgan/model.py index c836ea86..3093d96b 100644 --- a/src/ydata_synthetic/synthesizers/regular/wgan/model.py +++ b/src/ydata_synthetic/synthesizers/regular/wgan/model.py @@ -17,7 +17,7 @@ #Import ydata synthetic classes from ....synthesizers import TrainParameters -from ....synthesizers.gan import BaseModel +from ....synthesizers.base import BaseGANModel #Auxiliary Keras backend class to calculate the Random Weighted average #https://stackoverflow.com/questions/58133430/how-to-substitute-keras-layers-merge-merge-in-tensorflow-keras @@ -33,7 +33,7 @@ def call(self, inputs, **kwargs): def compute_output_shape(self, input_shape): return input_shape[0] -class WGAN(BaseModel): +class WGAN(BaseGANModel): __MODEL__='WGAN' diff --git a/src/ydata_synthetic/synthesizers/regular/wgangp/model.py b/src/ydata_synthetic/synthesizers/regular/wgangp/model.py index 667c2646..1c1d48be 100644 --- a/src/ydata_synthetic/synthesizers/regular/wgangp/model.py +++ b/src/ydata_synthetic/synthesizers/regular/wgangp/model.py @@ -16,9 +16,9 @@ #Import ydata synthetic classes from ....synthesizers import TrainParameters -from ....synthesizers.gan import BaseModel +from ....synthesizers.base import BaseGANModel -class WGAN_GP(BaseModel): +class WGAN_GP(BaseGANModel): __MODEL__='WGAN_GP' diff --git a/src/ydata_synthetic/synthesizers/timeseries/timegan/model.py b/src/ydata_synthetic/synthesizers/timeseries/timegan/model.py index baf230ac..60464566 100644 --- a/src/ydata_synthetic/synthesizers/timeseries/timegan/model.py +++ b/src/ydata_synthetic/synthesizers/timeseries/timegan/model.py @@ -13,7 +13,7 @@ from keras.optimizers import Adam from keras.losses import (BinaryCrossentropy, MeanSquaredError) -from ....synthesizers.gan import BaseModel +from ....synthesizers.base import BaseGANModel def make_net(model, n_layers, hidden_units, output_units, net_type='GRU'): if net_type=='GRU': @@ -33,7 +33,7 @@ def make_net(model, n_layers, hidden_units, output_units, net_type='GRU'): return model -class TimeGAN(BaseModel): +class TimeGAN(BaseGANModel): __MODEL__='TimeGAN'