Skip to content

Commit

Permalink
fix(core): use dataset length to get a number of epoch samples
Browse files Browse the repository at this point in the history
  • Loading branch information
mozharovsky committed Sep 17, 2021
1 parent 6d87808 commit 343634b
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions git_t5/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,9 @@ def configure_optimizers(
self,
) -> Tuple[optax.GradientTransformation, optax.Schedule]:
def training_steps() -> int:
total_steps = len(self.data_module.datasets["train"])
batch_size = self.data_module.config.data.train_batch_size
total_batch_size = batch_size * jax.device_count()
num_steps = len(self.data_module.datasets["train"])
num_epochs = self.config.trainer.max_epochs
num_train_steps = (total_steps // total_batch_size) * num_epochs
num_train_steps = num_steps * num_epochs
return num_train_steps

cfg = self.config.optimizer
Expand Down

0 comments on commit 343634b

Please sign in to comment.