diff --git a/include/nccl_ofi.h b/include/nccl_ofi.h index 2184bf0d6..680ee4844 100644 --- a/include/nccl_ofi.h +++ b/include/nccl_ofi.h @@ -179,12 +179,6 @@ typedef enum nccl_ofi_comm_stage { COMM_CONNECTED, } nccl_ofi_comm_stage_t; -/* Determines which object a provider associates MRs with */ -typedef enum nccl_ofi_mr_scope { - NCCL_OFI_MR_SCOPE_DOMAIN = 0, - NCCL_OFI_MR_SCOPE_ENDPOINT -} nccl_ofi_mr_scope_t; - typedef struct save_comm_state { nccl_net_ofi_comm_t *comm; nccl_net_ofi_req_t *req; @@ -231,8 +225,8 @@ typedef struct nccl_ofi_properties { unsigned int max_communicators; /** Maximum number of grouped receives */ unsigned int max_group_receives; - /** Scope of a memory region registered with a provider **/ - nccl_ofi_mr_scope_t mr_scope; + /** regMr is global if is not tied to a particular comm **/ + int regIsGlobal; } nccl_ofi_properties_t; /** diff --git a/src/nccl_ofi_interface_nvidia.c b/src/nccl_ofi_interface_nvidia.c index 0896ebb91..ac8eff7a4 100644 --- a/src/nccl_ofi_interface_nvidia.c +++ b/src/nccl_ofi_interface_nvidia.c @@ -26,52 +26,26 @@ static ncclResult_t getProperties_v8(int dev_id, ncclNetProperties_v8_t* props) props->ptrSupport |= NCCL_PTR_DMABUF; } - /* - * NCCL uses regIsGlobal to determine support for User Registrations via - * the NCCL API. If providers tie MRs to endpoints, the plugin can not - * support this model (since NCCL maintains a per-domain registration - * cache which requires (domain-)global registrations. + /** + * When net-plugin returns regIsGlobal=1 to NCCL (As part of + * net-plugin getProperties() API), it signals to NCCL that + * registered MRs are global, in the sense that they can be + * used by all communicators. In addition, it also signals to + * NCCL that the net-plugin have a fast MR cache such that + * calling regMr() on same buffer (address and size), will + * quickly return a previously globally registered MR on same + * buffer. + * + * When user registers a buffer with NCCL by using + * ncclCommRegister() API, if net-plugin supports + * regIsGlobal=1, NCCL will register the buffer globally once + * (On each net device) with regMr() API. When the net + * proxy-thread starts to execute a communication task on a + * previously registered user buffer, it will call the + * net-plugin regMr() to quickly fetch the previously globally + * registered MR from the plugin managed MR cache. */ - if (ofi_properties.mr_scope == NCCL_OFI_MR_SCOPE_DOMAIN) { - /** - * TODO: - * When net-plugin returns regIsGlobal=1 to NCCL (As part of - * net-plugin getProperties() API), it signals to NCCL that - * registered MRs are global, in the sense that they can be - * used by all communicators. In addition, it also signals to - * NCCL that the net-plugin have a fast MR cache such that - * calling regMr() on same buffer (address and size), will - * quickly return a previously globally registered MR on same - * buffer. - * - * When user registers a buffer with NCCL by using - * ncclCommRegister() API, if net-plugin supports - * regIsGlobal=1, NCCL will register the buffer globally once - * (On each net device) with regMr() API. When the net - * proxy-thread starts to execute a communication task on a - * previously registered user buffer, it will call the - * net-plugin regMr() to quickly fetch the previously globally - * registered MR from the plugin managed MR cache. - * - * Even though when ofi_properties.mr_scope == NCCL_OFI_MR_SCOPE_DOMAIN, - * aws-ofi-nccl registers MRs globally (As MRs registered are - * not specific to a communicator), aws-ofi-nccl doesn't have - * such a fast MR cache yet. Therefore, it should return - * regIsGlobal=0 for now. We should re-enable this when we fix - * the perf problem. - */ - - /** - * TODO: - * In addtion to the above comment, SENDRECV protocol currently - * does not correctly handle the truncated send case (send size - * > recv size) which NCCL uses when regIsGlobal=1. So, before - * setting this to 1, we need to either fix SENDRECV protocol, - * or refactor this code to set this property in a protocol- - * specific way. - */ - props->regIsGlobal = 0; - } + props->regIsGlobal = ofi_properties.regIsGlobal; props->speed = ofi_properties.port_speed; props->port = ofi_properties.port_number; diff --git a/src/nccl_ofi_net.c b/src/nccl_ofi_net.c index 25e4fb6d4..355983118 100644 --- a/src/nccl_ofi_net.c +++ b/src/nccl_ofi_net.c @@ -386,15 +386,21 @@ int nccl_net_ofi_info_properties(struct fi_info *nic_prov, int dev_id, int num_d } /* - * Determine the scope of MRs for providers to report global - * registration support to NCCL + * Determine the scope of MRs for providers to report global registration + * support to NCCL. + * NCCL uses regIsGlobal to determine support for User Registrations via + * the NCCL API. If providers tie MRs to endpoints, the plugin can not + * support this model (since NCCL maintains a per-domain registration + * cache which requires (domain-)global registrations. + * Also, if we have different domains for different threads, registrations + * are not reported as global even if they are tied to the domain. */ - if (nic_prov->domain_attr->mr_mode & FI_MR_ENDPOINT) { - props->mr_scope = NCCL_OFI_MR_SCOPE_ENDPOINT; - NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Libfabric provider associates MRs with endpoints"); + if (nic_prov->domain_attr->mr_mode & FI_MR_ENDPOINT || domain_per_thread == 1) { + props->regIsGlobal = 0; + NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Global registrations are not supported"); } else { - props->mr_scope = NCCL_OFI_MR_SCOPE_DOMAIN; - NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Libfabric provider associates MRs with domains"); + props->regIsGlobal = 1; + NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Global registrations supported"); } /* Speed reported in Mbps */ diff --git a/src/nccl_ofi_sendrecv.c b/src/nccl_ofi_sendrecv.c index 1284da2a0..517425cd6 100644 --- a/src/nccl_ofi_sendrecv.c +++ b/src/nccl_ofi_sendrecv.c @@ -53,6 +53,14 @@ static inline int get_properties(nccl_net_ofi_device_t *base_dev, props->max_communicators = NCCL_OFI_MIN(device->max_tag, INT_MAX); } + /** + * TODO: + * The SENDRECV protocol currently does not correctly handle the truncated + * send case (send size > recv size) which NCCL may use when regIsGlobal=1. + * Remove this line once that is fixed. + */ + props->regIsGlobal = 0; + return ret; }