Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/UCP: make local copy nb in allgather #867

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions src/components/tl/ucp/allgather/allgather_ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
size_t count = TASK_ARGS(task).dst.info.count;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_status_t status = UCC_OK;
ucc_rank_t sendto, recvfrom, sblock, rblock;
int step;
void *buf;
Expand Down Expand Up @@ -69,7 +70,14 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
}
}
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
task->super.status = UCC_OK;
if (task->allgather_ring.etask) {
status = ucc_ee_executor_task_test(task->allgather_ring.etask);
if (status == UCC_INPROGRESS) {
return;
}
ucc_ee_executor_task_finalize(task->allgather_ring.etask);
}
task->super.status = status;
out:
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_done", 0);
}
Expand All @@ -88,22 +96,50 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task)
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_status_t status;
ucc_rank_t block;
ucc_rank_t sendto, recvfrom, sblock, rblock;
ucc_ee_executor_t *exec;
ucc_ee_executor_task_args_t eargs;
void *buf;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_ring_start", 0);
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize);
recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize);
sblock = task->allgather_ring.get_send_block(&task->subset, trank, tsize, 0);
rblock = task->allgather_ring.get_recv_block(&task->subset, trank, tsize, 0);
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
block = task->allgather_ring.get_send_block(&task->subset, trank, tsize,
0);
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block),
sbuf, data_size, rmem, smem);
if (ucc_unlikely(UCC_OK != status)) {
status = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}

eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.src = sbuf;
eargs.copy.dst = PTR_OFFSET(rbuf, data_size * sblock);
eargs.copy.len = data_size;

status = ucc_ee_executor_task_post(exec, &eargs,
&task->allgather_ring.etask);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
buf = sbuf;
} else {
task->allgather_ring.etask = NULL;
buf = PTR_OFFSET(rbuf, data_size * sblock);
}

UCPCHECK_GOTO(ucc_tl_ucp_send_nb(buf, data_size, smem, sendto, team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb(PTR_OFFSET(rbuf, rblock * data_size),
data_size, rmem, recvfrom, team, task),
task, out);

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);

out:
return status;
}

ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task)
Expand All @@ -128,6 +164,9 @@ ucc_status_t ucc_tl_ucp_allgather_ring_init_common(ucc_tl_ucp_task_t *task)
task->allgather_ring.get_recv_block = ucc_tl_ucp_allgather_ring_get_recv_block;
task->super.post = ucc_tl_ucp_allgather_ring_start;
task->super.progress = ucc_tl_ucp_allgather_ring_progress;
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
}

return UCC_OK;
}
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ typedef struct ucc_tl_ucp_task {
ucc_rank_t trank,
ucc_rank_t tsize,
int step);
ucc_ee_executor_task_t *etask;
} allgather_ring;
struct {
ucc_rank_t dist;
Expand Down
20 changes: 15 additions & 5 deletions src/components/tl/ucp/tl_ucp_service_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ static ucc_status_t ucc_tl_ucp_service_coll_stop_executor(ucc_coll_task_t *task)
ucc_status_t ucc_tl_ucp_service_allreduce(ucc_base_team_t *team, void *sbuf,
void *rbuf, ucc_datatype_t dt,
size_t count, ucc_reduction_op_t op,
ucc_subset_t subset,
ucc_subset_t subset,
ucc_coll_task_t **task_p)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
Expand Down Expand Up @@ -140,7 +140,7 @@ ucc_status_t ucc_tl_ucp_service_allreduce(ucc_base_team_t *team, void *sbuf,

ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
void *rbuf, size_t msgsize,
ucc_subset_t subset,
ucc_subset_t subset,
ucc_coll_task_t **task_p)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
Expand Down Expand Up @@ -178,6 +178,14 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
task->n_polls = npolls;
task->super.progress = ucc_tl_ucp_allgather_ring_progress;
task->super.finalize = ucc_tl_ucp_coll_finalize;
if (in_place) {
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
}

status = ucc_tl_ucp_service_coll_start_executor(&task->super);
if (status != UCC_OK) {
goto free_task;
}

status = ucc_tl_ucp_allgather_ring_start(&task->super);
if (status != UCC_OK) {
Expand All @@ -187,15 +195,16 @@ ucc_status_t ucc_tl_ucp_service_allgather(ucc_base_team_t *team, void *sbuf,
*task_p = &task->super;
return status;
finalize_coll:
ucc_tl_ucp_coll_finalize(*task_p);
ucc_tl_ucp_coll_finalize(&task->super);
ucc_tl_ucp_service_coll_stop_executor(&task->super);
free_task:
ucc_tl_ucp_put_task(task);
return status;
}

ucc_status_t ucc_tl_ucp_service_bcast(ucc_base_team_t *team, void *buf,
size_t msgsize, ucc_rank_t root,
ucc_subset_t subset,
ucc_subset_t subset,
ucc_coll_task_t **task_p)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
Expand Down Expand Up @@ -239,7 +248,8 @@ ucc_status_t ucc_tl_ucp_service_bcast(ucc_base_team_t *team, void *buf,
return status;
}

void ucc_tl_ucp_service_update_id(ucc_base_team_t *team, uint16_t id) {
void ucc_tl_ucp_service_update_id(ucc_base_team_t *team, uint16_t id)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);

tl_team->super.super.params.id = id;
Expand Down
Loading