Skip to content

Commit

Permalink
ADLR/megatron-lm!1795 - Added --train-sync-interval to optionally per…
Browse files Browse the repository at this point in the history
…iodically synchronize with GPU during training
  • Loading branch information
szmigacz authored and jaredcasper committed Aug 13, 2024
1 parent cf0f9b2 commit 3bd1f4e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 2 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,8 @@ def _add_training_args(parser):
group.add_argument('--calculate-per-token-loss', action='store_true',
help=('Scale cross entropy loss by the number of non-padded tokens in the '
'global batch, versus the default behavior of assuming all tokens are non-padded.'))
group.add_argument('--train-sync-interval', type=int, default=None,
help='Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.')

# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
Expand Down
9 changes: 8 additions & 1 deletion megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,7 @@ def get_e2e_base_metrics():
num_floating_point_operations_so_far += num_fp_ops
total_flops += num_fp_ops

# Fault tolerance
# Send heartbeat to FT package and update timeouts.
if args.enable_ft_package:
ft_client = ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.TRAIN_HEARTBEAT)
Expand All @@ -1190,6 +1190,13 @@ def get_e2e_base_metrics():
print_rank_0(f'Updated FT timeouts. New values: \
{ft_integration.get_rank_monitor_client().timeouts}')

# Bring CPU and GPU back in sync if on right iteration.
if (
args.train_sync_interval
and iteration % args.train_sync_interval == 0
):
torch.cuda.synchronize()

# Logging.
loss_scale = optimizer.get_loss_scale().item()
params_norm = None
Expand Down

0 comments on commit 3bd1f4e

Please sign in to comment.