Skip to content

Commit

Permalink
Support NCCL's v8 ext-net interface
Browse files Browse the repository at this point in the history
Two key changes:
- regMr size changed from int -> size_t
- A new `regIsGlobal` property which is used by NCCL to determine
  support for user registrations. The plugin now determines this via the
  mr_mode bit providers set to define the scope of a MR (domain-level or
  endpoint-level).

Signed-off-by: Raghu Raja <[email protected]>
Signed-off-by: Liran Alon <[email protected]>
Signed-off-by: Raghu Raja <[email protected]>
  • Loading branch information
rajachan committed Mar 21, 2024
1 parent 4d86c3e commit afd61d6
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 6 deletions.
1 change: 1 addition & 0 deletions include/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ noinst_HEADERS = \
nccl-headers/nvidia/net_v5.h \
nccl-headers/nvidia/net_v6.h \
nccl-headers/nvidia/net_v7.h \
nccl-headers/nvidia/net_v8.h \
nccl-headers/nvidia/types.h \
nccl-headers/nvidia/tuner.h \
nccl-headers/neuron/net.h \
Expand Down
8 changes: 8 additions & 0 deletions include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ typedef enum nccl_ofi_comm_stage {
COMM_CONNECTED,
} nccl_ofi_comm_stage_t;

/* Determines which object a provider associates MRs with */
typedef enum nccl_ofi_mr_scope {
NCCL_OFI_MR_SCOPE_DOMAIN = 0,
NCCL_OFI_MR_SCOPE_ENDPOINT
} nccl_ofi_mr_scope_t;

typedef struct save_comm_state {
nccl_net_ofi_comm_t *comm;
nccl_net_ofi_req_t *req;
Expand Down Expand Up @@ -222,6 +228,8 @@ typedef struct nccl_ofi_properties {
unsigned int max_communicators;
/** Maximum number of grouped receives */
unsigned int max_group_receives;
/** Scope of a memory region registered with a provider **/
nccl_ofi_mr_scope_t mr_scope;
} nccl_ofi_properties_t;

/**
Expand Down
60 changes: 60 additions & 0 deletions src/nccl_ofi_interface_nvidia.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,44 @@
#include "nccl_ofi.h"
#include "nccl_ofi_api.h"

static ncclResult_t getProperties_v8(int dev_id, ncclNetProperties_v8_t* props)
{
nccl_ofi_properties_t ofi_properties;
ncclResult_t ret = nccl_net_ofi_get_properties(dev_id, &ofi_properties);
if (ret != ncclSuccess) {
return ret;
}

props->name = ofi_properties.name;
props->pciPath = ofi_properties.pci_path;
props->guid = ofi_properties.guid;
props->ptrSupport = NCCL_PTR_HOST;
if (ofi_properties.hmem_support) {
props->ptrSupport |= NCCL_PTR_CUDA;
}
if (ofi_properties.dmabuf_support) {
props->ptrSupport |= NCCL_PTR_DMABUF;
}

/*
* NCCL uses regIsGlobal to determine support for User Registrations via
* the NCCL API. If providers tie MRs to endpoints, the plugin can not
* support this model (since NCCL maintains a per-domain registration
* cache which requires (domain-)global registrations.
*/
if (ofi_properties.mr_scope == NCCL_OFI_MR_SCOPE_DOMAIN)
props->regIsGlobal = 1;

props->speed = ofi_properties.port_speed;
props->port = ofi_properties.port_number;
props->latency = ofi_properties.latency;
props->maxComms = ofi_properties.max_communicators;
props->maxRecvs = ofi_properties.max_group_receives;
props->netDeviceType = NCCL_NET_DEVICE_HOST;
props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION;

return ncclSuccess;
}

static ncclResult_t getProperties_v7(int dev_id, ncclNetProperties_v7_t *props)
{
Expand Down Expand Up @@ -248,3 +286,25 @@ const ncclNet_v7_t ncclNetPlugin_v7 = {
.getDeviceMr = NULL,
.irecvConsumed = NULL,
};

const ncclNet_v8_t ncclNetPlugin_v8 = {
.name = "AWS Libfabric",
.init = nccl_net_ofi_init,
.devices = nccl_net_ofi_devices,
.getProperties = getProperties_v8,
.listen = nccl_net_ofi_listen,
.connect = connect_v7,
.accept = accept_v7,
.regMr = nccl_net_ofi_regMr,
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
.irecv = nccl_net_ofi_irecv,
.iflush = nccl_net_ofi_iflush,
.test = nccl_net_ofi_test,
.closeSend = nccl_net_ofi_closeSend,
.closeRecv = nccl_net_ofi_closeRecv,
.closeListen = nccl_net_ofi_closeListen,
.getDeviceMr = NULL,
.irecvConsumed = NULL,
};
12 changes: 12 additions & 0 deletions src/nccl_ofi_net.c
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,18 @@ int nccl_net_ofi_info_properties(struct fi_info *nic_prov, int dev_id, int num_d
dev_props.name = strdup(nic_info->device_attr->name);
}

/*
* Determine the scope of MRs for providers to report global
* registration support to NCCL
*/
if (nic_prov->domain_attr->mr_mode & FI_MR_ENDPOINT) {
dev_props.mr_scope = NCCL_OFI_MR_SCOPE_ENDPOINT;
NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Libfabric provider associates MRs with endpoints");
} else {
dev_props.mr_scope = NCCL_OFI_MR_SCOPE_DOMAIN;
NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Libfabric provider associates MRs with domains");
}

/* Speed reported in Mbps */
dev_props.port_speed = nic_info->link_attr->speed / (1e6);

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/nccl_connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int main(int argc, char* argv[])
nccl_net_ofi_send_comm_t *sComm = NULL;
nccl_net_ofi_listen_comm_t *lComm = NULL;
nccl_net_ofi_recv_comm_t *rComm = NULL;
ncclNetDeviceHandle_v7_t *s_ignore, *r_ignore;
ncclNetDeviceHandle_v8_t *s_ignore, *r_ignore;
char src_handle[NCCL_NET_HANDLE_MAXSIZE] = {0};
char handle[NCCL_NET_HANDLE_MAXSIZE] = {0};
test_nccl_net_t *extNet = NULL;
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/nccl_message_transfer.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ int main(int argc, char* argv[])
nccl_net_ofi_listen_comm_t *lComm = NULL;
nccl_net_ofi_recv_comm_t *rComm = NULL;
test_nccl_net_t *extNet = NULL;
ncclNetDeviceHandle_v7_t *s_ignore, *r_ignore;
ncclNetDeviceHandle_v8_t *s_ignore, *r_ignore;
char src_handle[NCCL_NET_HANDLE_MAXSIZE] = {0};

ofi_log_function = logger;
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ int main(int argc, char *argv[])
char handle[NCCL_NET_HANDLE_MAXSIZE] = {0};
char src_handle_prev[NCCL_NET_HANDLE_MAXSIZE] = {0};
char src_handle_next[NCCL_NET_HANDLE_MAXSIZE] = {0};
ncclNetDeviceHandle_v7_t *s_ignore, *r_ignore;
ncclNetDeviceHandle_v8_t *s_ignore, *r_ignore;
test_nccl_net_t *extNet = NULL;

ofi_log_function = logger;
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/test-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@
} while(false);

// Can be changed when porting new versions to the plugin
#define NCCL_PLUGIN_SYMBOL ncclNetPlugin_v7
#define NCCL_PLUGIN_SYMBOL ncclNetPlugin_v8

typedef ncclNet_v7_t test_nccl_net_t;
typedef ncclNetProperties_v7_t test_nccl_properties_t;
typedef ncclNet_v8_t test_nccl_net_t;
typedef ncclNetProperties_v8_t test_nccl_properties_t;

void logger(ncclDebugLogLevel level, unsigned long flags, const char *filefunc,
int line, const char *fmt, ...)
Expand Down

0 comments on commit afd61d6

Please sign in to comment.