diff --git a/src/ydata_synthetic/synthesizers/timeseries/timegan/model.py b/src/ydata_synthetic/synthesizers/timeseries/timegan/model.py index 5d103d31..d92f0dfc 100644 --- a/src/ydata_synthetic/synthesizers/timeseries/timegan/model.py +++ b/src/ydata_synthetic/synthesizers/timeseries/timegan/model.py @@ -95,15 +95,6 @@ def define_gan(self): outputs=Y_real, name="RealDiscriminator") - # ---------------------------- - # Init the optimizers - # ---------------------------- - self.autoencoder_opt = Adam(learning_rate=self.lr) - self.supervisor_opt = Adam(learning_rate=self.lr) - self.generator_opt = Adam(learning_rate=self.lr) - self.discriminator_opt = Adam(learning_rate=self.lr) - self.embedding_opt = Adam(learning_rate=self.lr) - # ---------------------------- # Define the loss functions # ---------------------------- @@ -112,7 +103,7 @@ def define_gan(self): @function - def train_autoencoder(self, x): + def train_autoencoder(self, x, opt): with GradientTape() as tape: x_tilde = self.autoencoder(x) embedding_loss_t0 = self._mse(x, x_tilde) @@ -120,11 +111,11 @@ def train_autoencoder(self, x): var_list = self.embedder.trainable_variables + self.recovery.trainable_variables gradients = tape.gradient(e_loss_0, var_list) - self.autoencoder_opt.apply_gradients(zip(gradients, var_list)) + opt.apply_gradients(zip(gradients, var_list)) return sqrt(embedding_loss_t0) @function - def train_supervisor(self, x): + def train_supervisor(self, x, opt): with GradientTape() as tape: h = self.embedder(x) h_hat_supervised = self.supervisor(h) @@ -132,11 +123,12 @@ def train_supervisor(self, x): var_list = self.supervisor.trainable_variables + self.generator.trainable_variables gradients = tape.gradient(g_loss_s, var_list) - self.supervisor_opt.apply_gradients(zip(gradients, var_list)) + apply_grads = [(grad, var) for (grad, var) in zip(gradients, var_list) if grad is not None] + opt.apply_gradients(apply_grads) return g_loss_s @function - def train_embedder(self,x): + def train_embedder(self,x, opt): with GradientTape() as tape: h = self.embedder(x) h_hat_supervised = self.supervisor(h) @@ -148,7 +140,7 @@ def train_embedder(self,x): var_list = self.embedder.trainable_variables + self.recovery.trainable_variables gradients = tape.gradient(e_loss, var_list) - self.embedding_opt.apply_gradients(zip(gradients, var_list)) + opt.apply_gradients(zip(gradients, var_list)) return sqrt(embedding_loss_t0) def discriminator_loss(self, x, z): @@ -176,7 +168,7 @@ def calc_generator_moments_loss(y_true, y_pred): return g_loss_mean + g_loss_var @function - def train_generator(self, x, z): + def train_generator(self, x, z, opt): with GradientTape() as tape: y_fake = self.adversarial_supervised(z) generator_loss_unsupervised = self._bce(y_true=ones_like(y_fake), @@ -199,17 +191,17 @@ def train_generator(self, x, z): var_list = self.generator_aux.trainable_variables + self.supervisor.trainable_variables gradients = tape.gradient(generator_loss, var_list) - self.generator_opt.apply_gradients(zip(gradients, var_list)) + opt.apply_gradients(zip(gradients, var_list)) return generator_loss_unsupervised, generator_loss_supervised, generator_moment_loss @function - def train_discriminator(self, x, z): + def train_discriminator(self, x, z, opt): with GradientTape() as tape: discriminator_loss = self.discriminator_loss(x, z) var_list = self.discriminator.trainable_variables gradients = tape.gradient(discriminator_loss, var_list) - self.discriminator_opt.apply_gradients(zip(gradients, var_list)) + opt.apply_gradients(zip(gradients, var_list)) return discriminator_loss def get_batch_data(self, data, n_windows): @@ -229,16 +221,22 @@ def get_batch_noise(self): def train(self, data, train_steps): ## Embedding network training + autoencoder_opt = Adam(learning_rate=self.lr) for _ in tqdm(range(train_steps), desc='Emddeding network training'): X_ = next(self.get_batch_data(data, n_windows=len(data))) - step_e_loss_t0 = self.train_autoencoder(X_) + step_e_loss_t0 = self.train_autoencoder(X_, autoencoder_opt) ## Supervised Network training + supervisor_opt = Adam(learning_rate=self.lr) for _ in tqdm(range(train_steps), desc='Supervised network training'): X_ = next(self.get_batch_data(data, n_windows=len(data))) - step_g_loss_s = self.train_supervisor(X_) + step_g_loss_s = self.train_supervisor(X_, supervisor_opt) ## Joint training + generator_opt = Adam(learning_rate=self.lr) + embedder_opt = Adam(learning_rate=self.lr) + discriminator_opt = Adam(learning_rate=self.lr) + step_g_loss_u = step_g_loss_s = step_g_loss_v = step_e_loss_t0 = step_d_loss = 0 for _ in tqdm(range(train_steps), desc='Joint networks training'): @@ -250,18 +248,18 @@ def train(self, data, train_steps): # -------------------------- # Train the generator # -------------------------- - step_g_loss_u, step_g_loss_s, step_g_loss_v = self.train_generator(X_, Z_) + step_g_loss_u, step_g_loss_s, step_g_loss_v = self.train_generator(X_, Z_, generator_opt) # -------------------------- # Train the embedder # -------------------------- - step_e_loss_t0 = self.train_embedder(X_) + step_e_loss_t0 = self.train_embedder(X_, embedder_opt) X_ = next(self.get_batch_data(data, n_windows=len(data))) Z_ = next(self.get_batch_noise()) step_d_loss = self.discriminator_loss(X_, Z_) if step_d_loss > 0.15: - step_d_loss = self.train_discriminator(X_, Z_) + step_d_loss = self.train_discriminator(X_, Z_, discriminator_opt) def sample(self, n_samples): steps = n_samples // self.batch_size + 1 @@ -273,8 +271,6 @@ def sample(self, n_samples): return np.array(np.vstack(data)) - - class Generator(Model): def __init__(self, hidden_dim, net_type='GRU'): self.hidden_dim = hidden_dim