-
Notifications
You must be signed in to change notification settings - Fork 249
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: Streamlit demo app to generate synthetic dataset using ydata-s…
…ynthetic on tabular data (#166) * Add files via upload * Delete examples/regular/ydata-synthetic-streamlit directory * Demo app with Streamlit * Update app.py * Update app.py * Update app.py * Add files via upload * Create app.gif * Create README.md * Create requirements.txt * Update README.md * Update README.md * Update requirements.txt * Update README.md * Update README.md * Update README.md
- Loading branch information
Showing
7 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[theme] | ||
primaryColor="#040000" | ||
backgroundColor="#770303" | ||
secondaryBackgroundColor="#000000" | ||
textColor="#f2f2f3" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Streamlit application to generate synthetic data using ydata-synthetic | ||
|
||
<img src="https://github.com/rajeshai/ydata-synthetic/blob/dev/examples/regular/streamlit%20app/app.JPG" alt="streamlit app to generate synthetic data"> | ||
|
||
This application takes a pre-processed dataset as input and outputs a synthetic dataset based on the given input parameters. This is made with open source libraries streamlit, ydata-synthetic and deployed on the streamlit cloud. | ||
|
||
## How to use | ||
|
||
1. Upload a pre-processed dataset. | ||
2. Choose the numerical features and categorical features. | ||
3. Choose all the training parameters appropriately. | ||
4. Click the 'click here to start the training process' button. | ||
|
||
<img src="https://github.com/rajeshai/ydata-synthetic/blob/dev/examples/regular/streamlit%20app/app.gif" alt="streamlit app to generate synthetic data"> | ||
|
||
Wait for the training to end. You will see a graph comparing the original data and synthetic data after training. | ||
Please use less number of epochs to complete the training process quickly as this application is deployed on the community cloud of streamlit which has computational limits. | ||
|
||
## Contributing | ||
|
||
Find the application here in this link [](https://share.streamlit.io/rajeshai/ydata-synthetic-streamlit/main/app.py) | ||
|
||
Feel free to contribute to this app by adding more features and optimizing its performance further. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import os | ||
import streamlit as st | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
from ydata_synthetic.synthesizers.regular import DRAGAN, CGAN, CRAMERGAN, WGAN_GP | ||
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters | ||
|
||
st.set_page_config(layout="wide",initial_sidebar_state="auto") | ||
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices' | ||
def run(): | ||
#global data_synn | ||
st.sidebar.image('YData_logo.svg') | ||
st.title('Generate synthetic data for a tabular classification dataset using [ydata-synthetic](https://github.com/ydataai/ydata-synthetic)') | ||
st.markdown('This streamlit application can generate synthetic data for your dataset. Please read all the instructions in the sidebar before you start the process.') | ||
data = st.file_uploader('Upload a preprocessed dataset in csv format') | ||
st.sidebar.title('About') | ||
st.sidebar.markdown('[ydata-synthetic](https://github.com/ydataai/ydata-synthetic) is an open-source library and is used to generate synthetic data mimicking the real world data.') | ||
st.sidebar.header('What is synthetic data?') | ||
st.sidebar.markdown('Synthetic data is artificially generated data that is not collected from real world events. It replicates the statistical components of real data without containing any identifiable information, ensuring individuals privacy.') | ||
st.sidebar.header('Why Synthetic Data?') | ||
st.sidebar.markdown('''Synthetic data can be used for many applications: | ||
- Privacy | ||
- Remove bias | ||
- Balance datasets | ||
- Augment datasets''') | ||
|
||
|
||
st.sidebar.header('Steps to follow') | ||
st.sidebar.markdown(''' | ||
- Upload any preprocessed tabular classification dataset. | ||
- Choose the parameters in the adjacent window appropriately. | ||
- Since this is a demo, please choose less number of epochs for quick completion of training. | ||
- After choosing all parameters, Click the button under the parameters to start training. | ||
- After the training is complete, you will see a graph comparing both real data set and synthetic dataset. Categorical columns are used to compare. | ||
- You will also see a button to download your synthetic dataset. Click that button to download your dataset.''') | ||
|
||
st.sidebar.markdown('''[](https://github.com/ydataai/ydata-synthetic)''',unsafe_allow_html=True) | ||
|
||
@st.cache | ||
def train(df): | ||
#models_dir = './cache' | ||
gan_args = ModelParameters(batch_size=batch_size, | ||
lr=learning_rate*0.001, | ||
betas=(beta_1, beta_2), | ||
noise_dim=noise_dim, | ||
layers_dim=layer_dim) | ||
|
||
train_args = TrainParameters(epochs=epochs, | ||
sample_interval=log_step) | ||
synthesizer = model(gan_args, n_discriminator=3) | ||
synthesizer.train(data, train_args, num_cols, cat_cols) | ||
synthesizer.save('data_synth.pkl') | ||
synthesizer = model.load('data_synth.pkl') | ||
data_syn = synthesizer.sample(samples) | ||
return data_syn | ||
@st.cache | ||
def convert_df(df): | ||
return df.to_csv().encode('utf-8') | ||
if data is not None: | ||
data = pd.read_csv(data) | ||
data.dropna(inplace=True) | ||
st.header('Choose the parameters!!') | ||
col1, col2, col3,col4 = st.columns(4) | ||
with col1: | ||
model = st.selectbox('Choose the GAN model', ['DRAGAN','CGAN','CRAMEGAN','WGAN_GP'],key=1) | ||
if model=='DRAGAN': | ||
model = DRAGAN | ||
elif model=='CGAN': | ||
model=CGAN | ||
elif model=='CRAMEGAN': | ||
model = CRAMERGAN | ||
else: | ||
model = WGAN_GP | ||
num_cols = st.multiselect('Choose the numerical columns', data.columns,key=1) | ||
cat_cols = st.multiselect('Choose categorical columns', [x for x in data.columns if x not in num_cols], key=2) | ||
|
||
with col2: | ||
noise_dim = st.number_input('Select noise dimension', 0,200,128,1) | ||
layer_dim = st.number_input('Select the layer dimension', 0,200,128,1) | ||
batch_size = st.number_input('Select batch size', 0,500, 500,1) | ||
|
||
with col3: | ||
log_step = st.number_input('Select sample interval', 0,200,100,1) | ||
epochs = st.number_input('Select the number of epochs',0,50,2,1) | ||
learning_rate = st.number_input('Select learning rate(x1e-3', 0.01, 0.1, 0.05, 0.01) | ||
|
||
with col4: | ||
beta_1 = st.slider('Select first beta co-efficient', 0.0, 1.0, 0.5) | ||
beta_2 = st.slider('Select second beta co-efficient', 0.0, 1.0, 0.9) | ||
samples = st.number_input('Select the number of synthetic samples to be generated', 0, 400000, step=1000) | ||
if st.button('Click here to start the training process'): | ||
if data is not None: | ||
st.write('Model Training is in progress. It may take a few minutes. Please wait for a while.') | ||
data_synn = train(data) | ||
st.success('Synthetic dataset with the given number of samples is generated!!') | ||
st.subheader('Real Data vs Synthetic Data') | ||
f , axes = plt.subplots(len(cat_cols),2, figsize=(20,25)) | ||
f.suptitle('Real data vs Synthetic data') | ||
for i, j in enumerate(cat_cols): | ||
sns.countplot(x=j, data=data, ax = axes[i,0]) | ||
sns.countplot(x=j, data=data_synn, ax = axes[i,1]) | ||
st.pyplot(f) | ||
st.download_button( | ||
label="Download data as CSV", | ||
data=convert_df(data_synn), | ||
file_name='data_syn.csv', | ||
mime='text/csv') | ||
st.balloons() | ||
else: | ||
st.write('Upload a dataset to train!!') | ||
if __name__== '__main__': | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
pandas | ||
matplotlib | ||
numpy | ||
seaborn | ||
streamlit | ||
ydata-synthetic |