Skip to content

Commit 30b20f6

Browse files
committed
TL/CUDA: fix triggered post a2a
1 parent de3d462 commit 30b20f6

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

src/components/tl/cuda/alltoall/alltoall_ce.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ ucc_status_t ucc_tl_cuda_alltoall_ce_init(ucc_tl_cuda_task_t *task)
3838
size_t data_len;
3939
int i;
4040

41-
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
42-
4341
task->alltoallv_ce.get_size = ucc_tl_cuda_alltoall_get_size;
4442
task->alltoallv_ce.get_offset = ucc_tl_cuda_alltoall_get_offset;
4543
task->alltoallv_ce.sdt = args->src.info.datatype;
@@ -70,6 +68,8 @@ ucc_status_t ucc_tl_cuda_alltoall_ce_init(ucc_tl_cuda_task_t *task)
7068

7169
if (lib->cfg.alltoall_use_copy_engine) {
7270
ucc_debug("ucc_tl_cuda_alltoallv_ce_init: copy engine");
71+
task->super.triggered_post = ucc_tl_cuda_alltoallv_ce_triggered_post;
72+
7373
task->alltoallv_ce.copy_post = cuda_copy_post;
7474
task->alltoallv_ce.evtCompletions = (cudaEvent_t*)ucc_malloc(team->num_streams * sizeof(cudaEvent_t), "alltoallv_ce.evtCompletions");
7575
for (i = 0; i < team->num_streams; i++) {

src/components/tl/cuda/alltoallv/alltoallv.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,6 @@ ucc_status_t ee_copy_post(void *dst, void *src, size_t len,
6262
ucc_ee_executor_t *executor,
6363
ucc_ee_executor_task_t **task, cudaStream_t stream);
6464

65+
ucc_status_t ucc_tl_cuda_alltoallv_ce_triggered_post(ucc_ee_h ee, ucc_ev_t *ev,
66+
ucc_coll_task_t *coll_task);
6567
#endif

src/components/tl/cuda/alltoallv/alltoallv_ce.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ size_t ucc_tl_cuda_alltoallv_get_offset(const ucc_tl_cuda_task_t *task,
562562

563563
//NOLINTNEXTLINE(misc-unused-parameters): ev parameter unused as it's not needed for this implementation
564564
ucc_status_t ucc_tl_cuda_alltoallv_ce_triggered_post(ucc_ee_h ee, ucc_ev_t *ev,
565-
ucc_coll_task_t *coll_task)
565+
ucc_coll_task_t *coll_task)
566566
{
567567
ucc_tl_cuda_task_t *task = ucc_derived_of(coll_task, ucc_tl_cuda_task_t);
568568
ucc_status_t status;

0 commit comments

Comments
 (0)