Skip to content

Commit

Permalink
Merge branch 'duncan/fix-mamba-text-gen' into 'main'
Browse files Browse the repository at this point in the history
Fix model instantiation for text gen server

See merge request ADLR/megatron-lm!1906
  • Loading branch information
jaredcasper committed Aug 15, 2024
2 parents e8f8e63 + 7b8d43c commit 2d487b1
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
6 changes: 3 additions & 3 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ def validate_args(args, defaults={}):
# Checks.
if args.rank == 0:
print('using world size: {}, data-parallel size: {}, '
'context-parallel size: {} '
'context-parallel size: {}, '
'tensor-model-parallel size: {}, '
'encoder-tensor-model-parallel size: {}'
'pipeline-model-parallel size: {} '
'encoder-tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {}, '
'encoder-pipeline-model-parallel size: {}'.format(
args.world_size, args.data_parallel_size,
args.context_parallel_size,
Expand Down
4 changes: 3 additions & 1 deletion pretrain_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def model_provider(pre_process=True, post_process=True) -> MambaModel:
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base
)

for l in range(model.decoder.num_layers_per_pipeline_rank):
Expand Down
6 changes: 4 additions & 2 deletions tools/run_mamba_text_generation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ def model_provider(pre_process=True, post_process=True) -> MambaModel:
hybrid_override_pattern=args.hybrid_override_pattern,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
parallel_output=False,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base
)

for l in range(model.decoder.num_layers_per_pipeline_rank):
Expand Down
3 changes: 2 additions & 1 deletion tools/run_text_generation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
parallel_output=False,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base
)

return model
Expand Down

0 comments on commit 2d487b1

Please sign in to comment.