-
Notifications
You must be signed in to change notification settings - Fork 116
Add some CPU collectives to the NCCL TL #570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,7 @@ ucc_base_coll_alg_info_t | |
} \ | ||
} while(0) | ||
|
||
ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task) | ||
ucc_status_t ucc_tl_nccl_allgatherv_p2p_start_gpu(ucc_coll_task_t *coll_task) | ||
{ | ||
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); | ||
ucc_coll_args_t *args = &TASK_ARGS(task); | ||
|
@@ -93,6 +93,244 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task) | |
return status; | ||
} | ||
|
||
typedef struct { | ||
void *cpu_sbuf; | ||
void *staged_sbuf; | ||
uintptr_t sbuf_len; | ||
|
||
int first_peer_rank; | ||
void *first_peer_cpu_rbuf; | ||
uintptr_t first_peer_len; | ||
|
||
int last_peer_rank; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. int -> ucc_rank_t |
||
uintptr_t last_peer_len; | ||
} window_bounds_t; | ||
|
||
#define MIN(a, b) (((a) < (b)) ? (a) : (b)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have ucc_min and ucc_max in ucc_math.h |
||
#define MAX(a, b) (((a) > (b)) ? (a) : (b)) | ||
|
||
static void find_window_bounds(ucc_coll_task_t *coll_task, int round, window_bounds_t *win) | ||
{ | ||
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); | ||
ucc_coll_args_t *args = &TASK_ARGS(task); | ||
ucc_tl_nccl_team_t *team = TASK_TEAM(task); | ||
size_t sdt_size = ucc_dt_size(args->src.info.datatype); | ||
size_t rdt_size = ucc_dt_size(args->dst.info_v.datatype); | ||
|
||
/* initialize variables, so we don't accidentally use garbage | ||
* values */ | ||
win->cpu_sbuf = NULL; | ||
win->staged_sbuf = NULL; | ||
win->sbuf_len = 0; | ||
win->first_peer_rank = -1; | ||
win->first_peer_cpu_rbuf = NULL; | ||
win->first_peer_len = 0; | ||
win->last_peer_rank = -1; | ||
win->last_peer_len = 0; | ||
|
||
uintptr_t window_start = round * UCC_TL_NCCL_SCRATCH_BUF_SIZE; | ||
uintptr_t window_end = window_start + UCC_TL_NCCL_SCRATCH_BUF_SIZE; | ||
|
||
|
||
/* sbuf setup */ | ||
uintptr_t sbuf_start = 0; | ||
for (int peer = 0; peer < UCC_TL_TEAM_RANK(team); peer++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ucc_rank_t peer. plz move declaration in the beginning |
||
sbuf_start += ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; | ||
} | ||
uintptr_t sbuf_end = sbuf_start + args->src.info.count * sdt_size; | ||
|
||
if (sbuf_end > window_start && sbuf_start < window_end) { | ||
uintptr_t sbuf_offset = 0; | ||
if (sbuf_start < window_start) { | ||
sbuf_offset = window_start - sbuf_start; | ||
} | ||
|
||
win->cpu_sbuf = PTR_OFFSET(args->src.info.buffer, sbuf_offset); | ||
if (sbuf_start <= window_start) { | ||
win->staged_sbuf = task->cpu_coll_scratch_buf; | ||
} else { | ||
win->staged_sbuf = PTR_OFFSET(task->cpu_coll_scratch_buf, sbuf_start - window_start); | ||
} | ||
win->sbuf_len = MIN(sbuf_end, window_end) - MAX(sbuf_start, window_start); | ||
} | ||
|
||
|
||
/* rbuf setup */ | ||
uintptr_t offset = 0; | ||
int first_peer = 1; | ||
for (int peer = 0; peer < UCC_TL_TEAM_SIZE(team); peer++) { | ||
uintptr_t recv_size = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; | ||
|
||
if (recv_size == 0) { | ||
continue; | ||
} else if (offset + recv_size < window_start) { | ||
offset += recv_size; | ||
continue; | ||
} else if (offset >= window_end) { | ||
break; | ||
} | ||
|
||
recv_size = MIN(offset + recv_size, window_end) - MAX(offset, window_start); | ||
|
||
if (first_peer) { | ||
win->first_peer_rank = peer; | ||
uintptr_t displ = ucc_coll_args_get_displacement(args, args->dst.info_v.displacements, peer); | ||
win->first_peer_cpu_rbuf = PTR_OFFSET(args->dst.info_v.buffer, displ * rdt_size); | ||
win->first_peer_len = recv_size; | ||
|
||
first_peer = 0; | ||
} | ||
|
||
win->last_peer_rank = peer; | ||
win->last_peer_len = recv_size; | ||
|
||
offset += recv_size; | ||
} | ||
} | ||
|
||
static void CUDART_CB cpu_allgatherv_copy_in(void *data) | ||
{ | ||
ucc_coll_task_t *coll_task = (ucc_coll_task_t *) data; | ||
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. window_bounds_t win; goes together with variables decl and empty line goes after. |
||
window_bounds_t win; | ||
find_window_bounds(coll_task, task->cpu_coll_round, &win); | ||
|
||
if (win.sbuf_len != 0) { | ||
memcpy(win.staged_sbuf, win.cpu_sbuf, win.sbuf_len); | ||
} | ||
} | ||
|
||
static void CUDART_CB cpu_allgatherv_copy_out(void *data) | ||
{ | ||
ucc_coll_task_t *coll_task = (ucc_coll_task_t *) data; | ||
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); | ||
ucc_tl_nccl_team_t *team = TASK_TEAM(task); | ||
ucc_coll_args_t *args = &TASK_ARGS(task); | ||
size_t rdt_size = ucc_dt_size(args->dst.info_v.datatype); | ||
|
||
window_bounds_t win; | ||
find_window_bounds(coll_task, task->cpu_coll_round, &win); | ||
|
||
void *rbuf = task->cpu_coll_scratch_buf; | ||
uintptr_t copied = 0; | ||
for (int peer = win.first_peer_rank; peer <= win.last_peer_rank; peer++) { | ||
uintptr_t recv_size; | ||
|
||
if (peer == win.first_peer_rank) { | ||
memcpy(win.first_peer_cpu_rbuf, rbuf, win.first_peer_len); | ||
copied += win.first_peer_len; | ||
} else if (peer == win.last_peer_rank) { | ||
size_t displ = ucc_coll_args_get_displacement(args, args->dst.info_v.displacements, peer); | ||
memcpy(PTR_OFFSET(args->dst.info_v.buffer, displ * rdt_size), | ||
PTR_OFFSET(task->cpu_coll_scratch_buf, copied), win.last_peer_len); | ||
copied += win.last_peer_len; | ||
} else { | ||
uintptr_t copy_size = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; | ||
size_t displ = ucc_coll_args_get_displacement(args, args->dst.info_v.displacements, peer); | ||
memcpy(PTR_OFFSET(args->dst.info_v.buffer, displ * rdt_size), | ||
PTR_OFFSET(task->cpu_coll_scratch_buf, copied), copy_size); | ||
copied += copy_size; | ||
} | ||
} | ||
|
||
task->cpu_coll_round++; | ||
|
||
uintptr_t total_bytes = 0; | ||
for (int peer = 0; peer < UCC_TL_TEAM_SIZE(team); peer++) { | ||
total_bytes += ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; | ||
} | ||
int num_rounds = total_bytes / UCC_TL_NCCL_SCRATCH_BUF_SIZE + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i would cache num_rounds on task as well to avoid repeated re-calculation |
||
!!(total_bytes % UCC_TL_NCCL_SCRATCH_BUF_SIZE); | ||
|
||
if (task->cpu_coll_round == num_rounds) { | ||
ucc_mpool_put(task->cpu_coll_scratch_buf); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not mpool_put is done in collective_finalize? |
||
task->cpu_coll_scratch_buf = NULL; | ||
} | ||
} | ||
|
||
ucc_status_t ucc_tl_nccl_allgatherv_p2p_start_cpu(ucc_coll_task_t *coll_task) | ||
{ | ||
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t); | ||
ucc_coll_args_t *args = &TASK_ARGS(task); | ||
ucc_tl_nccl_team_t *team = TASK_TEAM(task); | ||
ucc_rank_t size = UCC_TL_TEAM_SIZE(team); | ||
ucc_ee_h ee = coll_task->ee; | ||
cudaStream_t stream = (ee) ? (cudaStream_t) ee->ee_context : | ||
team->stream; | ||
ucc_status_t status = UCC_OK; | ||
void *sbuf; | ||
size_t sdt_size, rdt_size, count, displ; | ||
size_t sbuf_size; | ||
ucc_rank_t peer; | ||
|
||
task->super.status = UCC_INPROGRESS; | ||
sdt_size = ucc_dt_size(args->src.info.datatype); | ||
rdt_size = ucc_dt_size(args->dst.info_v.datatype); | ||
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0); | ||
|
||
ucc_tl_nccl_context_t *ctx = TASK_CTX(task); | ||
task->cpu_coll_scratch_buf = ucc_mpool_get(&ctx->cpu_staging_scratch_mp); | ||
if (ucc_unlikely(!task->cpu_coll_scratch_buf)) { | ||
status = UCC_ERR_NO_MEMORY; | ||
goto exit_coll; | ||
} | ||
task->cpu_coll_round = 0; | ||
|
||
uintptr_t total_bytes = 0; | ||
for (peer = 0; peer < size; peer++) { | ||
total_bytes += ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; | ||
} | ||
int num_rounds = total_bytes / UCC_TL_NCCL_SCRATCH_BUF_SIZE + | ||
!!(total_bytes % UCC_TL_NCCL_SCRATCH_BUF_SIZE); | ||
|
||
for (int i = 0; i < num_rounds; i++) { | ||
if (args->src.info.count != 0) { | ||
NCCLCHECK_GOTO(cudaLaunchHostFunc(stream, cpu_allgatherv_copy_in, (void *) coll_task), | ||
exit_coll, status, UCC_TL_TEAM_LIB(team)); | ||
} | ||
|
||
window_bounds_t win; | ||
find_window_bounds(coll_task, i, &win); | ||
|
||
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team)); | ||
if (win.sbuf_len != 0) { | ||
for (peer = 0; peer < size; peer++) { | ||
NCCLCHECK_GOTO(ncclSend(win.staged_sbuf, win.sbuf_len, ncclChar, peer, | ||
team->nccl_comm, stream), | ||
exit_coll, status, UCC_TL_TEAM_LIB(team)); | ||
} | ||
} | ||
|
||
uintptr_t offset = 0; | ||
for (peer = win.first_peer_rank; peer <= win.last_peer_rank; peer++) { | ||
uintptr_t recv_size; | ||
|
||
if (peer == win.first_peer_rank) { | ||
recv_size = win.first_peer_len; | ||
} else if (peer == win.last_peer_rank) { | ||
recv_size = win.last_peer_len; | ||
} else { | ||
recv_size = ucc_coll_args_get_count(args, args->dst.info_v.counts, peer) * rdt_size; | ||
} | ||
|
||
NCCLCHECK_GOTO(ncclRecv(PTR_OFFSET(task->cpu_coll_scratch_buf, offset), | ||
recv_size, ncclChar, peer, team->nccl_comm, stream), | ||
exit_coll, status, UCC_TL_TEAM_LIB(team)); | ||
|
||
offset += recv_size; | ||
} | ||
NCCLCHECK_GOTO(ncclGroupEnd(), exit_coll, status, UCC_TL_TEAM_LIB(team)); | ||
|
||
NCCLCHECK_GOTO(cudaLaunchHostFunc(stream, cpu_allgatherv_copy_out, (void *) coll_task), | ||
exit_coll, status, UCC_TL_TEAM_LIB(team)); | ||
} | ||
|
||
status = ucc_tl_nccl_collective_sync(task, stream); | ||
exit_coll: | ||
return status; | ||
} | ||
|
||
ucc_status_t ucc_tl_nccl_allgatherv_p2p_init(ucc_base_coll_args_t *coll_args, | ||
ucc_base_team_t * team, | ||
ucc_coll_task_t ** task_h) | ||
|
@@ -108,7 +346,11 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_init(ucc_base_coll_args_t *coll_args, | |
if (!task) { | ||
return UCC_ERR_NO_MESSAGE; | ||
} | ||
task->super.post = ucc_tl_nccl_allgatherv_p2p_start; | ||
if (args->src.info.mem_type == UCC_MEMORY_TYPE_HOST) { | ||
task->super.post = ucc_tl_nccl_allgatherv_p2p_start_cpu; | ||
} else { | ||
task->super.post = ucc_tl_nccl_allgatherv_p2p_start_gpu; | ||
} | ||
*task_h = &task->super; | ||
out: | ||
return status; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,8 @@ | |
#include "utils/profile/ucc_profile_off.h" | ||
#endif | ||
|
||
#define UCC_TL_NCCL_SCRATCH_BUF_SIZE (1024 * 1024) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not a parameter? |
||
|
||
#define UCC_TL_NCCL_PROFILE_FUNC UCC_PROFILE_FUNC | ||
#define UCC_TL_NCCL_PROFILE_FUNC_VOID UCC_PROFILE_FUNC_VOID | ||
#define UCC_TL_NCCL_PROFILE_REQUEST_NEW UCC_PROFILE_REQUEST_NEW | ||
|
@@ -75,7 +77,8 @@ typedef struct ucc_tl_nccl_context { | |
ucc_tl_context_t super; | ||
ucc_tl_nccl_context_config_t cfg; | ||
ucc_mpool_t req_mp; | ||
void *scratch_buf; | ||
ucc_mpool_t cpu_staging_scratch_mp; | ||
void *barrier_scratch; | ||
} ucc_tl_nccl_context_t; | ||
UCC_CLASS_DECLARE(ucc_tl_nccl_context_t, const ucc_base_context_params_t *, | ||
const ucc_base_config_t *); | ||
|
@@ -93,6 +96,8 @@ typedef struct ucc_tl_nccl_task { | |
ucc_status_t host_status; | ||
ucc_status_t *dev_status; | ||
void *completed; | ||
void *cpu_coll_scratch_buf; | ||
int cpu_coll_round; | ||
union { | ||
struct { | ||
ucc_mc_buffer_header_t *scratch; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remote empty lines.