From e5481decd193f7859968b8499faf1be29e203eb5 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Tue, 10 Sep 2024 14:11:22 +0000 Subject: [PATCH] UCT/CUDA: Advertise MNNVL inter-node capability with shm device type --- src/ucp/core/ucp_worker.c | 5 ++++- src/uct/api/uct.h | 3 +++ src/uct/base/uct_iface.h | 1 + src/uct/base/uct_md.c | 1 + src/uct/cuda/cuda_ipc/cuda_ipc_iface.c | 15 ++++++++++++--- 5 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/ucp/core/ucp_worker.c b/src/ucp/core/ucp_worker.c index 7ac0e2fc64e..e7e6a668f07 100644 --- a/src/ucp/core/ucp_worker.c +++ b/src/ucp/core/ucp_worker.c @@ -2961,7 +2961,10 @@ static ucs_status_t ucp_worker_address_pack(ucp_worker_h worker, if (address_flags & UCP_WORKER_ADDRESS_FLAG_NET_ONLY) { UCS_STATIC_BITMAP_RESET_ALL(&tl_bitmap); UCS_STATIC_BITMAP_FOR_EACH_BIT(tl_id, &worker->context->tl_bitmap) { - if (context->tl_rscs[tl_id].tl_rsc.dev_type == UCT_DEVICE_TYPE_NET) { + if ((context->tl_rscs[tl_id].tl_rsc.dev_type == + UCT_DEVICE_TYPE_NET) || + (context->tl_rscs[tl_id].tl_rsc.flags & + UCT_TL_RESOURCE_FLAG_INTER_NODE)) { UCS_STATIC_BITMAP_SET(&tl_bitmap, tl_id); } } diff --git a/src/uct/api/uct.h b/src/uct/api/uct.h index 34c5ff14a75..1da40146659 100644 --- a/src/uct/api/uct.h +++ b/src/uct/api/uct.h @@ -332,8 +332,11 @@ typedef struct uct_tl_resource_desc { (e.g. UCT_DEVICE_TYPE_NET for a network interface) */ ucs_sys_device_t sys_device; /**< The identifier associated with the device bus_id as captured in ucs_sys_bus_id_t struct */ + uint8_t flags; /**< Associated flags to the resource */ } uct_tl_resource_desc_t; +#define UCT_TL_RESOURCE_FLAG_INTER_NODE UCS_BIT(0) /**< Inter-node capability */ + #define UCT_TL_RESOURCE_DESC_FMT "%s/%s" #define UCT_TL_RESOURCE_DESC_ARG(_resource) (_resource)->tl_name, (_resource)->dev_name diff --git a/src/uct/base/uct_iface.h b/src/uct/base/uct_iface.h index b44d4306b8e..cfa8d0787d0 100644 --- a/src/uct/base/uct_iface.h +++ b/src/uct/base/uct_iface.h @@ -385,6 +385,7 @@ typedef struct uct_tl_device_resource { (e.g. UCT_DEVICE_TYPE_NET for a network interface) */ ucs_sys_device_t sys_device; /**< The identifier associated with the device bus_id as captured in ucs_sys_bus_id_t struct */ + uint8_t flags; /**< Associated flags to the resource */ } uct_tl_device_resource_t; diff --git a/src/uct/base/uct_md.c b/src/uct/base/uct_md.c index 61c0dfff10e..db63ba19b26 100644 --- a/src/uct/base/uct_md.c +++ b/src/uct/base/uct_md.c @@ -120,6 +120,7 @@ ucs_status_t uct_md_query_tl_resources(uct_md_h md, sizeof(tmp[num_resources + i].dev_name)); tmp[num_resources + i].dev_type = tl_devices[i].type; tmp[num_resources + i].sys_device = tl_devices[i].sys_device; + tmp[num_resources + i].flags = tl_devices[i].flags; } resources = tmp; diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c b/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c index 893d987267e..d0985c6ec64 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_iface.c @@ -611,16 +611,25 @@ uct_cuda_ipc_query_devices( uct_md_h uct_md, uct_tl_device_resource_t **tl_devices_p, unsigned *num_tl_devices_p) { + uint8_t flags = 0; uct_device_type_t dev_type = UCT_DEVICE_TYPE_SHM; + ucs_status_t status; + #if HAVE_CUDA_FABRIC uct_cuda_ipc_md_t *md = ucs_derived_of(uct_md, uct_cuda_ipc_md_t); if (uct_cuda_ipc_iface_is_mnnvl_supported(md)) { - dev_type = UCT_DEVICE_TYPE_NET; + flags = UCT_TL_RESOURCE_FLAG_INTER_NODE; } #endif - return uct_cuda_base_query_devices_common(uct_md, dev_type, - tl_devices_p, num_tl_devices_p); + status = uct_cuda_base_query_devices_common(uct_md, dev_type, tl_devices_p, + num_tl_devices_p); + if (status == UCS_OK) { + ucs_assert(*num_tl_devices_p == 1); + (*tl_devices_p)->flags = flags; + } + + return status; } UCS_CLASS_DEFINE(uct_cuda_ipc_iface_t, uct_cuda_iface_t);