From 7b8d43c5d5aa39acb0b798efade9f8fdec61b731 Mon Sep 17 00:00:00 2001 From: Duncan Riach Date: Thu, 15 Aug 2024 11:54:23 -0700 Subject: [PATCH] ADLR/megatron-lm!1906 - Fix model instantiation for text gen server --- megatron/training/arguments.py | 6 +++--- pretrain_mamba.py | 4 +++- tools/run_mamba_text_generation_server.py | 6 ++++-- tools/run_text_generation_server.py | 3 ++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index ec1d665215..b313b2d93e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -190,10 +190,10 @@ def validate_args(args, defaults={}): # Checks. if args.rank == 0: print('using world size: {}, data-parallel size: {}, ' - 'context-parallel size: {} ' + 'context-parallel size: {}, ' 'tensor-model-parallel size: {}, ' - 'encoder-tensor-model-parallel size: {}' - 'pipeline-model-parallel size: {} ' + 'encoder-tensor-model-parallel size: {}, ' + 'pipeline-model-parallel size: {}, ' 'encoder-pipeline-model-parallel size: {}'.format( args.world_size, args.data_parallel_size, args.context_parallel_size, diff --git a/pretrain_mamba.py b/pretrain_mamba.py index 9132ce2c62..f8202b6eac 100644 --- a/pretrain_mamba.py +++ b/pretrain_mamba.py @@ -75,7 +75,9 @@ def model_provider(pre_process=True, post_process=True) -> MambaModel: fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, parallel_output=True, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base ) for l in range(model.decoder.num_layers_per_pipeline_rank): diff --git a/tools/run_mamba_text_generation_server.py b/tools/run_mamba_text_generation_server.py index 844d018055..2c7c6f44c2 100644 --- a/tools/run_mamba_text_generation_server.py +++ b/tools/run_mamba_text_generation_server.py @@ -63,9 +63,11 @@ def model_provider(pre_process=True, post_process=True) -> MambaModel: hybrid_override_pattern=args.hybrid_override_pattern, post_process=post_process, fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, + parallel_output=False, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base ) for l in range(model.decoder.num_layers_per_pipeline_rank): diff --git a/tools/run_text_generation_server.py b/tools/run_text_generation_server.py index 9acc66e337..861d8d6d73 100644 --- a/tools/run_text_generation_server.py +++ b/tools/run_text_generation_server.py @@ -82,7 +82,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat parallel_output=False, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base ) return model