Skip to content

Commit

Permalink
feat: add doppelganger model
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardodcpereira committed Aug 4, 2023
1 parent f53afd3 commit 809008a
Show file tree
Hide file tree
Showing 12 changed files with 1,493 additions and 61 deletions.
34 changes: 34 additions & 0 deletions examples/timeseries/stock_doppelganger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

# Importing necessary libraries
from ydata_synthetic.synthesizers.timeseries import TimeSeriesSynthesizer
from ydata_synthetic.preprocessing.timeseries import processed_stock
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
import pandas as pd
from os import path

# Read the data
stock_data = processed_stock(path='../../data/stock_data.csv', seq_len=24)
stock_data = [pd.DataFrame(sd, columns = ["Open", "High", "Low", "Close", "Adj_Close", "Volume"]) for sd in stock_data]
stock_data = pd.concat(stock_data).reset_index(drop=True)

# Define model parameters
model_args = ModelParameters(batch_size=100,
lr=0.001,
betas=(0.5, 0.9),
latent_dim=3,
gp_lambda=10,
pac=10)

train_args = TrainParameters(epochs=500, sequence_length=24,
measurement_cols=["Open", "High", "Low", "Close", "Adj_Close", "Volume"])

# Training the DoppelGANger synthesizer
if path.exists('doppelganger_stock'):
model_dop_gan = TimeSeriesSynthesizer.load('doppelganger_stock')
else:
model_dop_gan = TimeSeriesSynthesizer(modelname='doppelganger', model_parameters=model_args)
model_dop_gan.fit(stock_data, train_args, num_cols=["Open", "High", "Low", "Close", "Adj_Close", "Volume"])

# Generating new synthetic samples
synth_data = model_dop_gan.sample(n_samples=500)
print(synth_data[0])
59 changes: 25 additions & 34 deletions examples/timeseries/stock_timegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,62 +4,53 @@

# Importing necessary libraries
from os import path
import pandas as pd
from ydata_synthetic.synthesizers.timeseries import TimeSeriesSynthesizer
from ydata_synthetic.preprocessing.timeseries import processed_stock
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from ydata_synthetic.synthesizers import ModelParameters
from ydata_synthetic.preprocessing.timeseries import processed_stock
from ydata_synthetic.synthesizers.timeseries import TimeGAN

# Define model parameters
seq_len=24
n_seq = 6
hidden_dim=24
gamma=1

noise_dim = 32
dim = 128
batch_size = 128
gan_args = ModelParameters(batch_size=128,
lr=5e-4,
noise_dim=32,
layers_dim=128,
latent_dim=24,
gamma=1)

log_step = 100
learning_rate = 5e-4

gan_args = ModelParameters(batch_size=batch_size,
lr=learning_rate,
noise_dim=noise_dim,
layers_dim=dim)
train_args = TrainParameters(epochs=50000,
sequence_length=24,
number_sequences=6)

# Read the data
stock_data = processed_stock(path='../../data/stock_data.csv', seq_len=seq_len)
print(len(stock_data),stock_data[0].shape)
stock_data = pd.read_csv("../../data/stock_data.csv")
cols = list(stock_data.columns)

# Training the TimeGAN synthesizer
if path.exists('synthesizer_stock.pkl'):
synth = TimeGAN.load('synthesizer_stock.pkl')
synth = TimeSeriesSynthesizer.load('synthesizer_stock.pkl')
else:
synth = TimeGAN(model_parameters=gan_args, hidden_dim=24, seq_len=seq_len, n_seq=n_seq, gamma=1)
synth.train(stock_data, train_steps=50000)
synth = TimeSeriesSynthesizer(modelname='timegan', model_parameters=gan_args)
synth.fit(stock_data, train_args, num_cols=cols)
synth.save('synthesizer_stock.pkl')

# Generating new synthetic samples
synth_data = synth.sample(len(stock_data))
print(synth_data.shape)

# Reshaping the data
cols = ['Open','High','Low','Close','Adj Close','Volume']
stock_data_blocks = processed_stock(path='../../data/stock_data.csv', seq_len=24)
synth_data = synth.sample(n_samples=len(stock_data_blocks))
print(synth_data[0].shape)

# Plotting some generated samples. Both Synthetic and Original data are still standartized with values between [0,1]
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(15, 10))
axes=axes.flatten()

time = list(range(1,25))
obs = np.random.randint(len(stock_data))
obs = np.random.randint(len(stock_data_blocks))

for j, col in enumerate(cols):
df = pd.DataFrame({'Real': stock_data[obs][:, j],
'Synthetic': synth_data[obs][:, j]})
df = pd.DataFrame({'Real': stock_data_blocks[obs][:, j],
'Synthetic': synth_data[obs].iloc[:, j]})
df.plot(ax=axes[j],
title = col,
secondary_y='Synthetic data', style=['-', '--'])
fig.tight_layout()
fig.tight_layout()
2 changes: 1 addition & 1 deletion src/ydata_synthetic/preprocessing/regular/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def fit(self, X: DataFrame) -> RegularDataProcessor:
("scaler", MinMaxScaler()),
])
self._cat_pipeline = Pipeline([
("encoder", OneHotEncoder(sparse=False, handle_unknown='ignore')),
("encoder", OneHotEncoder(sparse_output=False, handle_unknown='ignore')),
])

self.num_pipeline.fit(X[self.num_cols]) if self.num_cols else zeros([len(X), 0])
Expand Down
179 changes: 179 additions & 0 deletions src/ydata_synthetic/preprocessing/timeseries/doppelganger_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from __future__ import annotations

from typing import List, Optional
from dataclasses import dataclass

from numpy import concatenate, ndarray, zeros, ones, expand_dims, reshape, sum as npsum, repeat, array_split, asarray
from pandas import DataFrame
from typeguard import typechecked

from ydata_synthetic.preprocessing.regular.processor import RegularDataProcessor


@dataclass
class ColumnMetadata:
"""
Dataclass that stores the metadata of each column.
"""
discrete: bool
output_dim: int
name: str


@typechecked
class DoppelGANgerProcessor(RegularDataProcessor):
"""
Main class for class the DoppelGANger preprocessing.
It works like any other transformer in scikit learn with the methods fit, transform and inverse transform.
Args:
num_cols (list of strings):
List of names of numerical columns.
measurement_cols (list of strings):
List of measurement columns.
sequence_length (int):
Sequence length.
"""
SUPPORTED_MODEL = 'DoppelGANger'

def __init__(self, num_cols: Optional[List[str]] = None,
cat_cols: Optional[List[str]] = None,
measurement_cols: Optional[List[str]] = None,
sequence_length: Optional[int] = None):
super().__init__(num_cols, cat_cols)

if num_cols is None:
num_cols = []
if cat_cols is None:
cat_cols = []
if measurement_cols is None:
measurement_cols = []
self.sequence_length = sequence_length
self._measurement_num_cols = [c for c in self.num_cols if c in measurement_cols]
self._measurement_cat_cols = [c for c in self.cat_cols if c in measurement_cols]
self._attribute_num_cols = [c for c in self.num_cols if c not in measurement_cols]
self._attribute_cat_cols = [c for c in self.cat_cols if c not in measurement_cols]
self._measurement_cols_metadata = None
self._attribute_cols_metadata = None
self._measurement_one_hot_cat_cols = None
self._attribute_one_hot_cat_cols = None
self._has_attributes = self._attribute_num_cols or self._attribute_cat_cols

@property
def measurement_cols_metadata(self):
return self._measurement_cols_metadata

@property
def attribute_cols_metadata(self):
return self._attribute_cols_metadata

def add_gen_flag(self, data_features: ndarray, sample_len: int):
num_sample = data_features.shape[0]
length = data_features.shape[1]
if length % sample_len != 0:
raise Exception("length must be a multiple of sample_len")
data_gen_flag = ones((num_sample, length))
data_gen_flag = expand_dims(data_gen_flag, 2)
shift_gen_flag = concatenate(
[data_gen_flag[:, 1:, :],
zeros((data_gen_flag.shape[0], 1, 1))],
axis=1)
data_gen_flag_t = reshape(
data_gen_flag,
[num_sample, int(length / sample_len), sample_len])
data_gen_flag_t = npsum(data_gen_flag_t, 2)
data_gen_flag_t = data_gen_flag_t > 0.5
data_gen_flag_t = repeat(data_gen_flag_t, sample_len, axis=1)
data_gen_flag_t = expand_dims(data_gen_flag_t, 2)
data_features = concatenate(
[data_features,
shift_gen_flag,
(1 - shift_gen_flag) * data_gen_flag_t],
axis=2)

return data_features

def transform(self, X: DataFrame) -> tuple[ndarray, ndarray]:
"""Transforms the passed DataFrame with the fit DataProcessor.
Args:
X (DataFrame):
DataFrame used to fit the processor parameters.
Should be aligned with the columns types defined in initialization.
Returns:
transformed (ndarray, ndarray):
Processed version of the passed DataFrame.
"""
self._check_is_fitted()

measurement_cols = self._measurement_num_cols + self._measurement_cat_cols
if not measurement_cols:
raise ValueError("At least one measurement column must be supplied.")
if not all(c in self.num_cols + self.cat_cols for c in measurement_cols):
raise ValueError("At least one of the supplied measurement columns does not exist in the dataset.")
if self.sequence_length is None:
raise ValueError("The sequence length is mandatory.")

num_data = DataFrame(self.num_pipeline.transform(X[self.num_cols]) if self.num_cols else zeros([len(X), 0]), columns=self.num_cols)
one_hot_cat_cols = self.cat_pipeline.get_feature_names_out()
cat_data = DataFrame(self.cat_pipeline.transform(X[self.cat_cols]) if self.cat_cols else zeros([len(X), 0]), columns=one_hot_cat_cols)

self._measurement_one_hot_cat_cols = [c for c in one_hot_cat_cols if c.split("_")[0] in self._measurement_cat_cols]
measurement_num_data = num_data[self._measurement_num_cols].to_numpy() if self._measurement_num_cols else zeros([len(X), 0])
self._measurement_cols_metadata = [ColumnMetadata(discrete=False, output_dim=1, name=c) for c in self._measurement_num_cols]
measurement_cat_data = cat_data[self._measurement_one_hot_cat_cols].to_numpy() if self._measurement_one_hot_cat_cols else zeros([len(X), 0])
self._measurement_cols_metadata += [ColumnMetadata(discrete=True, output_dim=X[c].nunique(), name=c) for c in self._measurement_cat_cols]
data_features = concatenate([measurement_num_data, measurement_cat_data], axis=1)

if self._has_attributes:
self._attribute_one_hot_cat_cols = [c for c in one_hot_cat_cols if c.split("_")[0] in self._attribute_cat_cols]
attribute_num_data = num_data[self._attribute_num_cols].to_numpy() if self._attribute_num_cols else zeros([len(X), 0])
self._attribute_cols_metadata = [ColumnMetadata(discrete=False, output_dim=1, name=c) for c in self._attribute_num_cols]
attribute_cat_data = cat_data[self._attribute_one_hot_cat_cols].to_numpy() if self._attribute_one_hot_cat_cols else zeros([len(X), 0])
self._attribute_cols_metadata += [ColumnMetadata(discrete=True, output_dim=X[c].nunique(), name=c) for c in self._attribute_cat_cols]
data_attributes = concatenate([attribute_num_data, attribute_cat_data], axis=1)
else:
self._attribute_one_hot_cat_cols = []
data_attributes = zeros((data_features.shape[0], 1))
self._attribute_cols_metadata = [ColumnMetadata(discrete=False, output_dim=1, name="zeros_attribute")]

num_samples = int(X.shape[0] / self.sequence_length)
data_features = asarray(array_split(data_features, num_samples))
data_attributes = asarray(array_split(data_attributes, num_samples))

data_features = self.add_gen_flag(data_features, sample_len=self.sequence_length)
self._measurement_cols_metadata += [ColumnMetadata(discrete=True, output_dim=2, name="gen_flags")]
return data_features, data_attributes.mean(axis=1)

def inverse_transform(self, X_features: ndarray, X_attributes: ndarray) -> list[DataFrame]:
"""Inverts the data transformation pipelines on a passed DataFrame.
Args:
X_features (ndarray):
Numpy array with the measurement data to be brought back to the original format.
X_attributes (ndarray):
Numpy array with the attribute data to be brought back to the original format.
Returns:
result (DataFrame):
DataFrame with all performed transformations inverted.
"""
self._check_is_fitted()

num_samples = X_attributes.shape[0]
if self._has_attributes:
X_attributes = repeat(X_attributes.reshape((num_samples, 1, X_attributes.shape[1])), repeats=X_features.shape[1], axis=1)
generated_data = concatenate((X_features, X_attributes), axis=2)
else:
generated_data = X_features
output_cols = self._measurement_num_cols + self._measurement_one_hot_cat_cols + self._attribute_num_cols + self._attribute_one_hot_cat_cols
one_hot_cat_cols = self._measurement_one_hot_cat_cols + self._attribute_one_hot_cat_cols

samples = []
for i in range(num_samples):
df = DataFrame(generated_data[i], columns=output_cols)
df_num = self.num_pipeline.inverse_transform(df[self.num_cols]) if self.num_cols else zeros([len(df), 0])
df_cat = self.cat_pipeline.inverse_transform(df[one_hot_cat_cols].round(0)) if self.cat_cols else zeros([len(df), 0])
df = DataFrame(concatenate((df_num, df_cat), axis=1), columns=self.num_cols+self.cat_cols)
df = df.loc[:, self._col_order_]
for col in df.columns:
df[col] = df[col].astype(self._types[col])
samples.append(df)

return samples
16 changes: 12 additions & 4 deletions src/ydata_synthetic/synthesizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,23 @@
from ydata_synthetic.preprocessing.timeseries.timeseries_processor import (
TimeSeriesDataProcessor, TimeSeriesModels)
from ydata_synthetic.preprocessing.regular.ctgan_processor import CTGANDataProcessor
from ydata_synthetic.preprocessing.timeseries.doppelganger_processor import DoppelGANgerProcessor
from ydata_synthetic.synthesizers.saving_keras import make_keras_picklable

_model_parameters = ['batch_size', 'lr', 'betas', 'layers_dim', 'noise_dim',
'n_cols', 'seq_len', 'condition', 'n_critic', 'n_features',
'tau_gs', 'generator_dims', 'critic_dims', 'l2_scale',
'latent_dim', 'gp_lambda', 'pac']
'latent_dim', 'gp_lambda', 'pac', 'gamma']
_model_parameters_df = [128, 1e-4, (None, None), 128, 264,
None, None, None, 1, None, 0.2, [256, 256],
[256, 256], 1e-6, 128, 10.0, 10]
[256, 256], 1e-6, 128, 10.0, 10, 1]

_train_parameters = ['cache_prefix', 'label_dim', 'epochs', 'sample_interval',
'labels', 'n_clusters', 'epsilon', 'log_frequency']
'labels', 'n_clusters', 'epsilon', 'log_frequency',
'measurement_cols', 'sequence_length', 'number_sequences']

ModelParameters = namedtuple('ModelParameters', _model_parameters, defaults=_model_parameters_df)
TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None, 10, 0.005, True))
TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None, 10, 0.005, True, None, 1, 1))

@typechecked
class BaseModel(ABC):
Expand Down Expand Up @@ -185,6 +187,12 @@ def fit(self,
epsilon = train_arguments.epsilon
self.processor = CTGANDataProcessor(n_clusters=n_clusters, epsilon=epsilon,
num_cols=num_cols, cat_cols=cat_cols).fit(data)
elif self.__MODEL__ == DoppelGANgerProcessor.SUPPORTED_MODEL:
measurement_cols = train_arguments.measurement_cols
sequence_length = train_arguments.sequence_length
self.processor = DoppelGANgerProcessor(num_cols=num_cols, cat_cols=cat_cols,
measurement_cols=measurement_cols,
sequence_length=sequence_length).fit(data)
else:
print(f'A DataProcessor is not available for the {self.__MODEL__}.')

Expand Down
4 changes: 2 additions & 2 deletions src/ydata_synthetic/synthesizers/timeseries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ydata_synthetic.synthesizers.timeseries.timegan.model import TimeGAN
from ydata_synthetic.synthesizers.timeseries.model import TimeSeriesSynthesizer

__all__ = [
'TimeGAN',
'TimeSeriesSynthesizer'
]
Empty file.
Loading

0 comments on commit 809008a

Please sign in to comment.