Skip to content

Commit

Permalink
Pass the memory handler and filter config to exclude cuda transport
Browse files Browse the repository at this point in the history
  • Loading branch information
tvegas1 authored and bureddy committed Nov 16, 2023
1 parent 81c41e6 commit a0432ad
Showing 1 changed file with 58 additions and 6 deletions.
64 changes: 58 additions & 6 deletions src/ucx_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
} while(0)

NCCL_PARAM(UCXDisable, "UCX_DISABLE", 0);
/* Exclude cuda-related UCX transports */
NCCL_PARAM(UCXCudaDisable, "UCX_CUDA_DISABLE", 1);

extern ncclDebugLogger_t pluginLogFunction;
static const ucp_tag_t tag = 0x8a000000;
Expand Down Expand Up @@ -212,15 +214,59 @@ static void recv_handler_nbx(void *request, ucs_status_t status,
static union ncclSocketAddress nccl_ucx_if_addr;
static char if_name[MAX_IF_NAME_SIZE];

static ncclResult_t ucx_config_no_cuda(ucp_config_t *config) {
char tmp[PATH_MAX];
const char *ucx_tls;
ssize_t n;

ucx_tls = getenv("NCCL_UCX_TLS");
if (ucx_tls == NULL) {
ucx_tls = getenv("UCX_TLS");
}

if (ucx_tls == NULL) {
ucx_tls = "^cuda";
} else if (ucx_tls[0] == '^') {
/* Negative expression, make sure to keep cuda excluded */
n = snprintf(tmp, sizeof(tmp), "^cuda,%s", &ucx_tls[1]);
if (n >= sizeof(tmp)) {
return ncclInternalError;
}

ucx_tls = tmp;
} else {
/* Positive expression cannot allow cuda-like transports */
if ((strstr(ucx_tls, "cuda") != NULL) || (strstr(ucx_tls, "gdr") != NULL)) {
WARN("Cannot use cuda/gdr transports as part of specified UCX_TLS");
return ncclInternalError;
}
}

UCXCHECK(ucp_config_modify(config, "TLS", ucx_tls));
UCXCHECK(ucp_config_modify(config, "RNDV_THRESH", "0"));
UCXCHECK(ucp_config_modify(config, "RNDV_SCHEME", "get_zcopy"));
UCXCHECK(
ucp_config_modify(config, "MEMTYPE_REG_WHOLE_ALLOC_TYPES", "unknown"));
return ncclSuccess;
}

static ncclResult_t ucx_init_context(ucp_context_h *ctx, int dev) {
ucp_params_t ucp_params;
ucp_config_t *config;
char ucx_dev_name[PATH_MAX];
ncclResult_t result;

snprintf(ucx_dev_name, PATH_MAX, "%s:%d", ncclIbDevs[dev].devName, ncclIbDevs[dev].port);
UCXCHECK(ucp_config_read("NCCL", NULL, &config));
UCXCHECK(ucp_config_modify(config, "NET_DEVICES", ucx_dev_name));

if (ncclParamUCXCudaDisable()) {
result = ucx_config_no_cuda(config);
if (result != ncclSuccess) {
return result;
}
}

memset(&ucp_params, 0, sizeof(ucp_params));
ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES;
ucp_params.features = UCP_FEATURE_TAG | UCP_FEATURE_RMA;
Expand Down Expand Up @@ -276,6 +322,8 @@ static ncclResult_t ucx_get_ctx_and_worker(int dev, ucp_context_h *ctx,
nccl_ucx_worker_t **ucx_worker,
ucp_tag_t *newtag) {
pthread_mutex_lock(&nccl_ucx_lock);
ncclResult_t result;

if (ncclNIbDevs <= dev) {
WARN("Device index is too large");
goto err;
Expand All @@ -300,7 +348,10 @@ static ncclResult_t ucx_get_ctx_and_worker(int dev, ucp_context_h *ctx,
w->thread = pthread_self();
w->count = 0;

ucx_init_context(&w->ctx, dev);
result = ucx_init_context(&w->ctx, dev);
if (result != ncclSuccess) {
return result;
}
ucx_init_worker(w->ctx, &w->worker);
worker_count++;

Expand Down Expand Up @@ -771,10 +822,11 @@ static ncclResult_t nccl_ucx_isend(void *send_comm, void *data, int size,
params.cb.send = send_handler_nbx;
params.user_data = &req->pending;
if (mh) {
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.memory_type = mh->mem_type;
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH;
params.memh = mh->ucp_memh;
}


ucp_req = ucp_tag_send_nbx(comm->ep, data, size,
nccl_ucx_ucp_tag(comm->tag, tag), &params);
if (UCS_PTR_IS_ERR(ucp_req)) {
Expand Down Expand Up @@ -826,10 +878,10 @@ static ncclResult_t nccl_ucx_irecv(void *recv_comm, int n, void **data,
ucx_request_add(req, sizes[i]);

if (mh[i]) {
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.memory_type = mh[i]->mem_type;
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH;
params.memh = mh[i]->ucp_memh;
} else {
params.op_attr_mask &= ~UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.op_attr_mask &= ~UCP_OP_ATTR_FIELD_MEMH;
}

ucp_req = ucp_tag_recv_nbx(comm->ucx_worker->worker, data[i], sizes[i],
Expand Down

0 comments on commit a0432ad

Please sign in to comment.