Skip to content

Commit 8a6938e

Browse files
committed
Plugin v8
- reduce-scatter and allgather API
1 parent a0432ad commit 8a6938e

9 files changed

+451
-23
lines changed

include/net.h

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCC
2222

2323
typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...);
2424

25+
#include "net_v8.h"
2526
#include "net_v7.h"
2627
#include "net_v6.h"
2728
#include "net_v5.h"

include/net_device.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ typedef struct {
2626
int needsProxyProgress;
2727
} ncclNetDeviceHandle_v7_t;
2828

29-
typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_t;
29+
typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_v8_t;
30+
typedef ncclNetDeviceHandle_v8_t ncclNetDeviceHandle_t;
3031

3132
#endif

include/net_v7.h

-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ typedef struct {
2222
int netDeviceVersion; // Version number for network offload
2323
} ncclNetProperties_v7_t;
2424

25-
typedef ncclNetProperties_v7_t ncclNetProperties_t;
26-
2725
typedef struct {
2826
// Name of the network (mainly for logs)
2927
const char* name;

include/net_v8.h

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright (c) 2017-2023, NVIDIA CORPORATION. All rights reserved.
3+
*/
4+
5+
#ifndef NCCL_NET_V8_H_
6+
#define NCCL_NET_V8_H_
7+
#include "net_device.h"
8+
9+
typedef struct {
10+
char* name; // Used mostly for logging.
11+
char* pciPath; // Path to the PCI device in /sys.
12+
uint64_t guid; // Unique identifier for the NIC chip. Important for
13+
// cards with multiple PCI functions (Physical or virtual).
14+
int ptrSupport; // [NCCL_PTR_HOST|NCCL_PTR_CUDA|NCCL_PTR_DMABUF]
15+
int regIsGlobal; // regMr is not tied to a particular comm
16+
int speed; // Port speed in Mbps.
17+
int port; // Port number.
18+
float latency; // Network latency
19+
int maxComms; // Maximum number of comms we can create
20+
int maxRecvs; // Maximum number of grouped receives.
21+
ncclNetDeviceType netDeviceType; // Network offload type
22+
int netDeviceVersion; // Version number for network offload
23+
} ncclNetProperties_v8_t;
24+
25+
typedef ncclNetProperties_v8_t ncclNetProperties_t;
26+
27+
typedef struct {
28+
// Name of the network (mainly for logs)
29+
const char* name;
30+
// Initialize the network.
31+
ncclResult_t (*init)(ncclDebugLogger_t logFunction);
32+
// Return the number of adapters.
33+
ncclResult_t (*devices)(int* ndev);
34+
// Get various device properties.
35+
ncclResult_t (*getProperties)(int dev, ncclNetProperties_v8_t* props);
36+
// Create a receiving object and provide a handle to connect to it. The
37+
// handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged
38+
// between ranks to create a connection.
39+
ncclResult_t (*listen)(int dev, void* handle, void** listenComm);
40+
// Connect to a handle and return a sending comm object for that peer.
41+
// This call must not block for the connection to be established, and instead
42+
// should return successfully with sendComm == NULL with the expectation that
43+
// it will be called again until sendComm != NULL.
44+
// If *sendDevComm points to a valid object, then NCCL is requesting device offload for this connection
45+
ncclResult_t (*connect)(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v8_t** sendDevComm);
46+
// Finalize connection establishment after remote peer has called connect.
47+
// This call must not block for the connection to be established, and instead
48+
// should return successfully with recvComm == NULL with the expectation that
49+
// it will be called again until recvComm != NULL.
50+
// If *recvDevComm points to a valid object, then NCCL is requesting device offload for this connection
51+
ncclResult_t (*accept)(void* listenComm, void** recvComm, ncclNetDeviceHandle_v8_t** recvDevComm);
52+
// Register/Deregister memory. Comm can be either a sendComm or a recvComm.
53+
// Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA.
54+
ncclResult_t (*regMr)(void* comm, void* data, size_t size, int type, void** mhandle);
55+
/* DMA-BUF support */
56+
ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle);
57+
ncclResult_t (*deregMr)(void* comm, void* mhandle);
58+
// Asynchronous send to a peer.
59+
// May return request == NULL if the call cannot be performed (or would block)
60+
ncclResult_t (*isend)(void* sendComm, void* data, int size, int tag, void* mhandle, void** request);
61+
// Asynchronous recv from a peer.
62+
// May return request == NULL if the call cannot be performed (or would block)
63+
ncclResult_t (*irecv)(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request);
64+
// Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is
65+
// visible to the GPU
66+
ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request);
67+
// Test whether a request is complete. If size is not NULL, it returns the
68+
// number of bytes sent/received.
69+
ncclResult_t (*test)(void* request, int* done, int* sizes);
70+
// Close and free send/recv comm objects
71+
ncclResult_t (*closeSend)(void* sendComm);
72+
ncclResult_t (*closeRecv)(void* recvComm);
73+
ncclResult_t (*closeListen)(void* listenComm);
74+
75+
// Copy the given mhandle to a dptr in a format usable by this plugin's device code
76+
ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle);
77+
78+
// Notify the plugin that a recv has completed by the device
79+
ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request);
80+
} ncclNet_v8_t;
81+
82+
83+
typedef struct {
84+
void* mhandle;
85+
void* address;
86+
uint32_t size;
87+
} ncclNetSGE_v8_t;
88+
89+
typedef struct {
90+
// Name of the collective network (mainly for logs)
91+
const char* name;
92+
// Initialize the collective network.
93+
ncclResult_t (*init)(ncclDebugLogger_t logFunction);
94+
// Return the number of adapters capable of doing collective operations.
95+
// If ndev returns 0, all other functions might be set to NULL.
96+
ncclResult_t (*devices)(int* ndev);
97+
// Get various device properties.
98+
ncclResult_t (*getProperties)(int dev, ncclNetProperties_v8_t* props);
99+
// Create a receiving object and provide a handle to connect to it. The
100+
// handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged
101+
// between ranks to create connections.
102+
ncclResult_t (*listen)(int dev, void* handle, void** listenComm);
103+
// Create a group for collective operations. handles have been created
104+
// using listen() above. rank indicates caller's rank in the collective network.
105+
ncclResult_t (*connect)(void* handles[], int nranks, int rank, void* listenComm, void** collComm);
106+
// Returns whether a reduction operation on a data type is supported.
107+
// 1 for supported, 0 otherwise.
108+
ncclResult_t (*reduceSupport)(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported);
109+
// Register/Deregister memory. Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA.
110+
ncclResult_t (*regMr)(void* collComm, void* data, size_t size, int type, void** mhandle);
111+
/* DMA-BUF support */
112+
ncclResult_t (*regMrDmaBuf)(void* collComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle);
113+
ncclResult_t (*deregMr)(void* collComm, void* mhandle);
114+
// Performs an asynchronous allreduce operation on the collective group.
115+
// May return request == NULL if the call cannot be performed (or would block).
116+
ncclResult_t (*iallreduce)(void* collComm, void* sendData, void* recvData, int count,
117+
ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request);
118+
ncclResult_t (*iallgather)(void* collComm, void* sendData, int nRecvParts, ncclNetSGE_v8_t* recvParts,
119+
size_t bytesPerRank, size_t windowOffset, size_t windowBytes,
120+
void* sendMhandle, void** request);
121+
ncclResult_t (*ireducescatter)(void* collComm, int nSendParts, ncclNetSGE_v8_t* sendParts, void* recvData,
122+
size_t bytesPerRank, size_t windowOffset, size_t windowBytes,
123+
ncclDataType_t dataType, ncclRedOp_t redOp,
124+
void* recvMhandle, void** request);
125+
// Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is
126+
// visible to the GPU
127+
ncclResult_t (*iflush)(void* collComm, void* data, int size, void* mhandle, void** request);
128+
// Test whether a request is complete. If size is not NULL, it returns the
129+
// number of bytes sent/received.
130+
ncclResult_t (*test)(void* request, int* done, int* size);
131+
// Close and free collective comm objects
132+
ncclResult_t (*closeColl)(void* collComm);
133+
ncclResult_t (*closeListen)(void* listenComm);
134+
} ncclCollNet_v8_t;
135+
136+
137+
#endif // end include guard

src/ib_plugin.c

+49-5
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,25 @@ ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props)
7878
return nccl_p2p_ib_get_properties(ncclIbDevs, dev, props);
7979
}
8080

81+
ncclResult_t ncclIbGetProperties_v7(int dev, ncclNetProperties_v7_t* props_v7)
82+
{
83+
ncclNetProperties_t props;
84+
ncclResult_t ret = nccl_p2p_ib_get_properties(ncclIbDevs, dev, &props);
85+
if (ret != ncclSuccess) return ret;
86+
props_v7->name = props.name;
87+
props_v7->pciPath = props.pciPath;
88+
props_v7->guid = props.guid;
89+
props_v7->ptrSupport = props.ptrSupport;
90+
props_v7->speed = props.speed;
91+
props_v7->latency = props.latency;
92+
props_v7->port = props.port;
93+
props_v7->maxComms = props.maxComms;
94+
props_v7->maxRecvs = props.maxRecvs;
95+
props_v7->netDeviceType = props.netDeviceType;
96+
props_v7->netDeviceVersion = props.netDeviceVersion;
97+
return ncclSuccess;
98+
}
99+
81100
ncclResult_t ncclIbGetProperties_v6(int dev, ncclNetProperties_v6_t* props_v6)
82101
{
83102
ncclNetProperties_t props;
@@ -693,7 +712,10 @@ ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, ui
693712
return res;
694713
}
695714

696-
ncclResult_t ncclIbRegMr(void* comm, void* data, int size, int type, void** mhandle) {
715+
ncclResult_t ncclIbRegMr(void* comm, void* data, size_t size, int type, void** mhandle) {
716+
return ncclIbRegMrDmaBuf(comm, data, size, type, 0ULL, -1, mhandle);
717+
}
718+
ncclResult_t ncclIbRegMr_v7(void* comm, void* data, int size, int type, void** mhandle) {
697719
return ncclIbRegMrDmaBuf(comm, data, (size_t)size, type, 0ULL, -1, mhandle);
698720
}
699721

@@ -1124,8 +1146,8 @@ ncclResult_t ncclIbCloseListen(void* listenComm) {
11241146
return ncclSuccess;
11251147
}
11261148

1127-
const ncclNet_v7_t ibPlugin_v7 = {
1128-
.name = "IBext_v7",
1149+
const ncclNet_v8_t ibPlugin_v8 = {
1150+
.name = "IBext_v8",
11291151
.init = ncclIbInit,
11301152
.devices = ncclIbDevices,
11311153
.getProperties = ncclIbGetProperties,
@@ -1146,6 +1168,28 @@ const ncclNet_v7_t ibPlugin_v7 = {
11461168
NULL /* irecvConsumed */
11471169
};
11481170

1171+
const ncclNet_v7_t ibPlugin_v7 = {
1172+
.name = "IBext_v7",
1173+
.init = ncclIbInit,
1174+
.devices = ncclIbDevices,
1175+
.getProperties = ncclIbGetProperties_v7,
1176+
.listen = ncclIbListen,
1177+
.connect = ncclIbConnect,
1178+
.accept = ncclIbAccept,
1179+
.regMr = ncclIbRegMr_v7,
1180+
.regMrDmaBuf = ncclIbRegMrDmaBuf,
1181+
.deregMr = ncclIbDeregMr,
1182+
.isend = ncclIbIsend,
1183+
.irecv = ncclIbIrecv,
1184+
.iflush = ncclIbIflush,
1185+
.test = ncclIbTest,
1186+
.closeSend = ncclIbCloseSend,
1187+
.closeRecv = ncclIbCloseRecv,
1188+
.closeListen = ncclIbCloseListen,
1189+
NULL /* getDeviceMr */,
1190+
NULL /* irecvConsumed */
1191+
};
1192+
11491193
const ncclNet_v6_t ibPlugin_v6 = {
11501194
.name = "IBext_v6",
11511195
.init = ncclIbInit,
@@ -1154,7 +1198,7 @@ const ncclNet_v6_t ibPlugin_v6 = {
11541198
.listen = ncclIbListen,
11551199
.connect = ncclIbConnect_v6,
11561200
.accept = ncclIbAccept_v6,
1157-
.regMr = ncclIbRegMr,
1201+
.regMr = ncclIbRegMr_v7,
11581202
.regMrDmaBuf = ncclIbRegMrDmaBuf,
11591203
.deregMr = ncclIbDeregMr,
11601204
.isend = ncclIbIsend,
@@ -1174,7 +1218,7 @@ const ncclNet_v5_t ibPlugin_v5 = {
11741218
.listen = ncclIbListen,
11751219
.connect = ncclIbConnect_v6,
11761220
.accept = ncclIbAccept_v6,
1177-
.regMr = ncclIbRegMr,
1221+
.regMr = ncclIbRegMr_v7,
11781222
.deregMr = ncclIbDeregMr,
11791223
.isend = ncclIbIsend,
11801224
.irecv = ncclIbIrecv,

src/p2p_plugin.c

+20
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515
#include "p2p_plugin.h"
1616

1717
#ifdef HAVE_UCX_PLUGIN
18+
extern ncclNet_v8_t ucxPlugin_v8;
1819
extern ncclNet_v7_t ucxPlugin_v7;
1920
extern ncclNet_v6_t ucxPlugin_v6;
2021
extern ncclNet_v5_t ucxPlugin_v5;
22+
extern ncclNet_v8_t ucxRmaPlugin_v8;
2123
extern ncclNet_v7_t ucxRmaPlugin_v7;
2224
extern ncclNet_v6_t ucxRmaPlugin_v6;
2325
extern ncclNet_v5_t ucxRmaPlugin_v5;
2426
#endif
2527

28+
extern ncclNet_v8_t ibPlugin_v8;
2629
extern ncclNet_v7_t ibPlugin_v7;
2730
extern ncclNet_v6_t ibPlugin_v6;
2831
extern ncclNet_v5_t ibPlugin_v5;
@@ -40,10 +43,16 @@ extern int ncclIbRelaxedOrderingEnabled;
4043
NCCL_PARAM(SharpMaxComms, "SHARP_MAX_COMMS", 1);
4144
NCCL_PARAM(IbAdaptiveRouting, "IB_ADAPTIVE_ROUTING", -2);
4245

46+
ncclResult_t pluginInit_v8(ncclDebugLogger_t logFunction);
4347
ncclResult_t pluginInit_v7(ncclDebugLogger_t logFunction);
4448
ncclResult_t pluginInit_v6(ncclDebugLogger_t logFunction);
4549
ncclResult_t pluginInit_v5(ncclDebugLogger_t logFunction);
4650

51+
ncclNet_v8_t ncclNetPlugin_v8 = {
52+
"NCCL RDMA Plugin v8",
53+
pluginInit_v8,
54+
};
55+
4756
ncclNet_v7_t ncclNetPlugin_v7 = {
4857
"NCCL RDMA Plugin v7",
4958
pluginInit_v7,
@@ -85,17 +94,20 @@ static void pluginSetup()
8594
switch (p2p_plugin) {
8695
#ifdef HAVE_UCX_PLUGIN
8796
case NCCL_P2P_UCX:
97+
ncclNetPlugin_v8 = ucxPlugin_v8;
8898
ncclNetPlugin_v7 = ucxPlugin_v7;
8999
ncclNetPlugin_v6 = ucxPlugin_v6;
90100
ncclNetPlugin_v5 = ucxPlugin_v5;
91101
break;
92102
case NCCL_P2P_UCX_RMA:
103+
ncclNetPlugin_v8 = ucxRmaPlugin_v8;
93104
ncclNetPlugin_v7 = ucxRmaPlugin_v7;
94105
ncclNetPlugin_v6 = ucxRmaPlugin_v6;
95106
ncclNetPlugin_v5 = ucxRmaPlugin_v5;
96107
break;
97108
#endif
98109
default:
110+
ncclNetPlugin_v8 = ibPlugin_v8;
99111
ncclNetPlugin_v7 = ibPlugin_v7;
100112
ncclNetPlugin_v6 = ibPlugin_v6;
101113
ncclNetPlugin_v5 = ibPlugin_v5;
@@ -104,6 +116,13 @@ static void pluginSetup()
104116

105117
}
106118

119+
ncclResult_t pluginInit_v8(ncclDebugLogger_t logFunction) {
120+
pluginLogFunction = logFunction;
121+
pluginSetup();
122+
INFO(NCCL_INIT|NCCL_NET, "P2P plugin %s", ncclNetPlugin_v8.name);
123+
return ncclNetPlugin_v8.init(logFunction);
124+
}
125+
107126
ncclResult_t pluginInit_v7(ncclDebugLogger_t logFunction) {
108127
pluginLogFunction = logFunction;
109128
pluginSetup();
@@ -176,6 +195,7 @@ ncclResult_t nccl_p2p_ib_get_properties(nccl_ib_dev_t *devs, int dev, ncclNetPro
176195
props->ptrSupport |= NCCL_PTR_CUDA; // GDR support via nv_peermem
177196
INFO(NCCL_NET,"NET/IB : GPU Direct RDMA (nvidia-peermem) enabled for HCA %d '%s", dev, devs[dev].devName);
178197
}
198+
props->regIsGlobal = 1;
179199
if (p2p_plugin == NCCL_P2P_IB && nccl_p2p_dmabuf_support(dev) == ncclSuccess) {
180200
props->ptrSupport |= NCCL_PTR_DMABUF; // GDR support via DMA-BUF
181201
INFO(NCCL_NET,"NET/IB : GPU Direct RDMA (DMABUF) enabled for HCA %d '%s", dev, devs[dev].devName);

0 commit comments

Comments
 (0)