diff --git a/src/components/tl/ucp/allgather/allgather_ring.c b/src/components/tl/ucp/allgather/allgather_ring.c index 07178aea25..b0aeeed214 100644 --- a/src/components/tl/ucp/allgather/allgather_ring.c +++ b/src/components/tl/ucp/allgather/allgather_ring.c @@ -40,6 +40,7 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) size_t count = TASK_ARGS(task).dst.info.count; ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; size_t data_size = (count / tsize) * ucc_dt_size(dt); + ucc_status_t status = UCC_OK; ucc_rank_t sendto, recvfrom, sblock, rblock; int step; void *buf; @@ -69,7 +70,14 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) } } ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); - task->super.status = UCC_OK; + if (task->allgather_ring.etask) { + status = ucc_ee_executor_task_test(task->allgather_ring.etask); + if (status == UCC_INPROGRESS) { + return; + } + ucc_ee_executor_task_finalize(task->allgather_ring.etask); + } + task->super.status = status; out: UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_done", 0); } @@ -88,22 +96,50 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task) ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; size_t data_size = (count / tsize) * ucc_dt_size(dt); ucc_status_t status; - ucc_rank_t block; + ucc_rank_t sendto, recvfrom, sblock, rblock; + ucc_ee_executor_t *exec; + ucc_ee_executor_task_args_t eargs; + void *buf; UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_start", 0); ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); + sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize); + recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize); + sblock = task->allgather_ring.get_send_block(&task->subset, trank, tsize, 0); + rblock = task->allgather_ring.get_recv_block(&task->subset, trank, tsize, 0); if (!UCC_IS_INPLACE(TASK_ARGS(task))) { - block = task->allgather_ring.get_send_block(&task->subset, trank, tsize, - 0); - status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block), - sbuf, data_size, rmem, smem); - if (ucc_unlikely(UCC_OK != status)) { + status = ucc_coll_task_get_executor(&task->super, &exec); + if (ucc_unlikely(status != UCC_OK)) { return status; } + + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.src = sbuf; + eargs.copy.dst = PTR_OFFSET(rbuf, data_size * sblock); + eargs.copy.len = data_size; + + status = ucc_ee_executor_task_post(exec, &eargs, + &task->allgather_ring.etask); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + buf = sbuf; + } else { + task->allgather_ring.etask = NULL; + buf = PTR_OFFSET(rbuf, data_size * sblock); } + UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buf, data_size, smem, sendto, team, task), + task, out); + UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, rblock * data_size), + data_size, rmem, recvfrom, team, task), + task, out); + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); + +out: + return status; } ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task) @@ -128,6 +164,9 @@ ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task) task->allgather_ring.get_recv_block = ucc_tl_ucp_allgather_ring_get_recv_block; task->super.post = ucc_tl_ucp_allgather_ring_start; task->super.progress = ucc_tl_ucp_allgather_ring_progress; + if (!UCC_IS_INPLACE(TASK_ARGS(task))) { + task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + } return UCC_OK; } diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index 6ab2c661dd..0843cbb886 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -180,6 +180,7 @@ typedef struct ucc_tl_ucp_task { ucc_rank_t trank, ucc_rank_t tsize, int step); + ucc_ee_executor_task_t *etask; } allgather_ring; struct { ucc_rank_t dist; diff --git a/src/components/tl/ucp/tl_ucp_service_coll.c b/src/components/tl/ucp/tl_ucp_service_coll.c index bf16cf00d7..390a9fb16c 100644 --- a/src/components/tl/ucp/tl_ucp_service_coll.c +++ b/src/components/tl/ucp/tl_ucp_service_coll.c @@ -74,7 +74,7 @@ static ucc_status_t ucc_tl_ucp_service_coll_stop_executor(ucc_coll_task_t *task) ucc_status_t ucc_tl_ucp_service_allreduce(ucc_base_team_t *team, void *sbuf, void *rbuf, ucc_datatype_t dt, size_t count, ucc_reduction_op_t op, - ucc_subset_t subset, + ucc_subset_t subset, ucc_coll_task_t **task_p) { ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); @@ -140,7 +140,7 @@ ucc_status_t ucc_tl_ucp_service_allreduce(ucc_base_team_t *team, void *sbuf, ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf, void *rbuf, size_t msgsize, - ucc_subset_t subset, + ucc_subset_t subset, ucc_coll_task_t **task_p) { ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); @@ -178,6 +178,14 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf, task->n_polls = npolls; task->super.progress = ucc_tl_ucp_allgather_ring_progress; task->super.finalize = ucc_tl_ucp_coll_finalize; + if (in_place) { + task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; + } + + status = ucc_tl_ucp_service_coll_start_executor(&task->super); + if (status != UCC_OK) { + goto free_task; + } status = ucc_tl_ucp_allgather_ring_start(&task->super); if (status != UCC_OK) { @@ -187,7 +195,8 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf, *task_p = &task->super; return status; finalize_coll: - ucc_tl_ucp_coll_finalize(*task_p); + ucc_tl_ucp_coll_finalize(&task->super); + ucc_tl_ucp_service_coll_stop_executor(&task->super); free_task: ucc_tl_ucp_put_task(task); return status; @@ -195,7 +204,7 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf, ucc_status_t ucc_tl_ucp_service_bcast(ucc_base_team_t *team, void *buf, size_t msgsize, ucc_rank_t root, - ucc_subset_t subset, + ucc_subset_t subset, ucc_coll_task_t **task_p) { ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); @@ -239,7 +248,8 @@ ucc_status_t ucc_tl_ucp_service_bcast(ucc_base_team_t *team, void *buf, return status; } -void ucc_tl_ucp_service_update_id(ucc_base_team_t *team, uint16_t id) { +void ucc_tl_ucp_service_update_id(ucc_base_team_t *team, uint16_t id) +{ ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t); tl_team->super.super.params.id = id;