diff --git a/src/ucx_rma_plugin.c b/src/ucx_rma_plugin.c index 400bb1d..5770396 100644 --- a/src/ucx_rma_plugin.c +++ b/src/ucx_rma_plugin.c @@ -726,7 +726,8 @@ static ucs_status_t nccl_ucp_shared_put(nccl_ucp_comm_t *comm, void *va, return UCS_PTR_STATUS(status_ptr); } -static int nccl_ucp_mh_update(nccl_ucp_comm_t *comm, nccl_ucp_memh_t *mh) { +static ncclResult_t nccl_ucp_mh_update(nccl_ucp_comm_t *comm, + nccl_ucp_memh_t *mh) { ucs_status_t status; nccl_ucp_packed_rkey_t *packed, *remote; @@ -741,11 +742,16 @@ static int nccl_ucp_mh_update(nccl_ucp_comm_t *comm, nccl_ucp_memh_t *mh) { status = nccl_ucp_shared_put(comm, packed, sizeof(*packed), remote, &comm->inflight_rkey); + if (UCS_STATUS_IS_ERR(status)) { + WARN("Failed to send packed rkey"); + return ncclSystemError; + } + comm->inflight_rkey += (status == UCS_INPROGRESS); - mh->sent = !UCS_STATUS_IS_ERR(status); + mh->sent = 1; } - return mh->sent == 0; + return ncclSuccess; } static ncclResult_t nccl_ucx_rma_regmr(void *reg_comm, void *data, size_t size, @@ -781,7 +787,6 @@ static ncclResult_t nccl_ucx_rma_irecv(void *recv_comm, int n, void **data, void **request) { nccl_ucp_comm_t *comm = recv_comm; nccl_ucp_memh_t **mh = (nccl_ucp_memh_t**)mhandle; - int missed = 0; nccl_ucp_req_t *req; nccl_ucp_rtr_t *rtr; nccl_ucp_atp_t *atp; @@ -804,7 +809,7 @@ static ncclResult_t nccl_ucx_rma_irecv(void *recv_comm, int n, void **data, *request = NULL; for (i = 0; i < n; i++) { - missed += nccl_ucp_mh_update(comm, mh[i]); + NCCLCHECK(nccl_ucp_mh_update(comm, mh[i])); rtr->chunk[i].data = (uint64_t)data[i]; rtr->chunk[i].rkey_id = mh[i]->rkey_id; @@ -822,11 +827,6 @@ static ncclResult_t nccl_ucx_rma_irecv(void *recv_comm, int n, void **data, memcpy(atp->sizes, sizes, sizeof(*sizes) * n); } - if (missed) { - ucp_worker_progress(comm->worker->ucp_worker); - return ncclSuccess; - } - remote = &comm->remote.share->rtr[comm->rtr_id & NCCL_UCP_RING_MASK]; status = nccl_ucp_shared_put( comm, rtr, sizeof(*rtr) - (NCCL_UCP_MAX_RECV - n) * sizeof(*rtr->chunk),