diff --git a/src/components/tl/nccl/allgatherv/allgatherv.c b/src/components/tl/nccl/allgatherv/allgatherv.c index 38721f1ce6d..141faa50ac1 100644 --- a/src/components/tl/nccl/allgatherv/allgatherv.c +++ b/src/components/tl/nccl/allgatherv/allgatherv.c @@ -48,7 +48,7 @@ ucc_base_coll_alg_info_t } \ } while(0) -ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task) +ucc_status_t ucc_tl_nccl_allgatherv_p2p_start_gpu(ucc_coll_task_t *coll_task) { ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); ucc_coll_args_t *args = &TASK_ARGS(task); @@ -93,6 +93,244 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task) return status; } +typedef struct { + void *cpu_sbuf; + void *staged_sbuf; + uintptr_t sbuf_len; + + int first_peer_rank; + void *first_peer_cpu_rbuf; + uintptr_t first_peer_len; + + int last_peer_rank; + uintptr_t last_peer_len; +} window_bounds_t; + +#define MIN(a, b) (((a) < (b)) ? (a) : (b)) +#define MAX(a, b) (((a) > (b)) ? (a) : (b)) + +static void find_window_bounds(ucc_coll_task_t *coll_task, int round, window_bounds_t *win) +{ + ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_tl_nccl_team_t *team = TASK_TEAM(task); + size_t sdt_size = ucc_dt_size(args->src.info.datatype); + size_t rdt_size = ucc_dt_size(args->dst.info_v.datatype); + + /* initialize variables, so we don't accidentally use garbage + * values */ + win->cpu_sbuf = NULL; + win->staged_sbuf = NULL; + win->sbuf_len = 0; + win->first_peer_rank = -1; + win->first_peer_cpu_rbuf = NULL; + win->first_peer_len = 0; + win->last_peer_rank = -1; + win->last_peer_len = 0; + + uintptr_t window_start = round * UCC_TL_NCCL_SCRATCH_BUF_SIZE; + uintptr_t window_end = window_start + UCC_TL_NCCL_SCRATCH_BUF_SIZE; + + + /* sbuf setup */ + uintptr_t sbuf_start = 0; + for (int peer = 0; peer < UCC_TL_TEAM_RANK(team); peer++) { + sbuf_start += ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; + } + uintptr_t sbuf_end = sbuf_start + args->src.info.count * sdt_size; + + if (sbuf_end > window_start && sbuf_start < window_end) { + uintptr_t sbuf_offset = 0; + if (sbuf_start < window_start) { + sbuf_offset = window_start - sbuf_start; + } + + win->cpu_sbuf = PTR_OFFSET(args->src.info.buffer, sbuf_offset); + if (sbuf_start <= window_start) { + win->staged_sbuf = task->cpu_coll_scratch_buf; + } else { + win->staged_sbuf = PTR_OFFSET(task->cpu_coll_scratch_buf, sbuf_start - window_start); + } + win->sbuf_len = MIN(sbuf_end, window_end) - MAX(sbuf_start, window_start); + } + + + /* rbuf setup */ + uintptr_t offset = 0; + int first_peer = 1; + for (int peer = 0; peer < UCC_TL_TEAM_SIZE(team); peer++) { + uintptr_t recv_size = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; + + if (recv_size == 0) { + continue; + } else if (offset + recv_size < window_start) { + offset += recv_size; + continue; + } else if (offset >= window_end) { + break; + } + + recv_size = MIN(offset + recv_size, window_end) - MAX(offset, window_start); + + if (first_peer) { + win->first_peer_rank = peer; + uintptr_t displ = ucc_coll_args_get_displacement(args, args->dst.info_v.displacements, peer); + win->first_peer_cpu_rbuf = PTR_OFFSET(args->dst.info_v.buffer, displ * rdt_size); + win->first_peer_len = recv_size; + + first_peer = 0; + } + + win->last_peer_rank = peer; + win->last_peer_len = recv_size; + + offset += recv_size; + } +} + +static void CUDART_CB cpu_allgatherv_copy_in(void *data) +{ + ucc_coll_task_t *coll_task = (ucc_coll_task_t *) data; + ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); + + window_bounds_t win; + find_window_bounds(coll_task, task->cpu_coll_round, &win); + + if (win.sbuf_len != 0) { + memcpy(win.staged_sbuf, win.cpu_sbuf, win.sbuf_len); + } +} + +static void CUDART_CB cpu_allgatherv_copy_out(void *data) +{ + ucc_coll_task_t *coll_task = (ucc_coll_task_t *) data; + ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); + ucc_tl_nccl_team_t *team = TASK_TEAM(task); + ucc_coll_args_t *args = &TASK_ARGS(task); + size_t rdt_size = ucc_dt_size(args->dst.info_v.datatype); + + window_bounds_t win; + find_window_bounds(coll_task, task->cpu_coll_round, &win); + + void *rbuf = task->cpu_coll_scratch_buf; + uintptr_t copied = 0; + for (int peer = win.first_peer_rank; peer <= win.last_peer_rank; peer++) { + uintptr_t recv_size; + + if (peer == win.first_peer_rank) { + memcpy(win.first_peer_cpu_rbuf, rbuf, win.first_peer_len); + copied += win.first_peer_len; + } else if (peer == win.last_peer_rank) { + size_t displ = ucc_coll_args_get_displacement(args, args->dst.info_v.displacements, peer); + memcpy(PTR_OFFSET(args->dst.info_v.buffer, displ * rdt_size), + PTR_OFFSET(task->cpu_coll_scratch_buf, copied), win.last_peer_len); + copied += win.last_peer_len; + } else { + uintptr_t copy_size = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; + size_t displ = ucc_coll_args_get_displacement(args, args->dst.info_v.displacements, peer); + memcpy(PTR_OFFSET(args->dst.info_v.buffer, displ * rdt_size), + PTR_OFFSET(task->cpu_coll_scratch_buf, copied), copy_size); + copied += copy_size; + } + } + + task->cpu_coll_round++; + + uintptr_t total_bytes = 0; + for (int peer = 0; peer < UCC_TL_TEAM_SIZE(team); peer++) { + total_bytes += ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; + } + int num_rounds = total_bytes / UCC_TL_NCCL_SCRATCH_BUF_SIZE + + !!(total_bytes % UCC_TL_NCCL_SCRATCH_BUF_SIZE); + + if (task->cpu_coll_round == num_rounds) { + ucc_mpool_put(task->cpu_coll_scratch_buf); + task->cpu_coll_scratch_buf = NULL; + } +} + +ucc_status_t ucc_tl_nccl_allgatherv_p2p_start_cpu(ucc_coll_task_t *coll_task) +{ + ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_tl_nccl_team_t *team = TASK_TEAM(task); + ucc_rank_t size = UCC_TL_TEAM_SIZE(team); + ucc_ee_h ee = coll_task->ee; + cudaStream_t stream = (ee) ? (cudaStream_t) ee->ee_context : + team->stream; + ucc_status_t status = UCC_OK; + void *sbuf; + size_t sdt_size, rdt_size, count, displ; + size_t sbuf_size; + ucc_rank_t peer; + + task->super.status = UCC_INPROGRESS; + sdt_size = ucc_dt_size(args->src.info.datatype); + rdt_size = ucc_dt_size(args->dst.info_v.datatype); + UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0); + + ucc_tl_nccl_context_t *ctx = TASK_CTX(task); + task->cpu_coll_scratch_buf = ucc_mpool_get(&ctx->cpu_staging_scratch_mp); + if (ucc_unlikely(!task->cpu_coll_scratch_buf)) { + status = UCC_ERR_NO_MEMORY; + goto exit_coll; + } + task->cpu_coll_round = 0; + + uintptr_t total_bytes = 0; + for (peer = 0; peer < size; peer++) { + total_bytes += ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; + } + int num_rounds = total_bytes / UCC_TL_NCCL_SCRATCH_BUF_SIZE + + !!(total_bytes % UCC_TL_NCCL_SCRATCH_BUF_SIZE); + + for (int i = 0; i < num_rounds; i++) { + if (args->src.info.count != 0) { + NCCLCHECK_GOTO(cudaLaunchHostFunc(stream, cpu_allgatherv_copy_in, (void *) coll_task), + exit_coll, status, UCC_TL_TEAM_LIB(team)); + } + + window_bounds_t win; + find_window_bounds(coll_task, i, &win); + + NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + if (win.sbuf_len != 0) { + for (peer = 0; peer < size; peer++) { + NCCLCHECK_GOTO(ncclSend(win.staged_sbuf, win.sbuf_len, ncclChar, peer, + team->nccl_comm, stream), + exit_coll, status, UCC_TL_TEAM_LIB(team)); + } + } + + uintptr_t offset = 0; + for (peer = win.first_peer_rank; peer <= win.last_peer_rank; peer++) { + uintptr_t recv_size; + + if (peer == win.first_peer_rank) { + recv_size = win.first_peer_len; + } else if (peer == win.last_peer_rank) { + recv_size = win.last_peer_len; + } else { + recv_size = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; + } + + NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(task->cpu_coll_scratch_buf, offset), + recv_size, ncclChar, peer, team->nccl_comm, stream), + exit_coll, status, UCC_TL_TEAM_LIB(team)); + + offset += recv_size; + } + NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); + + NCCLCHECK_GOTO(cudaLaunchHostFunc(stream, cpu_allgatherv_copy_out, (void *) coll_task), + exit_coll, status, UCC_TL_TEAM_LIB(team)); + } + + status = ucc_tl_nccl_collective_sync(task, stream); +exit_coll: + return status; +} + ucc_status_t ucc_tl_nccl_allgatherv_p2p_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t * team, ucc_coll_task_t ** task_h) @@ -108,7 +346,11 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_init(ucc_base_coll_args_t *coll_args, if (!task) { return UCC_ERR_NO_MESSAGE; } - task->super.post = ucc_tl_nccl_allgatherv_p2p_start; + if (args->src.info.mem_type == UCC_MEMORY_TYPE_HOST) { + task->super.post = ucc_tl_nccl_allgatherv_p2p_start_cpu; + } else { + task->super.post = ucc_tl_nccl_allgatherv_p2p_start_gpu; + } *task_h = &task->super; out: return status; diff --git a/src/components/tl/nccl/allgatherv/allgatherv.h b/src/components/tl/nccl/allgatherv/allgatherv.h index f5741d6739d..a982568f19f 100644 --- a/src/components/tl/nccl/allgatherv/allgatherv.h +++ b/src/components/tl/nccl/allgatherv/allgatherv.h @@ -22,7 +22,8 @@ enum { extern ucc_base_coll_alg_info_t ucc_tl_nccl_allgatherv_algs[UCC_TL_NCCL_ALLGATHERV_ALG_LAST + 1]; -ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task); +ucc_status_t ucc_tl_nccl_allgatherv_p2p_start_cpu(ucc_coll_task_t *coll_task); +ucc_status_t ucc_tl_nccl_allgatherv_p2p_start_gpu(ucc_coll_task_t *coll_task); ucc_status_t ucc_tl_nccl_allgatherv_p2p_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t * team, diff --git a/src/components/tl/nccl/tl_nccl.h b/src/components/tl/nccl/tl_nccl.h index 18152739d00..8ac06ac47c3 100644 --- a/src/components/tl/nccl/tl_nccl.h +++ b/src/components/tl/nccl/tl_nccl.h @@ -37,6 +37,8 @@ #include "utils/profile/ucc_profile_off.h" #endif +#define UCC_TL_NCCL_SCRATCH_BUF_SIZE (1024 * 1024) + #define UCC_TL_NCCL_PROFILE_FUNC UCC_PROFILE_FUNC #define UCC_TL_NCCL_PROFILE_FUNC_VOID UCC_PROFILE_FUNC_VOID #define UCC_TL_NCCL_PROFILE_REQUEST_NEW UCC_PROFILE_REQUEST_NEW @@ -75,7 +77,8 @@ typedef struct ucc_tl_nccl_context { ucc_tl_context_t super; ucc_tl_nccl_context_config_t cfg; ucc_mpool_t req_mp; - void *scratch_buf; + ucc_mpool_t cpu_staging_scratch_mp; + void *barrier_scratch; } ucc_tl_nccl_context_t; UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *, const ucc_base_config_t *); @@ -93,6 +96,8 @@ typedef struct ucc_tl_nccl_task { ucc_status_t host_status; ucc_status_t *dev_status; void *completed; + void *cpu_coll_scratch_buf; + int cpu_coll_round; union { struct { ucc_mc_buffer_header_t *scratch; diff --git a/src/components/tl/nccl/tl_nccl_coll.c b/src/components/tl/nccl/tl_nccl_coll.c index 796bb05d7e5..4ec903e4d1e 100644 --- a/src/components/tl/nccl/tl_nccl_coll.c +++ b/src/components/tl/nccl/tl_nccl_coll.c @@ -106,6 +106,7 @@ ucc_tl_nccl_task_t * ucc_tl_nccl_init_task(ucc_base_coll_args_t *coll_args, task->super.finalize = ucc_tl_nccl_coll_finalize; task->super.triggered_post = ucc_tl_nccl_triggered_post; task->completed = NULL; + task->cpu_coll_scratch_buf = NULL; if (nccl_ctx->cfg.sync_type == UCC_TL_NCCL_COMPLETION_SYNC_TYPE_EVENT) { status = ucc_ec_create_event(&task->completed, UCC_EE_CUDA_STREAM); if (ucc_unlikely(status != UCC_OK)) { @@ -417,10 +418,55 @@ ucc_status_t ucc_tl_nccl_allgatherv_init(ucc_tl_nccl_task_t *task) tl_error(UCC_TASK_LIB(task), "user defined datatype is not supported"); return UCC_ERR_NOT_SUPPORTED; } - task->super.post = ucc_tl_nccl_allgatherv_p2p_start; + + ucc_coll_args_t *args = &TASK_ARGS(task); + if (args->src.info.mem_type == UCC_MEMORY_TYPE_HOST) { + task->super.post = ucc_tl_nccl_allgatherv_p2p_start_cpu; + } else { + task->super.post = ucc_tl_nccl_allgatherv_p2p_start_gpu; + } + return UCC_OK; } +static void CUDART_CB cpu_bcast_copy_in(void *data) +{ + ucc_coll_task_t *coll_task = (ucc_coll_task_t *) data; + ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + uintptr_t total_bytes = args->src.info.count * ucc_dt_size(args->src.info.datatype); + uintptr_t completed = task->cpu_coll_round * UCC_TL_NCCL_SCRATCH_BUF_SIZE; + + uintptr_t rem_bytes = total_bytes - completed; + if (rem_bytes > UCC_TL_NCCL_SCRATCH_BUF_SIZE) { + rem_bytes = UCC_TL_NCCL_SCRATCH_BUF_SIZE; + } + + memcpy(task->cpu_coll_scratch_buf, PTR_OFFSET(args->src.info.buffer, completed), rem_bytes); +} + +static void CUDART_CB cpu_bcast_copy_out(void *data) +{ + ucc_coll_task_t *coll_task = (ucc_coll_task_t *) data; + ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); + ucc_coll_args_t *args = &TASK_ARGS(task); + uintptr_t total_bytes = args->src.info.count * ucc_dt_size(args->src.info.datatype); + uintptr_t completed = task->cpu_coll_round * UCC_TL_NCCL_SCRATCH_BUF_SIZE; + + uintptr_t rem_bytes = total_bytes - completed; + if (rem_bytes > UCC_TL_NCCL_SCRATCH_BUF_SIZE) { + rem_bytes = UCC_TL_NCCL_SCRATCH_BUF_SIZE; + } + + memcpy(PTR_OFFSET(args->src.info.buffer, completed), task->cpu_coll_scratch_buf, rem_bytes); + task->cpu_coll_round++; + + if (completed + rem_bytes == total_bytes) { + ucc_mpool_put(task->cpu_coll_scratch_buf); + task->cpu_coll_scratch_buf = NULL; + } +} + ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task) { ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); @@ -464,9 +510,39 @@ ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task) exit_coll, status, UCC_TL_TEAM_LIB(team)); } } else { - NCCLCHECK_GOTO(ncclBroadcast(src, src, count, dt, root, team->nccl_comm, - stream), - exit_coll, status, UCC_TL_TEAM_LIB(team)); + if (args->src.info.mem_type == UCC_MEMORY_TYPE_HOST) { + ucc_tl_nccl_context_t *ctx = TASK_CTX(task); + task->cpu_coll_scratch_buf = ucc_mpool_get(&ctx->cpu_staging_scratch_mp); + if (ucc_unlikely(!task->cpu_coll_scratch_buf)) { + status = UCC_ERR_NO_MEMORY; + goto exit_coll; + } + task->cpu_coll_round = 0; + + uintptr_t total_bytes = count * ucc_dt_size(args->src.info.datatype); + int num_rounds = total_bytes / UCC_TL_NCCL_SCRATCH_BUF_SIZE + + !!(total_bytes % UCC_TL_NCCL_SCRATCH_BUF_SIZE); + + for (int i = 0; i < num_rounds; i++) { + if (UCC_TL_TEAM_RANK(team) == root) { + NCCLCHECK_GOTO(cudaLaunchHostFunc(stream, cpu_bcast_copy_in, (void *) coll_task), + exit_coll, status, UCC_TL_TEAM_LIB(team)); + } + + NCCLCHECK_GOTO(ncclBroadcast(task->cpu_coll_scratch_buf, task->cpu_coll_scratch_buf, count, dt, + root, team->nccl_comm, stream), + exit_coll, status, UCC_TL_TEAM_LIB(team)); + + if (UCC_TL_TEAM_RANK(team) != args->root) { + NCCLCHECK_GOTO(cudaLaunchHostFunc(stream, cpu_bcast_copy_out, (void *) coll_task), + exit_coll, status, UCC_TL_TEAM_LIB(team)); + } + } + } else { + NCCLCHECK_GOTO(ncclBroadcast(src, src, count, dt, root, team->nccl_comm, + stream), + exit_coll, status, UCC_TL_TEAM_LIB(team)); + } } status = ucc_tl_nccl_collective_sync(task, stream); exit_coll: @@ -601,7 +677,7 @@ ucc_status_t ucc_tl_nccl_barrier_init(ucc_tl_nccl_task_t *task) args->flags |= UCC_COLL_ARGS_FLAG_IN_PLACE; args->op = UCC_OP_SUM; - args->dst.info.buffer = TASK_CTX(task)->scratch_buf; + args->dst.info.buffer = TASK_CTX(task)->barrier_scratch; args->src.info.buffer = args->dst.info.buffer; args->dst.info.datatype = args->src.info.datatype = UCC_DT_FLOAT32; args->dst.info.count = args->src.info.count = 1; @@ -966,26 +1042,30 @@ ucc_status_t ucc_tl_nccl_alg_id_to_init(int alg_id, const char *alg_id_str, alg_id = alg_id_from_str(coll_type, alg_id_str); } - switch (coll_type) { - case UCC_COLL_TYPE_ALLGATHERV: - switch (alg_id) { - case UCC_TL_NCCL_ALLGATHERV_ALG_P2P: - *init = ucc_tl_nccl_allgatherv_p2p_init; - break; - case UCC_TL_NCCL_ALLGATHERV_ALG_BCOPY: - *init = ucc_tl_nccl_allgatherv_bcopy_init; - break; - case UCC_TL_NCCL_ALLGATHERV_ALG_BCAST: - *init = ucc_tl_nccl_allgatherv_bcast_init; + if (mem_type == UCC_MEMORY_TYPE_HOST) { + *init = ucc_tl_nccl_allgatherv_p2p_init; + } else { + switch (coll_type) { + case UCC_COLL_TYPE_ALLGATHERV: + switch (alg_id) { + case UCC_TL_NCCL_ALLGATHERV_ALG_P2P: + *init = ucc_tl_nccl_allgatherv_p2p_init; + break; + case UCC_TL_NCCL_ALLGATHERV_ALG_BCOPY: + *init = ucc_tl_nccl_allgatherv_bcopy_init; + break; + case UCC_TL_NCCL_ALLGATHERV_ALG_BCAST: + *init = ucc_tl_nccl_allgatherv_bcast_init; + break; + default: + status = UCC_ERR_INVALID_PARAM; + break; + }; break; default: - status = UCC_ERR_INVALID_PARAM; + status = UCC_ERR_NOT_SUPPORTED; break; - }; - break; - default: - status = UCC_ERR_NOT_SUPPORTED; - break; + } } return status; } diff --git a/src/components/tl/nccl/tl_nccl_context.c b/src/components/tl/nccl/tl_nccl_context.c index 017878ba12a..58f25160dde 100644 --- a/src/components/tl/nccl/tl_nccl_context.c +++ b/src/components/tl/nccl/tl_nccl_context.c @@ -93,6 +93,33 @@ static ucc_mpool_ops_t ucc_tl_nccl_req_mapped_mpool_ops = { .obj_cleanup = NULL }; +static ucc_status_t ucc_tl_nccl_managed_mpool_chunk_malloc(ucc_mpool_t *mp, + size_t *size_p, + void ** chunk_p) +{ + cudaError_t cu_st; + + cu_st = cudaMallocManaged((void**)chunk_p, *size_p, cudaMemAttachGlobal); + if (cu_st != cudaSuccess) { + return UCC_ERR_NO_MEMORY; + } + + return UCC_OK; +} + +static void ucc_tl_nccl_managed_mpool_chunk_free(ucc_mpool_t *mp, + void *chunk) +{ + cudaFree(chunk); +} + +static ucc_mpool_ops_t ucc_tl_nccl_managed_mpool_ops = { + .chunk_alloc = ucc_tl_nccl_managed_mpool_chunk_malloc, + .chunk_release = ucc_tl_nccl_managed_mpool_chunk_free, + .obj_init = NULL, + .obj_cleanup = NULL +}; + UCC_CLASS_INIT_FUNC(ucc_tl_nccl_context_t, const ucc_base_context_params_t *params, const ucc_base_config_t *config) @@ -148,11 +175,23 @@ UCC_CLASS_INIT_FUNC(ucc_tl_nccl_context_t, "failed to initialize tl_nccl_req mpool"); return status; } + // scratch buffer for barrier - cudaError_t cuda_st = cudaMalloc(&self->scratch_buf, sizeof(float)); + cudaError_t cuda_st = cudaMalloc(&self->barrier_scratch, sizeof(float)); if (cuda_st != cudaSuccess) { return UCC_ERR_NO_MEMORY; } + + // scratch buffer for other collectives + status = ucc_mpool_init(&self->cpu_staging_scratch_mp, 0, UCC_TL_NCCL_SCRATCH_BUF_SIZE, + 0, 1, 1, UINT_MAX, &ucc_tl_nccl_managed_mpool_ops, + params->thread_mode, "tl_nccl_managed_mp"); + if (status != UCC_OK) { + tl_error(self->super.super.lib, + "failed to initialize tl_nccl_managed mpool"); + return status; + } + tl_info(self->super.super.lib, "initialized tl context: %p", self); return UCC_OK; } @@ -161,8 +200,9 @@ UCC_CLASS_CLEANUP_FUNC(ucc_tl_nccl_context_t) { tl_info(self->super.super.lib, "finalizing tl context: %p", self); ucc_mpool_cleanup(&self->req_mp, 1); - cudaFree(self->scratch_buf); - self->scratch_buf = NULL; + cudaFree(self->barrier_scratch); + self->barrier_scratch = NULL; + ucc_mpool_cleanup(&self->cpu_staging_scratch_mp, 1); } UCC_CLASS_DEFINE(ucc_tl_nccl_context_t, ucc_tl_context_t); diff --git a/src/components/tl/nccl/tl_nccl_team.c b/src/components/tl/nccl/tl_nccl_team.c index d03af20a186..148841caf8a 100644 --- a/src/components/tl/nccl/tl_nccl_team.c +++ b/src/components/tl/nccl/tl_nccl_team.c @@ -224,8 +224,9 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team, } } - // add barrier, which might be triggered from host memory type - // use lower score + /******************************************************************************/ + /* Add CPU collectives at a lower priority */ + /******************************************************************************/ status = ucc_coll_score_add_range(score, UCC_COLL_TYPE_BARRIER, UCC_MEMORY_TYPE_HOST, 0, UCC_MSG_MAX, 1, ucc_tl_nccl_coll_init, tl_team); @@ -233,6 +234,21 @@ ucc_status_t ucc_tl_nccl_team_get_scores(ucc_base_team_t *tl_team, return status; } + status = ucc_coll_score_add_range(score, UCC_COLL_TYPE_BCAST, + UCC_MEMORY_TYPE_HOST, 0, UCC_MSG_MAX, 1, + ucc_tl_nccl_coll_init, tl_team); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + status = ucc_coll_score_add_range(score, UCC_COLL_TYPE_ALLGATHERV, + UCC_MEMORY_TYPE_HOST, 0, UCC_MSG_MAX, 1, + ucc_tl_nccl_coll_init, tl_team); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + + if (strlen(ctx->score_str) > 0) { status = ucc_coll_score_update_from_str( ctx->score_str, score, UCC_TL_TEAM_SIZE(team),