Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions modules/ui/TrainUI.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def __init__(self):
self.training_callbacks = None
self.training_commands = None

self.start_time = None
self.session_start_epoch = None
self.session_start_epoch_step = None

self.always_on_tensorboard_subprocess = None
self.current_workspace_dir = self.train_config.workspace_dir
self._check_start_always_on_tensorboard()
Expand Down Expand Up @@ -639,14 +643,24 @@ def open_tensorboard(self):
webbrowser.open("http://localhost:" + str(self.train_config.tensorboard_port), new=0, autoraise=False)

def _calculate_eta_string(self, train_progress: TrainProgress, max_step: int, max_epoch: int) -> str | None:
# Guard against None values before first progress callback
if self.start_time is None or self.session_start_epoch is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can never happen. the function is only called immediatly after setting it to non-None
if you want have a logical check, an assert() would do that

return "Estimating ..."

spent_total = time.monotonic() - self.start_time
steps_done = train_progress.epoch * max_step + train_progress.epoch_step

# calculate steps done in THIS SESSION only
current_total_steps = train_progress.epoch * max_step + train_progress.epoch_step
session_start_total_steps = self.session_start_epoch * max_step + self.session_start_epoch_step
steps_done_this_session = current_total_steps - session_start_total_steps

remaining_steps = (max_epoch - train_progress.epoch - 1) * max_step + (max_step - train_progress.epoch_step)
total_eta = spent_total / steps_done * remaining_steps

if train_progress.global_step <= 30:
if steps_done_this_session <= 30:
return "Estimating ..."

total_eta = spent_total / steps_done_this_session * remaining_steps

td = datetime.timedelta(seconds=total_eta)
days = td.days
hours, remainder = divmod(td.seconds, 3600)
Expand All @@ -671,6 +685,11 @@ def delete_eta_label(self):
self.eta_label.configure(text="")

def on_update_train_progress(self, train_progress: TrainProgress, max_step: int, max_epoch: int):
# capture session start on first progress update - hopefully works on cloud, multi and local.
if self.session_start_epoch is None:
self.session_start_epoch = train_progress.epoch
self.session_start_epoch_step = train_progress.epoch_step

self.set_step_progress(train_progress.epoch_step, max_step)
self.set_epoch_progress(train_progress.epoch, max_epoch)
self.set_eta_label(train_progress, max_step, max_epoch)
Expand Down Expand Up @@ -751,6 +770,9 @@ def __training_thread_function(self):
if self.train_config.cloud.enabled:
self.ui_state.get_var("secrets.cloud").update(self.train_config.secrets.cloud)

# Reset session tracking - actual values captured on first progress callback
self.session_start_epoch = None
self.session_start_epoch_step = None
self.start_time = time.monotonic()
trainer.train()
except Exception:
Expand Down