Skip to content

Commit 20e1f1c

Browse files
clumsyazzhipa
andauthored
feat: print expert groups on megatron init (#13874)
Signed-off-by: Alexander Zhipa <[email protected]> Co-authored-by: Alexander Zhipa <[email protected]>
1 parent 0ee26ae commit 20e1f1c

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

nemo/lightning/megatron_init.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,16 +499,24 @@ def generator_wrapper(group_type, is_expert=False, **kwargs):
499499
# EP rank
500500
expert_model_parallel_rank = 0
501501
if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1:
502+
all_expert_model_parallel_ranks = []
502503
for ranks in generator_wrapper('ep', is_expert=True):
504+
all_expert_model_parallel_ranks.append(ranks)
503505
if rank in ranks:
504506
expert_model_parallel_rank = list(ranks).index(rank)
507+
logging.info(f'All expert model parallel group ranks: {all_expert_model_parallel_ranks}')
508+
logging.info(f'Rank {rank} has expert model parallel rank: {expert_model_parallel_rank}')
505509

506510
# ETP
507511
expert_tensor_parallel_rank = 0
508512
if expert_tensor_parallel_size_ is not None and expert_tensor_parallel_size_ > 1:
513+
all_expert_tensor_parallel_ranks = []
509514
for ranks in generator_wrapper('tp', is_expert=True):
515+
all_expert_tensor_parallel_ranks.append(ranks)
510516
if rank in ranks:
511517
expert_tensor_parallel_rank = list(ranks).index(rank)
518+
logging.info(f'All expert tensor parallel group ranks: {all_expert_tensor_parallel_ranks}')
519+
logging.info(f'Rank {rank} has expert tensor parallel rank: {expert_tensor_parallel_rank}')
512520

513521
# Build the pipeline model-parallel groups and embedding groups
514522
# (first and last rank in each pipeline model-parallel group).

0 commit comments

Comments
 (0)