Skip to content

Commit 7b9afc6

Browse files
committed
Sync IB changes with NCCL v2.21.5-1
1 parent 9cdd572 commit 7b9afc6

File tree

3 files changed

+237
-21
lines changed

3 files changed

+237
-21
lines changed

include/p2p_plugin.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@ struct ncclIbRequest {
7777
struct ncclIbGidInfo {
7878
uint8_t link_layer;
7979
union ibv_gid localGid;
80+
int32_t localGidIndex;
8081
};
8182

8283
typedef struct ncclIbNetCommDevBase {
8384
int ibDevN;
8485
struct ibv_pd* pd;
8586
struct ibv_cq* cq;
86-
uint64_t pad[1];
87+
uint64_t pad[2];
8788
struct ncclIbGidInfo gidInfo;
8889
} ncclIbNetCommDevBase;
8990

src/ib_plugin.c

+225-13
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ static int ncclNIbDevs = -1;
3535
pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER;
3636
int ncclIbRelaxedOrderingEnabled = 0;
3737

38-
NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", 0);
38+
NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", -1);
39+
NCCL_PARAM(IbRoceVersionNum, "IB_ROCE_VERSION_NUM", 2);
3940
NCCL_PARAM(IbIsGlobal, "IB_IS_GLOBAL", 0);
4041
NCCL_PARAM(IbTimeout, "IB_TIMEOUT", 18);
4142
NCCL_PARAM(IbRetryCnt, "IB_RETRY_CNT", 7);
@@ -63,6 +64,211 @@ int ncclIbRelaxedOrderingCapable(void) {
6364
return 1;
6465
}
6566

67+
static sa_family_t envIbAddrFamily(void) {
68+
sa_family_t family = AF_INET;
69+
const char* env = ncclGetEnv("NCCL_IB_ADDR_FAMILY");
70+
if (env == NULL || strlen(env) == 0) {
71+
return family;
72+
}
73+
74+
INFO(NCCL_ENV, "NCCL_IB_ADDR_FAMILY set by environment to %s", env);
75+
76+
if (strcmp(env, "AF_INET") == 0) {
77+
family = AF_INET;
78+
} else if (strcmp(env, "AF_INET6") == 0) {
79+
family = AF_INET6;
80+
}
81+
82+
return family;
83+
}
84+
85+
static void* envIbAddrRange(sa_family_t af, int* mask) {
86+
*mask = 0;
87+
static struct in_addr addr;
88+
static struct in6_addr addr6;
89+
void *ret = (af == AF_INET) ? (void *)&addr : (void *)&addr6;
90+
91+
const char* env = ncclGetEnv("NCCL_IB_ADDR_RANGE");
92+
if (NULL == env || strlen(env) == 0) {
93+
return NULL;
94+
}
95+
96+
INFO(NCCL_ENV, "NCCL_IB_ADDR_RANGE set by environment to %s", env);
97+
98+
char addrString[128] = { 0 };
99+
snprintf(addrString, 128, "%s", env);
100+
char *addrStrPtr = addrString;
101+
char *maskStrPtr = strstr(addrString, "/") + 1;
102+
if (NULL == maskStrPtr) {
103+
return NULL;
104+
}
105+
*(maskStrPtr - 1) = '\0';
106+
107+
if (inet_pton(af, addrStrPtr, ret) == 0) {
108+
WARN("NET/IB: Ip address '%s' is invalid for family %s, ignoring address", addrStrPtr, (af == AF_INET) ? "AF_INET" : "AF_INET6");
109+
return NULL;
110+
}
111+
112+
*mask = (int)strtol(maskStrPtr, NULL, 10);
113+
if (af == AF_INET && *mask > 32) {
114+
WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6");
115+
*mask = 0;
116+
ret = NULL;
117+
} else if (af == AF_INET6 && *mask > 128) {
118+
WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6");
119+
*mask = 0;
120+
ret = NULL;
121+
}
122+
123+
return ret;
124+
}
125+
126+
static sa_family_t getGidAddrFamily(union ibv_gid* gid) {
127+
const struct in6_addr *a = (struct in6_addr *)gid->raw;
128+
bool isIpV4Mapped = ((a->s6_addr32[0] | a->s6_addr32[1]) | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL;
129+
bool isIpV4MappedMulticast = (a->s6_addr32[0] == htonl(0xff0e0000) && ((a->s6_addr32[1] | (a->s6_addr32[2] ^ htonl(0x0000ffff))) == 0UL));
130+
return (isIpV4Mapped || isIpV4MappedMulticast) ? AF_INET : AF_INET6;
131+
}
132+
133+
static bool matchGidAddrPrefix(sa_family_t af, void* prefix, int prefixlen, union ibv_gid* gid) {
134+
struct in_addr *base = NULL;
135+
struct in6_addr *base6 = NULL;
136+
struct in6_addr *addr6 = NULL;;
137+
if (af == AF_INET) {
138+
base = (struct in_addr *)prefix;
139+
} else {
140+
base6 = (struct in6_addr *)prefix;
141+
}
142+
addr6 = (struct in6_addr *)gid->raw;
143+
144+
#define NETMASK(bits) (htonl(0xffffffff ^ ((1 << (32 - bits)) - 1)))
145+
146+
int i = 0;
147+
while (prefixlen > 0 && i < 4) {
148+
if (af == AF_INET) {
149+
int mask = NETMASK(prefixlen);
150+
if ((base->s_addr & mask) ^ (addr6->s6_addr32[3] & mask)) {
151+
break;
152+
}
153+
prefixlen = 0;
154+
break;
155+
} else {
156+
if (prefixlen >= 32) {
157+
if (base6->s6_addr32[i] ^ addr6->s6_addr32[i]) {
158+
break;
159+
}
160+
prefixlen -= 32;
161+
++i;
162+
} else {
163+
int mask = NETMASK(prefixlen);
164+
if ((base6->s6_addr32[i] & mask) ^ (addr6->s6_addr32[i] & mask)) {
165+
break;
166+
}
167+
prefixlen = 0;
168+
}
169+
}
170+
}
171+
172+
return (prefixlen == 0) ? true : false;
173+
}
174+
175+
static bool configuredGid(union ibv_gid* gid) {
176+
const struct in6_addr *a = (struct in6_addr *)gid->raw;
177+
int trailer = (a->s6_addr32[1] | a->s6_addr32[2] | a->s6_addr32[3]);
178+
if (((a->s6_addr32[0] | trailer) == 0UL) || ((a->s6_addr32[0] == htonl(0xfe800000)) && (trailer == 0UL))) {
179+
return false;
180+
}
181+
return true;
182+
}
183+
184+
static bool linkLocalGid(union ibv_gid* gid) {
185+
const struct in6_addr *a = (struct in6_addr *)gid->raw;
186+
if (a->s6_addr32[0] == htonl(0xfe800000) && a->s6_addr32[1] == 0UL) {
187+
return true;
188+
}
189+
return false;
190+
}
191+
192+
static bool validGid(union ibv_gid* gid) {
193+
return (configuredGid(gid) && !linkLocalGid(gid));
194+
}
195+
196+
static ncclResult_t ncclIbRoceGetVersionNum(const char* deviceName, int portNum, int gidIndex, int* version) {
197+
char gidRoceVerStr[16] = { 0 };
198+
char roceTypePath[PATH_MAX] = { 0 };
199+
sprintf(roceTypePath, "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d", deviceName, portNum, gidIndex);
200+
201+
int fd = open(roceTypePath, O_RDONLY);
202+
if (fd == -1) {
203+
return ncclSystemError;
204+
}
205+
int ret = read(fd, gidRoceVerStr, 15);
206+
close(fd);
207+
208+
if (ret == -1) {
209+
return ncclSystemError;
210+
}
211+
212+
if (strlen(gidRoceVerStr)) {
213+
if (strncmp(gidRoceVerStr, "IB/RoCE v1", strlen("IB/RoCE v1")) == 0 || strncmp(gidRoceVerStr, "RoCE v1", strlen("RoCE v1")) == 0) {
214+
*version = 1;
215+
} else if (strncmp(gidRoceVerStr, "RoCE v2", strlen("RoCE v2")) == 0) {
216+
*version = 2;
217+
}
218+
}
219+
220+
return ncclSuccess;
221+
}
222+
223+
static ncclResult_t ncclUpdateGidIndex(struct ibv_context* context, uint8_t portNum, sa_family_t af, void* prefix, int prefixlen, int roceVer, int gidIndexCandidate, int* gidIndex) {
224+
union ibv_gid gid, gidCandidate;
225+
NCCLCHECK(wrap_ibv_query_gid(context, portNum, *gidIndex, &gid));
226+
NCCLCHECK(wrap_ibv_query_gid(context, portNum, gidIndexCandidate, &gidCandidate));
227+
228+
sa_family_t usrFam = af;
229+
sa_family_t gidFam = getGidAddrFamily(&gid);
230+
sa_family_t gidCandidateFam = getGidAddrFamily(&gidCandidate);
231+
bool gidCandidateMatchSubnet = matchGidAddrPrefix(usrFam, prefix, prefixlen, &gidCandidate);
232+
233+
if (gidCandidateFam != gidFam && gidCandidateFam == usrFam && gidCandidateMatchSubnet) {
234+
*gidIndex = gidIndexCandidate;
235+
} else {
236+
if (gidCandidateFam != usrFam || !validGid(&gidCandidate) || !gidCandidateMatchSubnet) {
237+
return ncclSuccess;
238+
}
239+
int usrRoceVer = roceVer;
240+
int gidRoceVerNum, gidRoceVerNumCandidate;
241+
const char* deviceName = wrap_ibv_get_device_name(context->device);
242+
NCCLCHECK(ncclIbRoceGetVersionNum(deviceName, portNum, *gidIndex, &gidRoceVerNum));
243+
NCCLCHECK(ncclIbRoceGetVersionNum(deviceName, portNum, gidIndexCandidate, &gidRoceVerNumCandidate));
244+
if ((gidRoceVerNum != gidRoceVerNumCandidate || !validGid(&gid)) && gidRoceVerNumCandidate == usrRoceVer) {
245+
*gidIndex = gidIndexCandidate;
246+
}
247+
}
248+
249+
return ncclSuccess;
250+
}
251+
252+
static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, int gidTblLen, int *gidIndex) {
253+
*gidIndex = ncclParamIbGidIndex();
254+
if (*gidIndex >= 0) {
255+
return ncclSuccess;
256+
}
257+
258+
sa_family_t userAddrFamily = envIbAddrFamily();
259+
int userRoceVersion = ncclParamIbRoceVersionNum();
260+
int prefixlen;
261+
void *prefix = envIbAddrRange(userAddrFamily, &prefixlen);
262+
263+
*gidIndex = 0;
264+
for (int gidIndexNext = 1; gidIndexNext < gidTblLen; ++gidIndexNext) {
265+
NCCLCHECK(ncclUpdateGidIndex(context, portNum, userAddrFamily, prefix, prefixlen, userRoceVersion, gidIndexNext, gidIndex));
266+
}
267+
268+
return ncclSuccess;
269+
}
270+
271+
66272
NCCL_PARAM(IbDisable, "IBEXT_DISABLE", 0);
67273
NCCL_PARAM(IbMergeVfs, "IB_MERGE_VFS", 1);
68274
NCCL_PARAM(IbMergeNics, "IB_MERGE_NICS", 1);
@@ -373,7 +579,7 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base,
373579
return ncclSuccess;
374580
}
375581

376-
ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbDevInfo* info) {
582+
ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint8_t sGidIndex, uint32_t dest_qp_num, struct ncclIbDevInfo* info) {
377583
struct ibv_qp_attr qpAttr;
378584
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
379585
qpAttr.qp_state = IBV_QPS_RTR;
@@ -392,7 +598,7 @@ ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbD
392598
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->spn;
393599
qpAttr.ah_attr.grh.dgid.global.interface_id = info->iid;
394600
qpAttr.ah_attr.grh.flow_label = 0;
395-
qpAttr.ah_attr.grh.sgid_index = ncclParamIbGidIndex();
601+
qpAttr.ah_attr.grh.sgid_index = sGidIndex;
396602
qpAttr.ah_attr.grh.hop_limit = 255;
397603
qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc();
398604
}
@@ -514,7 +720,9 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
514720
);
515721

516722
if (devInfo->link_layer == IBV_LINK_LAYER_ETHERNET || devInfo->is_global) {
517-
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &commDev->base.gidInfo.localGid));
723+
724+
NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &commDev->base.gidInfo.localGidIndex));
725+
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid));
518726
devInfo->spn = commDev->base.gidInfo.localGid.global.subnet_prefix;
519727
devInfo->iid = commDev->base.gidInfo.localGid.global.interface_id;
520728
}
@@ -532,9 +740,9 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
532740
// Print just the QPs for this dev
533741
if (comm->base.qps[q].devIndex == i)
534742
INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x",
535-
comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev,
536-
commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, ncclParamIbGidIndex(),
537-
devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey);
743+
comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev,
744+
commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, (int64_t)commDev->base.gidInfo.localGidIndex,
745+
devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey);
538746
}
539747
}
540748
}
@@ -602,12 +810,15 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
602810

603811
// Assign per-QP remDev
604812
comm->base.qps[q].remDevIdx = remQpInfo->devIndex;
813+
int devIndex = comm->base.qps[q].devIndex;
814+
ncclIbSendCommDev* commDev = comm->devs + devIndex;
815+
uint8_t gidIndex = commDev->base.gidInfo.localGidIndex;
605816

606817
struct ibv_qp* qp = comm->base.qps[q].qp;
607818
if (remQpInfo->ece_supported && remQpInfo->ece_supported)
608819
NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported));
609820

610-
NCCLCHECK(ncclIbRtrQp(qp, remQpInfo->qpn, remDevInfo));
821+
NCCLCHECK(ncclIbRtrQp(qp, gidIndex, remQpInfo->qpn, remDevInfo));
611822
NCCLCHECK(ncclIbRtsQp(qp));
612823
}
613824

@@ -707,7 +918,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
707918
ibDevN = mergedDev->devs[i];
708919
NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &rCommDev->base));
709920
ibDev = ncclIbDevs + ibDevN;
710-
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, ncclParamIbGidIndex(), &rCommDev->base.gidInfo.localGid));
921+
NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &rCommDev->base.gidInfo.localGidIndex));
922+
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, rCommDev->base.gidInfo.localGidIndex, &rCommDev->base.gidInfo.localGid));
711923
}
712924

713925
// Copy remDevInfo for things like remGidInfo, remFifoAddr, etc.
@@ -745,7 +957,7 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
745957
if (meta.qpInfo[q].ece_supported)
746958
NCCLCHECK(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported));
747959
}
748-
NCCLCHECK(ncclIbRtrQp(qp->qp, remMeta.qpInfo[q].qpn, remDevInfo));
960+
NCCLCHECK(ncclIbRtrQp(qp->qp, rCommDev->base.gidInfo.localGidIndex, remMeta.qpInfo[q].qpn, remDevInfo));
749961
NCCLCHECK(ncclIbRtsQp(qp->qp));
750962
}
751963

@@ -783,8 +995,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
783995
#endif
784996
);
785997
devInfo.mtu = ibDev->portAttr.active_mtu;
786-
NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo));
787-
NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp));
998+
NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->base.gidInfo.localGidIndex, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo));
999+
NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp));
7881000
}
7891001

7901002
// Fill Handle
@@ -1431,7 +1643,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
14311643
return ncclInternalError;
14321644
}
14331645
if (req->nreqs == 1) {
1434-
req->recv.sizes[0] += wc->imm_data;
1646+
req->recv.sizes[0] = wc->imm_data;
14351647
}
14361648
}
14371649
req->events[i]--;

src/p2p_plugin.c

+10-7
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ int ncclIbFindMatchingDev(int dev) {
276276
ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction)
277277
{
278278
int ncclNIbDevs = *num_devs;
279-
279+
ncclResult_t ret;
280280
pluginLogFunction = logFunction;
281281
if (ncclNIbDevs == -1) {
282282
pthread_mutex_lock(&nccl_p2p_lock);
@@ -287,7 +287,8 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
287287
ncclNSharpDevs = 0;
288288
if (ncclFindInterfaces(ncclIbIfName, ncclIbIfAddr, MAX_IF_NAME_SIZE, 1) != 1) {
289289
WARN("NET/IB : No IP interface found.");
290-
return ncclInternalError;
290+
ret = ncclInternalError;
291+
goto fail;
291292
}
292293

293294
// Detect IB cards
@@ -302,7 +303,7 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
302303
if (searchExact) userIbEnv++;
303304
int nUserIfs = parseStringList(userIbEnv, userIfs, MAX_IB_DEVS);
304305

305-
if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) return ncclInternalError;
306+
if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) { ret = ncclInternalError; goto fail; }
306307

307308
for (int d=0; d<nIbDevs; d++) {
308309
struct ibv_context * context;
@@ -314,7 +315,7 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
314315
struct ibv_device_attr devAttr;
315316
if (ncclSuccess != wrap_ibv_query_device(context, &devAttr)) {
316317
WARN("NET/IB : Unable to query device %s", devices[d]->name);
317-
if (ncclSuccess != wrap_ibv_close_device(context)) { return ncclInternalError; }
318+
if (ncclSuccess != wrap_ibv_close_device(context)) { ret = ncclInternalError; goto fail; }
318319
continue;
319320
}
320321
for (int port_num = 1; port_num <= devAttr.phys_port_cnt; port_num++) {
@@ -394,9 +395,9 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
394395
ncclNIbDevs++;
395396
nPorts++;
396397
}
397-
if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { return ncclInternalError; }
398+
if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { ret = ncclInternalError; goto fail; }
398399
}
399-
if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) { return ncclInternalError; };
400+
if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) { ret = ncclInternalError; goto fail; };
400401
}
401402
if (ncclNIbDevs == 0) {
402403
INFO(NCCL_INIT|NCCL_NET, "NET/IB : No device found.");
@@ -444,7 +445,9 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
444445
pthread_mutex_unlock(&nccl_p2p_lock);
445446
}
446447
return ncclSuccess;
447-
448+
fail:
449+
pthread_mutex_unlock(&nccl_p2p_lock);
450+
return ret;
448451
}
449452

450453
ncclResult_t nccl_p2p_ib_pci_path(ncclIbDev *devs, int num_devs, char* dev_name, char** path, int* real_port)

0 commit comments

Comments
 (0)