@@ -34,6 +34,7 @@ NCCL_PARAM(SharpGroupSizeThresh, "SHARP_GROUP_SIZE_THRESH", 2);
34
34
NCCL_PARAM (SharpV3Datatypes , "SHARP_V3_DATATYPES" , 2 );
35
35
NCCL_PARAM (SharpDisableRS , "SHARP_DISABLE_REDUCE_SCATTER" , 0 );
36
36
NCCL_PARAM (SharpDisableAG , "SHARP_DISABLE_ALLGATHER" , 0 );
37
+ NCCL_PARAM (enableSharpTrace , "SHARP_COLL_TRACE" , 0 );
37
38
38
39
enum ncclSharpRequestType {
39
40
NCCL_SHARP_REQ_SHARP_COLL ,
@@ -500,6 +501,9 @@ ncclResult_t ncclSharpIallreduce(void* collComm, void* sendData, void* recvData,
500
501
reduce_spec .op = op_type ;
501
502
reduce_spec .aggr_mode = SHARP_AGGREGATION_NONE ;
502
503
504
+ if (ncclParamenableSharpTrace () && cComm -> rank == 0 )
505
+ INFO (NCCL_COLL , "Allreduce count:%d, op:%d dtype:%d " , count , op_type , sharp_type );
506
+
503
507
#if BLOCKING == 0
504
508
if (SHARP_COLL_SUCCESS != sharp_coll_do_allreduce_nb (cComm -> sharpCollComm , & reduce_spec , & req -> sharpRequest )) {
505
509
WARN ("SHARP allreduce failed\n" );
@@ -546,6 +550,10 @@ ncclResult_t ncclSharpIallgather(void* collComm, void* sendData, int nRecvParts,
546
550
gather_spec .size = recvParts [0 ].size ;
547
551
gather_spec .offset = windowOffset ;
548
552
553
+ if (ncclParamenableSharpTrace () && cComm -> rank == 0 )
554
+ INFO (NCCL_COLL , "Allgather size:%lu bytesPerRank:%lu windowOffset:%lu windowBytes:%lu" ,
555
+ recvParts [0 ].size , bytesPerRank , windowOffset , windowBytes );
556
+
549
557
#if BLOCKING == 0
550
558
if (SHARP_COLL_SUCCESS != sharp_coll_do_allgather_nb (cComm -> sharpCollComm , & gather_spec , & req -> sharpRequest )) {
551
559
WARN ("SHARP Allgather failed\n" );
@@ -611,6 +619,10 @@ ncclResult_t ncclSharpIreducescatter(void* collComm, int nSendParts, ncclNetSGE_
611
619
reduce_spec .op = op_type ;
612
620
reduce_spec .aggr_mode = SHARP_AGGREGATION_NONE ;
613
621
622
+ if (ncclParamenableSharpTrace () && cComm -> rank == 0 )
623
+ INFO (NCCL_COLL , "ReduceScatter bytesPerRank:%lu windowOffset:%lu windowBytes:%lu op_type:%d dtype:%d" ,
624
+ bytesPerRank , windowOffset , windowBytes , op_type , sharp_type );
625
+
614
626
#if BLOCKING == 0
615
627
if (SHARP_COLL_SUCCESS != sharp_coll_do_reduce_scatter_nb (cComm -> sharpCollComm , & reduce_spec , & req -> sharpRequest )) {
616
628
WARN ("SHARP reduce_scatter failed\n" );
0 commit comments