From a0432ad7e0f4fcaa95cab9a73934d88142d0e3d9 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Wed, 25 Oct 2023 11:29:22 +0300 Subject: [PATCH] Pass the memory handler and filter config to exclude cuda transport --- src/ucx_plugin.c | 64 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/src/ucx_plugin.c b/src/ucx_plugin.c index cab2ee15..8b148e15 100644 --- a/src/ucx_plugin.c +++ b/src/ucx_plugin.c @@ -34,6 +34,8 @@ } while(0) NCCL_PARAM(UCXDisable, "UCX_DISABLE", 0); +/* Exclude cuda-related UCX transports */ +NCCL_PARAM(UCXCudaDisable, "UCX_CUDA_DISABLE", 1); extern ncclDebugLogger_t pluginLogFunction; static const ucp_tag_t tag = 0x8a000000; @@ -212,15 +214,59 @@ static void recv_handler_nbx(void *request, ucs_status_t status, static union ncclSocketAddress nccl_ucx_if_addr; static char if_name[MAX_IF_NAME_SIZE]; +static ncclResult_t ucx_config_no_cuda(ucp_config_t *config) { + char tmp[PATH_MAX]; + const char *ucx_tls; + ssize_t n; + + ucx_tls = getenv("NCCL_UCX_TLS"); + if (ucx_tls == NULL) { + ucx_tls = getenv("UCX_TLS"); + } + + if (ucx_tls == NULL) { + ucx_tls = "^cuda"; + } else if (ucx_tls[0] == '^') { + /* Negative expression, make sure to keep cuda excluded */ + n = snprintf(tmp, sizeof(tmp), "^cuda,%s", &ucx_tls[1]); + if (n >= sizeof(tmp)) { + return ncclInternalError; + } + + ucx_tls = tmp; + } else { + /* Positive expression cannot allow cuda-like transports */ + if ((strstr(ucx_tls, "cuda") != NULL) || (strstr(ucx_tls, "gdr") != NULL)) { + WARN("Cannot use cuda/gdr transports as part of specified UCX_TLS"); + return ncclInternalError; + } + } + + UCXCHECK(ucp_config_modify(config, "TLS", ucx_tls)); + UCXCHECK(ucp_config_modify(config, "RNDV_THRESH", "0")); + UCXCHECK(ucp_config_modify(config, "RNDV_SCHEME", "get_zcopy")); + UCXCHECK( + ucp_config_modify(config, "MEMTYPE_REG_WHOLE_ALLOC_TYPES", "unknown")); + return ncclSuccess; +} + static ncclResult_t ucx_init_context(ucp_context_h *ctx, int dev) { ucp_params_t ucp_params; ucp_config_t *config; char ucx_dev_name[PATH_MAX]; + ncclResult_t result; snprintf(ucx_dev_name, PATH_MAX, "%s:%d", ncclIbDevs[dev].devName, ncclIbDevs[dev].port); UCXCHECK(ucp_config_read("NCCL", NULL, &config)); UCXCHECK(ucp_config_modify(config, "NET_DEVICES", ucx_dev_name)); + if (ncclParamUCXCudaDisable()) { + result = ucx_config_no_cuda(config); + if (result != ncclSuccess) { + return result; + } + } + memset(&ucp_params, 0, sizeof(ucp_params)); ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES; ucp_params.features = UCP_FEATURE_TAG | UCP_FEATURE_RMA; @@ -276,6 +322,8 @@ static ncclResult_t ucx_get_ctx_and_worker(int dev, ucp_context_h *ctx, nccl_ucx_worker_t **ucx_worker, ucp_tag_t *newtag) { pthread_mutex_lock(&nccl_ucx_lock); + ncclResult_t result; + if (ncclNIbDevs <= dev) { WARN("Device index is too large"); goto err; @@ -300,7 +348,10 @@ static ncclResult_t ucx_get_ctx_and_worker(int dev, ucp_context_h *ctx, w->thread = pthread_self(); w->count = 0; - ucx_init_context(&w->ctx, dev); + result = ucx_init_context(&w->ctx, dev); + if (result != ncclSuccess) { + return result; + } ucx_init_worker(w->ctx, &w->worker); worker_count++; @@ -771,10 +822,11 @@ static ncclResult_t nccl_ucx_isend(void *send_comm, void *data, int size, params.cb.send = send_handler_nbx; params.user_data = &req->pending; if (mh) { - params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE; - params.memory_type = mh->mem_type; + params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH; + params.memh = mh->ucp_memh; } + ucp_req = ucp_tag_send_nbx(comm->ep, data, size, nccl_ucx_ucp_tag(comm->tag, tag), ¶ms); if (UCS_PTR_IS_ERR(ucp_req)) { @@ -826,10 +878,10 @@ static ncclResult_t nccl_ucx_irecv(void *recv_comm, int n, void **data, ucx_request_add(req, sizes[i]); if (mh[i]) { - params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE; - params.memory_type = mh[i]->mem_type; + params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH; + params.memh = mh[i]->ucp_memh; } else { - params.op_attr_mask &= ~UCP_OP_ATTR_FIELD_MEMORY_TYPE; + params.op_attr_mask &= ~UCP_OP_ATTR_FIELD_MEMH; } ucp_req = ucp_tag_recv_nbx(comm->ucx_worker->worker, data[i], sizes[i],