diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index e305c77ea..7227916e6 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -4614,6 +4614,9 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)send_comm; nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle; nccl_net_ofi_rdma_req_t *req = NULL; + uint16_t msg_seq_num = s_comm->next_msg_seq_num; + bool polled_cq = false; + bool have_ctrl = false; assert(s_comm != NULL); @@ -4644,15 +4647,12 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t * TODO: Use NCCL provided tags when using grouped receives aka * props->maxRecvs > 1. */ - - bool have_ctrl = false; - uint16_t msg_seq_num = s_comm->next_msg_seq_num; - void *elem; nccl_ofi_msgbuff_elemtype_t type; nccl_ofi_msgbuff_status_t msg_stat; nccl_ofi_msgbuff_result_t mb_res; +retry: /* Retrive entry from message buffer for msg_seq_num index */ mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, &elem, &type, &msg_stat); @@ -4689,6 +4689,17 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t goto error; } + /* look for control messages and then retry the message search + to avoid unnecessary polling / queueing. */ + if (OFI_UNLIKELY(!polled_cq && !have_ctrl)) { + ret = ofi_process_cq_rail(ep, &ep->control_rail); + if (ret != 0) { + goto error; + } + polled_cq = true; + goto retry; + } + /* Determine if this should be sent eagerly. */ bool eager = false; if ((!have_ctrl && size <= eager_max_size) ||