forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
arguments.py
1826 lines (1653 loc) · 101 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron arguments."""
import argparse
import dataclasses
import json
import logging
import os
import torch
import types
import torch.nn.functional as F
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.models.retro.utils import (
get_config_path as get_retro_config_path,
get_gpt_data_dir as get_retro_data_dir,
)
from megatron.core.transformer import TransformerConfig
from megatron.training.activations import squared_relu
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
# Standard arguments.
parser = _add_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_initialization_args(parser)
parser = _add_learning_rate_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_mixed_precision_args(parser)
parser = _add_distributed_args(parser)
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vision_args(parser)
parser = _add_moe_args(parser)
parser = _add_logging_args(parser)
parser = _add_straggler_detector_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
parser = _add_experimental_args(parser)
parser = _add_one_logger_args(parser)
parser = _add_config_logger_args(parser)
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
# Parse.
if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()
# Experimental yaml
if args.yaml_cfg is not None:
from .yaml_arguments import load_yaml
assert args.yaml_cfg and not args.use_legacy_models, \
"Yaml config is not supported with legacy models."
args = load_yaml(args.yaml_cfg)
# Args from environment
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
return args
def load_retro_config(retro_project_dir):
'''Load Retro's config.json.'''
# Retro config path.
retro_config_path = get_retro_config_path(retro_project_dir)
assert os.path.exists(retro_config_path), \
"Retro project dir missing config.json."
# Load retro config.
with open(retro_config_path) as f:
retro_config = types.SimpleNamespace(**json.load(f))
return retro_config
def load_retro_args(args):
"""Load predefined args from Retro config (if applicable).
When using Retro (or GPT for comparison purposes), data arguments are
overridden by the saved config.json within the Retro project directory. This
is to ensure that the data used for pretraining is consistent with the data
that was preprocessed using the Retro preprocessing pipeline (see
`tools/retro/preprocess_data.py`).
"""
# Return if no project directory is specified.
if args.retro_project_dir is None:
return
# Load retro config.
retro_config = load_retro_config(args.retro_project_dir)
# Retro data path is relative to project dir (via hard or soft links).
data_dir = get_retro_data_dir(args.retro_project_dir)
data_path = list(retro_config.retro_gpt_data_path)
if len(data_path) % 2 == 0:
for i in range(len(data_path) - 1, -1, -2):
data_path[i] = os.path.join(data_dir, data_path[i])
else:
assert len(data_path) == 1
data_path[0] = os.path.join(data_dir, data_path[0])
# Update args.
args.data_cache_path = retro_config.retro_gpt_data_cache_path
args.data_path = data_path if args.data_path is None else args.data_path
args.eval_interval = retro_config.retro_gpt_eval_interval
args.eval_iters = retro_config.retro_gpt_eval_iters
args.global_batch_size = retro_config.retro_gpt_global_batch_size
args.max_position_embeddings = retro_config.retro_gpt_seq_length
args.merge_file = os.path.join(
args.retro_project_dir,
retro_config.retro_gpt_merge_file,
) if retro_config.retro_gpt_merge_file is not None else None
args.seed = retro_config.retro_gpt_seed
args.seq_length = retro_config.retro_gpt_seq_length
args.tokenizer_model = os.path.join(
args.retro_project_dir,
retro_config.retro_gpt_tokenizer_model,
) if retro_config.retro_gpt_tokenizer_model is not None else None
args.tokenizer_type = retro_config.retro_gpt_tokenizer_type
args.train_samples = retro_config.retro_gpt_train_samples
args.vocab_file = os.path.join(
args.retro_project_dir,
retro_config.retro_gpt_vocab_file,
) if retro_config.retro_gpt_vocab_file is not None else None
# Retro-specific args.
args.retro_block_size = retro_config.retro_block_size
args.retro_chunk_length = retro_config.retro_gpt_chunk_length
args.retro_neighbor_dirs = retro_config.retro_neighbor_dirs
args.retro_split_preprocessing = retro_config.retro_gpt_split
args.retro_bert_tokenizer_type = retro_config.retro_bert_tokenizer_type
args.retro_bert_vocab_file = retro_config.retro_bert_vocab_file
def validate_args(args, defaults={}):
# Temporary
assert args.non_persistent_ckpt_type in ['global', None], \
'Currently only global checkpoints are supported'
# Load saved args from Retro (if applicable).
load_retro_args(args)
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
' ({}) is not divisible by tensor model parallel size ({})'.format(
args.world_size, args.tensor_model_parallel_size)
# Pipeline model parallel size.
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
# Checks.
model_parallel_size = (args.encoder_pipeline_model_parallel_size + args.pipeline_model_parallel_size) * \
args.tensor_model_parallel_size
assert args.world_size % (model_parallel_size * args.context_parallel_size) == 0, \
'world size ({}) is not divisible by tensor parallel size ({}) times ' \
'pipeline parallel size (encoder+decoder) ({}+{}) times context parallel size ({})'.format(
args.world_size, args.tensor_model_parallel_size,
args.encoder_pipeline_model_parallel_size, args.pipeline_model_parallel_size, args.context_parallel_size)
args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size)
if args.rank == 0:
print('using world size: {}, data-parallel size: {}, '
'context-parallel size: {} '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.context_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
# backwards compatibility.
if args.pipeline_model_parallel_split_rank is not None:
args.encoder_pipeline_model_parallel_size = args.pipeline_model_parallel_split_rank
args.pipeline_model_parallel_size -= args.encoder_pipeline_model_parallel_size
assert args.pipeline_model_parallel_size > 0
if args.tp_comm_overlap:
assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled'
# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
'valid, use --micro-batch-size instead'
del args.batch_size
assert args.warmup is None, '--warmup argument is no longer valid, use ' \
'--lr-warmup-fraction instead'
del args.warmup
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
if args.rank == 0:
print('--checkpoint-activations is no longer valid, use --recompute-activations, '
'or, for more control, --recompute-granularity and --recompute-method.')
exit()
del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key, None) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
if args.data_path is not None and args.split is None:
legacy_default_split_value = '969, 30, 1'
if args.rank == 0:
print('WARNING: Please specify --split when using --data-path. Using legacy default value '
f'of "{legacy_default_split_value}"')
args.split = legacy_default_split_value
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
if args.overlap_p2p_comm:
assert args.pipeline_model_parallel_size > 1, \
'when interleaved schedule is used, pipeline-model-parallel size '\
'should be greater than 1'
else:
assert args.pipeline_model_parallel_size > 2, \
'when interleaved schedule is used and p2p communication overlap is disabled, '\
'pipeline-model-parallel size should be greater than 2 to avoid having multiple '\
'p2p sends and recvs between same 2 ranks per communication batch'
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'number of layers should be divisible by the pipeline parallel size'
num_layers_per_pipeline_stage = args.num_layers // args.transformer_pipeline_model_parallel_size
assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage'
args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = None
# Overlap P2P communication is disabled if not using the interleaved schedule.
args.overlap_p2p_comm = False
if args.rank == 0:
print('WARNING: Setting args.overlap_p2p_comm to False since non-interleaved '
'schedule does not support overlapping p2p communication')
if args.overlap_param_gather:
assert args.use_distributed_optimizer, \
'--overlap-param-gather only supported with distributed optimizer'
assert args.overlap_grad_reduce, \
'--overlap-grad-reduce should be turned on when using --overlap-param-gather'
assert not args.use_legacy_models, \
'--overlap-param-gather only supported with MCore models'
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
# Turn off checking for NaNs in loss and grads if using dynamic loss scaling,
# where NaNs in grads / loss are signal to the loss scaler.
if not args.loss_scale:
args.check_for_nan_in_loss_and_grad = False
if args.rank == 0:
print('WARNING: Setting args.check_for_nan_in_loss_and_grad to False since '
'dynamic loss scaling is being used')
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.', flush=True)
if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
if args.dataloader_type is None:
args.dataloader_type = 'single'
# data
assert args.num_dataset_builder_threads > 0
# Consumed tokens.
args.consumed_train_samples = 0
args.skipped_train_samples = 0
args.consumed_valid_samples = 0
# Support for variable sequence lengths across batches/microbatches.
# set it if the dataloader supports generation of variable sequence lengths
# across batches/microbatches. Due to additional communication overhead
# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args.variable_seq_lengths = False
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert args.train_samples is None, \
'expected iteration-based training'
assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \
'expected iteration-based learning rate warmup'
assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-fraction and lr-warmup-iters'
# Sample-based training.
if args.train_samples:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert args.train_iters is None, \
'expected sample-based training'
assert args.lr_decay_iters is None, \
'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-fraction ' \
'and lr-warmup-samples'
if args.num_layers is not None:
assert args.encoder_num_layers is None, \
'cannot have both num-layers and encoder-num-layers specified'
args.encoder_num_layers = args.num_layers
else:
assert args.encoder_num_layers is not None, \
'either num-layers or encoder-num-layers should be specified'
args.num_layers = args.encoder_num_layers
# Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
for req_arg in required_args:
_check_arg_is_not_none(args, req_arg)
# Checks.
if args.ffn_hidden_size is None:
if args.swiglu:
# reduce the dimnesion for MLP since projections happens on
# two linear layers. this keeps the number of paramters in
# the same ballpark as the counterpart with 4*h size
# we keep it a multiple of 64, which means the actual tensor size
# will be a multiple of 64 / tp_size
args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64
else:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None and args.context_parallel_size > 1:
assert args.seq_length % (args.context_parallel_size * 2) == 0, \
'seq-length should be a multiple of 2 * context-parallel-size ' \
'if context-parallel-size > 1.'
if args.seq_length is not None:
assert args.encoder_seq_length is None
args.encoder_seq_length = args.seq_length
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length
if args.lr is not None:
assert args.min_lr <= args.lr
if args.save is not None:
assert args.save_interval is not None
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
if args.moe_grouped_gemm:
assert args.bf16, 'Currently GroupedGEMM for MoE only supports bf16 dtype.'
dc = torch.cuda.get_device_capability()
assert dc[0] >= 8, "Unsupported compute capability for GroupedGEMM kernels."
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
# Activation recomputing.
if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'recomputed activations only across tensor model ' \
'parallel groups'
assert args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert (TORCH_MAJOR, TORCH_MINOR) >= (1, 10), \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
# disable sequence parallelism when tp=1
# to avoid change in numerics when
# sequence_parallelism is enabled.
if args.tensor_model_parallel_size == 1:
args.sequence_parallel = False
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
"Using sequence parallelism requires setting the environment variable "
"CUDA_DEVICE_MAX_CONNECTIONS to 1")
if args.async_tensor_model_parallel_allreduce:
raise RuntimeError(
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Disable bias gelu fusion if we are disabling bias altogether
if not args.add_bias_linear:
args.bias_gelu_fusion = False
# Retro checks.
if args.retro_add_retriever:
# Train samples should be auto-loaded.
assert args.train_samples is not None, \
"args.train_samples should be auto-loaded from the retro config."
# Sequence parallelism unsupported.
assert not args.sequence_parallel, \
"retro currently does not support sequence parallelism."
# Pipeline parallelism unsupported.
assert args.pipeline_model_parallel_size == 1, \
"retro currently does not support pipeline parallelism."
if args.decoupled_lr is not None or args.decoupled_min_lr is not None:
assert not args.use_legacy_models, \
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
assert not args.use_dist_ckpt, "Distributed checkpointing does not work with decoupled LR yet."
# Legacy RoPE arguments
if args.use_rotary_position_embeddings:
args.position_embedding_type = 'rope'
if args.rotary_interleaved and args.apply_rope_fusion:
raise RuntimeError('--rotary-interleaved does not work with rope_fusion.')
if args.rotary_interleaved and args.use_legacy_models:
raise RuntimeError('--rotary-interleaved is not supported in legacy models.')
# Would just need to add 'NoPE' as a position_embedding_type to support this, but for now
# don't allow it to keep things simple
if not args.add_position_embedding and args.position_embedding_type != 'rope':
raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')
# MoE Spec check
if args.num_experts == 0:
args.num_experts = None
if args.num_experts is not None:
assert args.spec is None, "Model Spec must be None when using MoEs"
# Context parallel
if args.context_parallel_size > 1:
assert not args.use_legacy_models, "Context parallelism is not supported in legacy models."
# Expert parallelism check
if args.expert_model_parallel_size > 1:
assert args.num_experts is not None, "num_experts must be non None to use expert model parallelism"
assert args.num_experts % args.expert_model_parallel_size == 0, \
"Number of experts should be a multiple of expert model parallel_size."
assert not args.fp16, \
"Expert parallelism is not supported with fp16 training."
# Distributed checkpointing checks
if args.use_dist_ckpt and args.use_legacy_models:
raise RuntimeError('--use-dist-ckpt is not supported in legacy models.')
# Data blend checks
assert args.mock_data + \
bool(args.data_path) + \
any([args.train_data_path, args.valid_data_path, args.test_data_path]) \
<= 1, "A single data source must be provided in training mode, else None"
if args.use_tp_pp_dp_mapping:
assert args.context_parallel_size * args.expert_model_parallel_size <= 1, \
"context_parallel and expert_model_parallel can't be used with tp-pp-dp mapping."
# Deterministic mode
if args.deterministic_mode:
assert not args.use_flash_attn, "Flash attention can not be used in deterministic mode."
assert not args.cross_entropy_loss_fusion, "Cross Entropy Fusion is currently not deterministic."
all_reduce_choices = ["Tree", "Ring", "CollnetDirect", "CollnetChain", "^NVLS"]
assert os.getenv("NCCL_ALGO", -1) != -1 and os.getenv("NCCL_ALGO") in all_reduce_choices, \
f"NCCL_ALGO must be one of {all_reduce_choices}."
torch.use_deterministic_algorithms(True)
# Update the printed args to reflect that `apply_query_key_layer_scaling` also controls `attention_softmax_in_fp32`
if args.apply_query_key_layer_scaling:
args.attention_softmax_in_fp32 = True
# Checkpointing
if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
print('--ckpt-fully-parallel-save flag is deprecated and has no effect.'
' Use --no-ckpt-fully-parallel-save to disable parallel save.')
if (
args.use_dist_ckpt
and not args.ckpt_fully_parallel_save
and args.use_distributed_optimizer
and args.rank == 0
):
print('Warning: With non-parallel ckpt save and DistributedOptimizer,'
' it will be impossible to resume training with different parallelism.'
' Consider removing flag --no-ckpt-fully-parallel-save.')
# Print arguments.
_print_args("arguments", args)
return args
def _print_args(title, args):
"""Print arguments."""
if args.rank == 0:
print(f'------------------------ {title} ------------------------',
flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print(f'-------------------- end of {title} ---------------------',
flush=True)
def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def core_transformer_config_from_args(args, config_class=None):
# Config class.
config_class = config_class or TransformerConfig
# Translate args to core transformer configuration
kw_args = {}
for f in dataclasses.fields(config_class):
if hasattr(args, f.name):
kw_args[f.name] = getattr(args, f.name)
kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p
kw_args['layernorm_epsilon'] = args.norm_epsilon
kw_args['deallocate_pipeline_outputs'] = True
kw_args['pipeline_dtype'] = args.params_dtype
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
kw_args['num_moe_experts'] = args.num_experts
kw_args['rotary_interleaved'] = args.rotary_interleaved
if args.swiglu:
kw_args['activation_func'] = F.silu
kw_args['gated_linear_unit'] = True
kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion
else:
kw_args['bias_activation_fusion'] = args.bias_gelu_fusion
if args.squared_relu:
assert not args.swiglu
kw_args['activation_func'] = squared_relu
if args.init_method_xavier_uniform:
kw_args['init_method'] = torch.nn.init.xavier_uniform_
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
if args.group_query_attention:
kw_args['num_query_groups'] = args.num_query_groups
else:
kw_args['num_query_groups'] = None
kw_args['config_logger_dir'] = args.config_logger_dir
# Return config.
return config_class(**kw_args)
def _add_transformer_engine_args(parser):
group = parser.add_argument_group(title='Transformer-Engine')
group.add_argument('--fp8-format', default=None,
choices=['e4m3', 'hybrid'],
help='Which fp8 format scheme to use for FP8 tensors in the forward and backward pass',
dest='fp8')
group.add_argument('--fp8-margin', type=int, default=0,
help='Scaling margin for fp8',
dest='fp8_margin')
group.add_argument('--fp8-interval', type=int, default=1,
help='Scaling update interval for fp8',
dest='fp8_interval')
group.add_argument('--fp8-amax-history-len', type=int, default=1,
help='Number of steps for which amax history is recorded per tensor',
dest='fp8_amax_history_len')
group.add_argument('--fp8-amax-compute-algo', default='most_recent',
choices=['most_recent', 'max'],
help='Algorithm for computing amax from history',
dest='fp8_amax_compute_algo')
group.add_argument('--no-fp8-wgrad', action='store_false',
help='Execute wgrad in higher precision even for FP8 runs',
dest='fp8_wgrad')
group.add_argument('--transformer-impl', default='transformer_engine',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.')
return parser
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
group.add_argument('--max-tokens-to-oom',
type=int, default=12000,
help='Maximum number of tokens during inference'
'tokens here is # in prompt + # to generate'
'Allows us to throw an error before OOM crashes server')
group.add_argument('--output-bert-embeddings', action='store_true',
help='Output Bert embeddings (via mean pooling) from '
'model, rather than its binary head output or entire '
'hidden batch.')
group.add_argument('--bert-embedder-type', default="megatron",
choices=["megatron", "huggingface"],
help='Select either Megatron or Huggingface as the '
'Bert embedder.')
return parser
def _add_retro_args(parser):
group = parser.add_argument_group(title='retro')
group.add_argument('--retro-project-dir', default=None,
help='Retro project directory, which contains the '
'preprocessed data for pretraining. This directory '
'is built during preprocessing (see '
'tools/retro/README.md), and contains subdirectories '
'for the chunk database and pretraining neighbors.')
group.add_argument('--retro-add-retriever',
action='store_true', default=False,
help='Add a retriever to the transformer, for use in '
'pretraining a Retro model.')
group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
help='Set number of training iterations for cyclic '
'Retro training.')
group.add_argument('--retro-encoder-layers', type=int, default=2,
help='Number of layers to use for the retrieval '
'encoder.')
group.add_argument('--retro-encoder-hidden-dropout',
type=float, default=0.1, help='Hidden dropout for '
'retrieval encoder.')
group.add_argument('--retro-encoder-attention-dropout',
type=float, default=0.1, help='Attention dropout for '
'retrieval encoder.')
group.add_argument("--retro-num-neighbors", type=int, default=2,
help='Number of neighbors to retrieve during '
'pretraining.')
group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
help='Number of chunks to retrieve from the retrieval '
'database.')
group.add_argument("--retro-attention-gate", type=float, default=1,
help="Gated cross attention.")
group.add_argument("--retro-no-verify-neighbor-count", action="store_false",
dest="retro_verify_neighbor_count",
help="Skip verifying that len(GPT dataset) == len(saved "
"neighbors).")
# Enforce argument naming convention.
for action in group._group_actions:
prefix = action.dest.split("_")[0]
assert prefix == "retro", \
"Retro args must be prefixed with '--retro-*', for consistent " \
"styling. Please fix '%s'." % ", ".join(action.option_strings)
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')
group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.')
group.add_argument('--encoder-num-layers', type=int, default=None,
help='Number of encoder transformer layers.')
group.add_argument('--decoder-num-layers', type=int, default=None,
help='Number of decoder transformer layers.')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. '
'This is set to 4*hidden-size if not provided')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--kv-channels', type=int, default=None,
help='Projection weights dimension in multi-head '
'attention. This is set to '
' args.hidden_size // args.num_attention_heads '
'if not provided.')
group.add_argument('--group-query-attention', action='store_true',
help='Use group-query attention.')
group.add_argument('--num-query-groups', type=int, default=1)
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
group.add_argument('--position-embedding-type', type=str, default='learned_absolute',
choices=['learned_absolute', 'rope', 'none'],
help='Position embedding type.')
group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not. '
'Deprecated: use --position-embedding-type')
group.add_argument('--rotary-base', type=int, default=10000,
help='Base to use for rotary positional embeddings, default 10000')
group.add_argument('--rotary-percent', type=float, default=1.0,
help='Percent of rotary dimension to use, default 100%%')
group.add_argument('--rotary-interleaved', action='store_true',
help='Use interleaved rotary embedding.')
group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None,
help='Sequence length interpolation factor for rotary embeddings.')
group.add_argument('--no-position-embedding',
action='store_false',
help='Disable position embedding. Deprecated: use --position-embedding-type',
dest='add_position_embedding')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
group.add_argument('--normalization', default='LayerNorm',
choices=['LayerNorm', 'RMSNorm'],
help='Which normalization technique to use.')
group.add_argument('--norm-epsilon', type=float, default=1e-5,
help='Epsilon for layer norm and RMS norm.')
group.add_argument('--apply-layernorm-1p', action='store_true',
help='Adjust LayerNorm weights such that they are centered '
'around zero. This improves numerical stability.')
group.add_argument('--apply-residual-connection-post-layernorm',
action='store_true',
help='If set, use original BERT residula connection '
'ordering.')
group.add_argument('--openai-gelu', action='store_true',
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'reasons.')
group.add_argument('--squared-relu', action='store_true',
help='Use squared relu activation instead of default gelu')
group.add_argument('--swiglu', action='store_true',
help='Use gated linear units and SiLU activation instead of default gelu')
group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with '
'Torch ONNX exporter')
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
return parser
def _add_straggler_detector_args(parser):
group = parser.add_argument_group(title='straggler')
group.add_argument('--log-straggler', action='store_true',
help='If set, tracks and logs straggler per GPU.')
group.add_argument('--disable-straggler-on-startup', action='store_true',
help='If set, StragglerDetector is disabled on startup.')
group.add_argument('--straggler-ctrlr-port', type=int, default=65535,
help='Port number to toggle StragglerDetector on/off at runtime')
group.add_argument('--straggler-minmax-count', type=int, default=1,
help='Number of ranks to report with high/low estimated throughput')
return parser
def _add_one_logger_args(parser):
group = parser.add_argument_group(title='one logger')
group.add_argument('--no-one-logger', action='store_false',
help='If set, disable using one_logger to track E2E metrics'
'Note that one_logger is an internal tool and not '
'available externally. For installation, please go to '
'https://confluence.nvidia.com/display/MLWFO/Package+Repositories'
'for more details',
dest='enable_one_logger')
group.add_argument('--one-logger-project', type=str, default='megatron-lm',
help='The one-logger project name. Will ignore if '
'--no-one-logger is set')
group.add_argument('--one-logger-run-name', type=str, default=None,
help='The one-logger run name displayed. Will ignore if '
'--no-one-logger is set')
group.add_argument('--one-logger-async', action='store_true',
help='If set, forces one_logger to use async mode.')
group.add_argument('--app-tag-run-name', type=str, default=None,
help='Jobs belonging to same training run, suppose to '
'have the same name. It will be used to track progress of '
'a training done over multiple different jobs')
group.add_argument('--app-tag-run-version', type=str, default='0.0.0',
help='The version of the training of which current job is '
'part of. It will be used to track the changes in the '
'application side which might change the performance '
'baseline')
return parser
def _add_config_logger_args(parser):
group = parser.add_argument_group(title='config logger')
group.add_argument('--config-logger-dir', type=str, default='',
help='If set, will dump all configs to --config-logger-dir',
dest='config_logger_dir')
return parser
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--log-throughput', action='store_true',
help='If set, calculate and log throughput per GPU.')
group.add_argument('--log-progress', action='store_true',
help='If set, log progress (in terms of number of processed tokens and '
'number of floating-point operations) to progress.txt file in checkpoint '
'directory.')
group.add_argument('--timing-log-level', type=int,
default=0, choices=range(0,3),
help='Granularity level to measure and report timing. '
' 0: report only iteration time and make sure timing '
' does not introduce extra overhead.'
' 1: report timing for operations that are executed '
' very limited times (basically once) during '
' each iteration (such as gradient all-reduce) '
' 2: report timing for operations that migh be '
' executed numerous times during each iteration. '
'Note that setting the level to 1 or 2 might '
'cause increase in iteration time.')
group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
help='If not set, use barrier with level 1 time '
'measurements. Note that this is up to the user '
'to make sure calling barrier with their timers '
'will not result in hangs. This can happen if for '
'example the user adds a level 1 timer that is not '
'called by all ranks.',
dest='barrier_with_L1_time')
group.add_argument('--timing-log-option', type=str, default='minmax',
choices=['max', 'minmax', 'all'],
help='Options for logging timing:'
' max: report the max timing across all ranks'
' minmax: report min and max timings across all ranks'
' all: report timings of all ranks.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
help='Size of the tensorboard queue for pending events '
'and summaries before one of the ‘add’ calls forces a '
'flush to disk.')
group.add_argument('--log-timers-to-tensorboard', action='store_true',
help='If set, write timers to tensorboard.')
group.add_argument('--no-log-loss-scale-to-tensorboard',
action='store_false',
help='Disable loss-scale logging to tensorboard.',
dest='log_loss_scale_to_tensorboard')
group.add_argument('--log-validation-ppl-to-tensorboard',
action='store_true',
help='If set, write validation perplexity to '
'tensorboard.')
group.add_argument('--log-memory-to-tensorboard',
action='store_true',
help='Enable memory logging to tensorboard.')
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
group.add_argument('--wandb-project', type=str, default='',
help='The wandb project name. Ignore wandb by default.')
group.add_argument('--wandb-exp-name', type=str, default='',
help='The wandb experiment name.')
group.add_argument('--wandb-save-dir', type=str, default='',
help='Path to save the wandb results locally.')
group.add_argument('--logging-level', type=int, default=None,
help='Set default logging level')
return parser
def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization')
group.add_argument('--attention-dropout', type=float, default=0.1,
help='Post attention dropout probability.')
group.add_argument('--hidden-dropout', type=float, default=0.1,
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
help='First coefficient for computing running averages '
'of gradient and its square')
group.add_argument('--adam-beta2', type=float, default=0.999,
help='Second coefficient for computing running averages '
'of gradient and its square')
group.add_argument('--adam-eps', type=float, default=1e-08,
help='Term added to the denominator to improve'
'numerical stability')
group.add_argument('--sgd-momentum', type=float, default=0.9,
help='Momentum factor for sgd')
return parser
def _add_training_args(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size times number of micro batches.')
group.add_argument('--batch-size', type=int, default=None,
help='Old batch size parameter, do not use. '
'Use --micro-batch-size instead')
group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '