diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index ec1d665215..2cffdec31e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1135,6 +1135,8 @@ def _add_training_args(parser): group.add_argument('--calculate-per-token-loss', action='store_true', help=('Scale cross entropy loss by the number of non-padded tokens in the ' 'global batch, versus the default behavior of assuming all tokens are non-padded.')) + group.add_argument('--train-sync-interval', type=int, default=None, + help='Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.') # deprecated group.add_argument('--checkpoint-activations', action='store_true', diff --git a/megatron/training/training.py b/megatron/training/training.py index 2c04a603cc..75a5b0bff7 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1175,7 +1175,7 @@ def get_e2e_base_metrics(): num_floating_point_operations_so_far += num_fp_ops total_flops += num_fp_ops - # Fault tolerance + # Send heartbeat to FT package and update timeouts. if args.enable_ft_package: ft_client = ft_integration.get_rank_monitor_client( ft_integration.StateMachineActions.TRAIN_HEARTBEAT) @@ -1190,6 +1190,13 @@ def get_e2e_base_metrics(): print_rank_0(f'Updated FT timeouts. New values: \ {ft_integration.get_rank_monitor_client().timeouts}') + # Bring CPU and GPU back in sync if on right iteration. + if ( + args.train_sync_interval + and iteration % args.train_sync_interval == 0 + ): + torch.cuda.synchronize() + # Logging. loss_scale = optimizer.get_loss_scale().item() params_norm = None