diff --git a/include/Makefile.am b/include/Makefile.am index 1d38fcacd..f096c06de 100644 --- a/include/Makefile.am +++ b/include/Makefile.am @@ -37,6 +37,7 @@ noinst_HEADERS = \ nccl-headers/nvidia/net_v5.h \ nccl-headers/nvidia/net_v6.h \ nccl-headers/nvidia/net_v7.h \ + nccl-headers/nvidia/net_v8.h \ nccl-headers/nvidia/types.h \ nccl-headers/nvidia/tuner.h \ nccl-headers/neuron/net.h \ diff --git a/include/nccl_ofi.h b/include/nccl_ofi.h index 7890dbd3e..12c059600 100644 --- a/include/nccl_ofi.h +++ b/include/nccl_ofi.h @@ -179,6 +179,12 @@ typedef enum nccl_ofi_comm_stage { COMM_CONNECTED, } nccl_ofi_comm_stage_t; +/* Determines which object a provider associates MRs with */ +typedef enum nccl_ofi_mr_scope { + NCCL_OFI_MR_SCOPE_DOMAIN = 0, + NCCL_OFI_MR_SCOPE_ENDPOINT +} nccl_ofi_mr_scope_t; + typedef struct save_comm_state { nccl_net_ofi_comm_t *comm; nccl_net_ofi_req_t *req; @@ -222,6 +228,8 @@ typedef struct nccl_ofi_properties { unsigned int max_communicators; /** Maximum number of grouped receives */ unsigned int max_group_receives; + /** Scope of a memory region registered with a provider **/ + nccl_ofi_mr_scope_t mr_scope; } nccl_ofi_properties_t; /** diff --git a/src/nccl_ofi_interface_nvidia.c b/src/nccl_ofi_interface_nvidia.c index 4f842cd33..4d1d12300 100644 --- a/src/nccl_ofi_interface_nvidia.c +++ b/src/nccl_ofi_interface_nvidia.c @@ -7,6 +7,44 @@ #include "nccl_ofi.h" #include "nccl_ofi_api.h" +static ncclResult_t getProperties_v8(int dev_id, ncclNetProperties_v8_t* props) +{ + nccl_ofi_properties_t ofi_properties; + ncclResult_t ret = nccl_net_ofi_get_properties(dev_id, &ofi_properties); + if (ret != ncclSuccess) { + return ret; + } + + props->name = ofi_properties.name; + props->pciPath = ofi_properties.pci_path; + props->guid = ofi_properties.guid; + props->ptrSupport = NCCL_PTR_HOST; + if (ofi_properties.hmem_support) { + props->ptrSupport |= NCCL_PTR_CUDA; + } + if (ofi_properties.dmabuf_support) { + props->ptrSupport |= NCCL_PTR_DMABUF; + } + + /* + * NCCL uses regIsGlobal to determine support for User Registrations via + * the NCCL API. If providers tie MRs to endpoints, the plugin can not + * support this model (since NCCL maintains a per-domain registration + * cache which requires (domain-)global registrations. + */ + if (ofi_properties.mr_scope == NCCL_OFI_MR_SCOPE_DOMAIN) + props->regIsGlobal = 1; + + props->speed = ofi_properties.port_speed; + props->port = ofi_properties.port_number; + props->latency = ofi_properties.latency; + props->maxComms = ofi_properties.max_communicators; + props->maxRecvs = ofi_properties.max_group_receives; + props->netDeviceType = NCCL_NET_DEVICE_HOST; + props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + + return ncclSuccess; +} static ncclResult_t getProperties_v7(int dev_id, ncclNetProperties_v7_t *props) { @@ -248,3 +286,25 @@ const ncclNet_v7_t ncclNetPlugin_v7 = { .getDeviceMr = NULL, .irecvConsumed = NULL, }; + +const ncclNet_v8_t ncclNetPlugin_v8 = { + .name = "AWS Libfabric", + .init = nccl_net_ofi_init, + .devices = nccl_net_ofi_devices, + .getProperties = getProperties_v8, + .listen = nccl_net_ofi_listen, + .connect = connect_v7, + .accept = accept_v7, + .regMr = nccl_net_ofi_regMr, + .regMrDmaBuf = nccl_net_ofi_regMrDmaBuf, + .deregMr = nccl_net_ofi_deregMr, + .isend = nccl_net_ofi_isend, + .irecv = nccl_net_ofi_irecv, + .iflush = nccl_net_ofi_iflush, + .test = nccl_net_ofi_test, + .closeSend = nccl_net_ofi_closeSend, + .closeRecv = nccl_net_ofi_closeRecv, + .closeListen = nccl_net_ofi_closeListen, + .getDeviceMr = NULL, + .irecvConsumed = NULL, +}; diff --git a/src/nccl_ofi_net.c b/src/nccl_ofi_net.c index 8a9410f87..e714a7019 100644 --- a/src/nccl_ofi_net.c +++ b/src/nccl_ofi_net.c @@ -376,6 +376,18 @@ int nccl_net_ofi_info_properties(struct fi_info *nic_prov, int dev_id, int num_d dev_props.name = strdup(nic_info->device_attr->name); } + /* + * Determine the scope of MRs for providers to report global + * registration support to NCCL + */ + if (nic_prov->domain_attr->mr_mode & FI_MR_ENDPOINT) { + dev_props.mr_scope = NCCL_OFI_MR_SCOPE_ENDPOINT; + NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Libfabric provider associates MRs with endpoints"); + } else { + dev_props.mr_scope = NCCL_OFI_MR_SCOPE_DOMAIN; + NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Libfabric provider associates MRs with domains"); + } + /* Speed reported in Mbps */ dev_props.port_speed = nic_info->link_attr->speed / (1e6); diff --git a/tests/functional/nccl_connection.c b/tests/functional/nccl_connection.c index ae2433a67..34a202fac 100644 --- a/tests/functional/nccl_connection.c +++ b/tests/functional/nccl_connection.c @@ -21,7 +21,7 @@ int main(int argc, char* argv[]) nccl_net_ofi_send_comm_t *sComm = NULL; nccl_net_ofi_listen_comm_t *lComm = NULL; nccl_net_ofi_recv_comm_t *rComm = NULL; - ncclNetDeviceHandle_v7_t *s_ignore, *r_ignore; + ncclNetDeviceHandle_v8_t *s_ignore, *r_ignore; char src_handle[NCCL_NET_HANDLE_MAXSIZE] = {0}; char handle[NCCL_NET_HANDLE_MAXSIZE] = {0}; test_nccl_net_t *extNet = NULL; diff --git a/tests/functional/nccl_message_transfer.c b/tests/functional/nccl_message_transfer.c index 14a19bfa1..e401c683b 100644 --- a/tests/functional/nccl_message_transfer.c +++ b/tests/functional/nccl_message_transfer.c @@ -25,7 +25,7 @@ int main(int argc, char* argv[]) nccl_net_ofi_listen_comm_t *lComm = NULL; nccl_net_ofi_recv_comm_t *rComm = NULL; test_nccl_net_t *extNet = NULL; - ncclNetDeviceHandle_v7_t *s_ignore, *r_ignore; + ncclNetDeviceHandle_v8_t *s_ignore, *r_ignore; char src_handle[NCCL_NET_HANDLE_MAXSIZE] = {0}; ofi_log_function = logger; diff --git a/tests/functional/ring.c b/tests/functional/ring.c index d1f63ff7d..933f077a2 100644 --- a/tests/functional/ring.c +++ b/tests/functional/ring.c @@ -23,7 +23,7 @@ int main(int argc, char *argv[]) char handle[NCCL_NET_HANDLE_MAXSIZE] = {0}; char src_handle_prev[NCCL_NET_HANDLE_MAXSIZE] = {0}; char src_handle_next[NCCL_NET_HANDLE_MAXSIZE] = {0}; - ncclNetDeviceHandle_v7_t *s_ignore, *r_ignore; + ncclNetDeviceHandle_v8_t *s_ignore, *r_ignore; test_nccl_net_t *extNet = NULL; ofi_log_function = logger; diff --git a/tests/functional/test-common.h b/tests/functional/test-common.h index ea71cc99e..a46bfeb3b 100644 --- a/tests/functional/test-common.h +++ b/tests/functional/test-common.h @@ -53,10 +53,10 @@ } while(false); // Can be changed when porting new versions to the plugin -#define NCCL_PLUGIN_SYMBOL ncclNetPlugin_v7 +#define NCCL_PLUGIN_SYMBOL ncclNetPlugin_v8 -typedef ncclNet_v7_t test_nccl_net_t; -typedef ncclNetProperties_v7_t test_nccl_properties_t; +typedef ncclNet_v8_t test_nccl_net_t; +typedef ncclNetProperties_v8_t test_nccl_properties_t; void logger(ncclDebugLogLevel level, unsigned long flags, const char *filefunc, int line, const char *fmt, ...)