From dcdab7f6c88135400d251a316476360740d6d6e7 Mon Sep 17 00:00:00 2001 From: Alexandre Quemy Date: Wed, 22 Mar 2023 10:49:09 +0100 Subject: [PATCH] feat: add Fabric Regular Synthesizer to Streamlit app (#252) * feat: Fabric Regular Synthesizer in Streamlit app * feat: add ydata-sdk as requirement for streamlit * feat: allow to overwrite default datatype for Fabric Regular Synthesizer * fix: restore streamlit dependency * feat: rename the SDK synthesizer, improve documentation * fix: type exception --- setup.py | 3 +- src/ydata_synthetic/streamlit_app/About.py | 9 +++ .../pages/1_Train_a_synthesizer.py | 76 ++++++++++++++++--- .../pages/2_Generate_synthetic_data.py | 55 ++++++++++++-- 4 files changed, 127 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 102717da..8daf7d82 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,8 @@ "streamlit==1.18.1", "typing-extensions==3.10.0", "streamlit_pandas_profiling==0.1.3", - "ydata-profiling==4.0.0" + "ydata-profiling==4.0.0", + "ydata-sdk>=0.2.1" ], }, ) diff --git a/src/ydata_synthetic/streamlit_app/About.py b/src/ydata_synthetic/streamlit_app/About.py index 19f28fb3..cec3669d 100644 --- a/src/ydata_synthetic/streamlit_app/About.py +++ b/src/ydata_synthetic/streamlit_app/About.py @@ -44,6 +44,15 @@ def main(): - WGAN - WGANGP - CTGAN + - **ydata-sdk Synthesizer** + ''') + + st.success('''In particular, **ydata-sdk Synthesizer** uses [`ydata-sdk`](https://docs.sdk.ydata.ai/) to leverage the state-of-the-art synthesizer model developed by YData.''') + st.info(''' + Using **ydata-sdk Synthesizer** requires a valid token. The token is attached to a Fabric account. + In case you do not have an account, you can create one at https://ydata.ai/ydata-fabric-free-trial. + To obtain the token, please, login to https://fabric.ydata.ai. + The token is available on the homepage once you are connected. ''') #best practives for synthetic data generation diff --git a/src/ydata_synthetic/streamlit_app/pages/1_Train_a_synthesizer.py b/src/ydata_synthetic/streamlit_app/pages/1_Train_a_synthesizer.py index 6251c3ad..f7c88298 100644 --- a/src/ydata_synthetic/streamlit_app/pages/1_Train_a_synthesizer.py +++ b/src/ydata_synthetic/streamlit_app/pages/1_Train_a_synthesizer.py @@ -1,6 +1,11 @@ from typing import Union +import os +import json import streamlit as st +from ydata.sdk.synthesizers import RegularSynthesizer +from ydata.sdk.common.client import get_client + from ydata_synthetic.synthesizers import ModelParameters, TrainParameters from ydata_synthetic.synthesizers.regular.model import Model @@ -12,7 +17,7 @@ def get_available_models(type: Union[str, DataType]): dtype = DataType(type) if dtype == DataType.TABULAR: - models_list = [e.value.upper() for e in Model if e.value not in ['cgan', 'cwgangp']] + models_list = [e.value.upper() for e in Model if e.value not in ['cgan', 'cwgangp']] + ['ydata-sdk Synthesizer'] else: st.warning('Time-Series models are not yet supported .') models_list = (['']) @@ -35,7 +40,7 @@ def run(): models_list = get_available_models(type=datatype) model_name = st.selectbox('Select your model', models_list) - if model_name !='': + if model_name not in ['', 'ydata-sdk Synthesizer']: st.text("Select your synthesizer model parameters") col1, col2 = st.columns(2) with col1: @@ -50,14 +55,14 @@ def run(): # Create the Train parameters gan_args = ModelParameters(batch_size=batch_size, - lr=lr, - betas=(beta_1, beta_2), - noise_dim=noise_dim, - layers_dim=layer_dim) + lr=lr, + betas=(beta_1, beta_2), + noise_dim=noise_dim, + layers_dim=layer_dim) model = init_synth(datatype=datatype, modelname=model_name, model_parameters=gan_args) - if model!=None: + if model != None: st.text("Set your synthesizer training parameters") #Get the training parameters epochs, label_col = training_parameters(model_name, df.columns) @@ -72,11 +77,64 @@ def run(): else: model.fit(data=df, num_cols=num_cols, cat_cols=cat_cols, train_arguments=train_args) - st.success('Synthesizer was trained succesfully!!') - + st.success('Synthesizer was trained succesfully!') st.info(f"The trained model will be saved at {model_path}.") model.save(model_path) + + + if model_name == 'ydata-sdk Synthesizer': + valid_token = False + st.text("Model parameters") + col1, col2 = st.columns(2) + with col1: + token = st.text_input("SDK Token", type="password") + os.environ['YDATA_TOKEN'] = token + + with col2: + st.write("##") + try: + get_client() + st.text('✅ Valid') + valid_token = True + except Exception: + st.text('❌ Invalid') + + if not valid_token: + st.error("""**ydata-sdk Synthesizer requires a valid token.** + In case you do not have an account, please, create one at https://ydata.ai/ydata-fabric-free-trial. + To obtain the token, please, login to https://fabric.ydata.ai. + The token is available on the homepage once you are connected. + """) + + + with st.expander('**More settings**'): + model_path = st.text_input("Saved trained model to path:", value="trained_synth.pkl") + + st.subheader("3. Train your synthesizer") + if st.button('Click here to start the training process', disabled=not valid_token): + model = RegularSynthesizer() + with st.spinner("Please wait while your synthesizer trains..."): + dtypes = {} + for c in num_cols: + dtypes[c] = 'numerical' + for c in cat_cols: + dtypes[c] = 'categorical' + model.fit(X=df, dtypes=dtypes) + + st.success('Synthesizer was trained succesfully!') + st.info(f"The trained model will be saved at {model_path}.") + + model_data = { + 'uid': model.uid, + 'token': os.environ['YDATA_TOKEN'] + } + with open(model_path, 'w') as outfile: + json.dump(model_data, outfile) + + + + if __name__ == '__main__': run() \ No newline at end of file diff --git a/src/ydata_synthetic/streamlit_app/pages/2_Generate_synthetic_data.py b/src/ydata_synthetic/streamlit_app/pages/2_Generate_synthetic_data.py index 4e57a871..5eba0df4 100644 --- a/src/ydata_synthetic/streamlit_app/pages/2_Generate_synthetic_data.py +++ b/src/ydata_synthetic/streamlit_app/pages/2_Generate_synthetic_data.py @@ -1,18 +1,57 @@ import streamlit as st +import json +import os + +from ydata.sdk.synthesizers import RegularSynthesizer +from ydata.sdk.common.client import get_client from ydata_synthetic.streamlit_app.pages.functions.train import DataType from ydata_synthetic.streamlit_app.pages.functions.generate import load_model, generate_profile def run(): st.subheader("Generate synthetic data from a trained model") - + from_SDK = False + model_data = {} + valid_token = False col1, col2 = st.columns([4, 2]) with col1: input_path = st.text_input("Provide the path to a trained model", value="trained_synth.pkl") + # Try to load as a JSON as SDK + try: + f = open(input_path) + model_data = json.load(f) + from_SDK = True + except: + pass + + if from_SDK: + token = st.text_input("SDK Token", type="password", value=model_data.get('token')) + os.environ['YDATA_TOKEN'] = token + + with col2: datatype = st.selectbox('Select your data type', (DataType.TABULAR.value,)) datatype=DataType(datatype) + if from_SDK and 'YDATA_TOKEN' in os.environ: + st.write("##") + try: + get_client() + st.text('✅ Valid') + valid_token = True + except Exception: + st.text('❌ Invalid') + + if from_SDK and 'token' in model_data and not valid_token: + st.warning("The token used during training is not valid anymore. Please, use a new token.") + + if from_SDK and not valid_token: + st.error("""**ydata-sdk Synthesizer requires a valid token.** + In case you do not have an account, please, create one at https://ydata.ai/ydata-fabric-free-trial. + To obtain the token, please, login to https://fabric.ydata.ai. + The token is available on the homepage once you are connected. + """) + col1, col2 = st.columns([4,2]) with col1: n_samples = st.number_input("Number of samples to generate", min_value=0, value=1000) @@ -21,14 +60,18 @@ def run(): sample_path = st.text_input("Synthetic samples file path", value='synthetic.csv') if st.button('Generate samples'): - #load a trained model - model = load_model(input_path=input_path, - datatype=datatype) + if from_SDK: + model = RegularSynthesizer.get(uid=model_data.get('uid')) + + else: + model = load_model(input_path=input_path, datatype=datatype) + + st.success('The model was properly loaded and is now ready to generate synthetic samples!') - st.success('Trained model was loaded. You can now generate synthetic samples') #sample synthetic data - synth_data = model.sample(n_samples) + with st.spinner('Generating samples... This might take time.'): + synth_data = model.sample(n_samples) st.write(synth_data) #save the synthetic data samples to a given path