@@ -35,7 +35,8 @@ static int ncclNIbDevs = -1;
35
35
pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER ;
36
36
int ncclIbRelaxedOrderingEnabled = 0 ;
37
37
38
- NCCL_PARAM (IbGidIndex , "IB_GID_INDEX" , 0 );
38
+ NCCL_PARAM (IbGidIndex , "IB_GID_INDEX" , -1 );
39
+ NCCL_PARAM (IbRoceVersionNum , "IB_ROCE_VERSION_NUM" , 2 );
39
40
NCCL_PARAM (IbIsGlobal , "IB_IS_GLOBAL" , 0 );
40
41
NCCL_PARAM (IbTimeout , "IB_TIMEOUT" , 18 );
41
42
NCCL_PARAM (IbRetryCnt , "IB_RETRY_CNT" , 7 );
@@ -63,6 +64,211 @@ int ncclIbRelaxedOrderingCapable(void) {
63
64
return 1 ;
64
65
}
65
66
67
+ static sa_family_t envIbAddrFamily (void ) {
68
+ sa_family_t family = AF_INET ;
69
+ const char * env = ncclGetEnv ("NCCL_IB_ADDR_FAMILY" );
70
+ if (env == NULL || strlen (env ) == 0 ) {
71
+ return family ;
72
+ }
73
+
74
+ INFO (NCCL_ENV , "NCCL_IB_ADDR_FAMILY set by environment to %s" , env );
75
+
76
+ if (strcmp (env , "AF_INET" ) == 0 ) {
77
+ family = AF_INET ;
78
+ } else if (strcmp (env , "AF_INET6" ) == 0 ) {
79
+ family = AF_INET6 ;
80
+ }
81
+
82
+ return family ;
83
+ }
84
+
85
+ static void * envIbAddrRange (sa_family_t af , int * mask ) {
86
+ * mask = 0 ;
87
+ static struct in_addr addr ;
88
+ static struct in6_addr addr6 ;
89
+ void * ret = (af == AF_INET ) ? (void * )& addr : (void * )& addr6 ;
90
+
91
+ const char * env = ncclGetEnv ("NCCL_IB_ADDR_RANGE" );
92
+ if (NULL == env || strlen (env ) == 0 ) {
93
+ return NULL ;
94
+ }
95
+
96
+ INFO (NCCL_ENV , "NCCL_IB_ADDR_RANGE set by environment to %s" , env );
97
+
98
+ char addrString [128 ] = { 0 };
99
+ snprintf (addrString , 128 , "%s" , env );
100
+ char * addrStrPtr = addrString ;
101
+ char * maskStrPtr = strstr (addrString , "/" ) + 1 ;
102
+ if (NULL == maskStrPtr ) {
103
+ return NULL ;
104
+ }
105
+ * (maskStrPtr - 1 ) = '\0' ;
106
+
107
+ if (inet_pton (af , addrStrPtr , ret ) == 0 ) {
108
+ WARN ("NET/IB: Ip address '%s' is invalid for family %s, ignoring address" , addrStrPtr , (af == AF_INET ) ? "AF_INET" : "AF_INET6" );
109
+ return NULL ;
110
+ }
111
+
112
+ * mask = (int )strtol (maskStrPtr , NULL , 10 );
113
+ if (af == AF_INET && * mask > 32 ) {
114
+ WARN ("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask" , * mask , (af == AF_INET ) ? "AF_INET" : "AF_INET6" );
115
+ * mask = 0 ;
116
+ ret = NULL ;
117
+ } else if (af == AF_INET6 && * mask > 128 ) {
118
+ WARN ("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask" , * mask , (af == AF_INET ) ? "AF_INET" : "AF_INET6" );
119
+ * mask = 0 ;
120
+ ret = NULL ;
121
+ }
122
+
123
+ return ret ;
124
+ }
125
+
126
+ static sa_family_t getGidAddrFamily (union ibv_gid * gid ) {
127
+ const struct in6_addr * a = (struct in6_addr * )gid -> raw ;
128
+ bool isIpV4Mapped = ((a -> s6_addr32 [0 ] | a -> s6_addr32 [1 ]) | (a -> s6_addr32 [2 ] ^ htonl (0x0000ffff ))) == 0UL ;
129
+ bool isIpV4MappedMulticast = (a -> s6_addr32 [0 ] == htonl (0xff0e0000 ) && ((a -> s6_addr32 [1 ] | (a -> s6_addr32 [2 ] ^ htonl (0x0000ffff ))) == 0UL ));
130
+ return (isIpV4Mapped || isIpV4MappedMulticast ) ? AF_INET : AF_INET6 ;
131
+ }
132
+
133
+ static bool matchGidAddrPrefix (sa_family_t af , void * prefix , int prefixlen , union ibv_gid * gid ) {
134
+ struct in_addr * base = NULL ;
135
+ struct in6_addr * base6 = NULL ;
136
+ struct in6_addr * addr6 = NULL ;;
137
+ if (af == AF_INET ) {
138
+ base = (struct in_addr * )prefix ;
139
+ } else {
140
+ base6 = (struct in6_addr * )prefix ;
141
+ }
142
+ addr6 = (struct in6_addr * )gid -> raw ;
143
+
144
+ #define NETMASK (bits ) (htonl(0xffffffff ^ ((1 << (32 - bits)) - 1)))
145
+
146
+ int i = 0 ;
147
+ while (prefixlen > 0 && i < 4 ) {
148
+ if (af == AF_INET ) {
149
+ int mask = NETMASK (prefixlen );
150
+ if ((base -> s_addr & mask ) ^ (addr6 -> s6_addr32 [3 ] & mask )) {
151
+ break ;
152
+ }
153
+ prefixlen = 0 ;
154
+ break ;
155
+ } else {
156
+ if (prefixlen >= 32 ) {
157
+ if (base6 -> s6_addr32 [i ] ^ addr6 -> s6_addr32 [i ]) {
158
+ break ;
159
+ }
160
+ prefixlen -= 32 ;
161
+ ++ i ;
162
+ } else {
163
+ int mask = NETMASK (prefixlen );
164
+ if ((base6 -> s6_addr32 [i ] & mask ) ^ (addr6 -> s6_addr32 [i ] & mask )) {
165
+ break ;
166
+ }
167
+ prefixlen = 0 ;
168
+ }
169
+ }
170
+ }
171
+
172
+ return (prefixlen == 0 ) ? true : false;
173
+ }
174
+
175
+ static bool configuredGid (union ibv_gid * gid ) {
176
+ const struct in6_addr * a = (struct in6_addr * )gid -> raw ;
177
+ int trailer = (a -> s6_addr32 [1 ] | a -> s6_addr32 [2 ] | a -> s6_addr32 [3 ]);
178
+ if (((a -> s6_addr32 [0 ] | trailer ) == 0UL ) || ((a -> s6_addr32 [0 ] == htonl (0xfe800000 )) && (trailer == 0UL ))) {
179
+ return false;
180
+ }
181
+ return true;
182
+ }
183
+
184
+ static bool linkLocalGid (union ibv_gid * gid ) {
185
+ const struct in6_addr * a = (struct in6_addr * )gid -> raw ;
186
+ if (a -> s6_addr32 [0 ] == htonl (0xfe800000 ) && a -> s6_addr32 [1 ] == 0UL ) {
187
+ return true;
188
+ }
189
+ return false;
190
+ }
191
+
192
+ static bool validGid (union ibv_gid * gid ) {
193
+ return (configuredGid (gid ) && !linkLocalGid (gid ));
194
+ }
195
+
196
+ static ncclResult_t ncclIbRoceGetVersionNum (const char * deviceName , int portNum , int gidIndex , int * version ) {
197
+ char gidRoceVerStr [16 ] = { 0 };
198
+ char roceTypePath [PATH_MAX ] = { 0 };
199
+ sprintf (roceTypePath , "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d" , deviceName , portNum , gidIndex );
200
+
201
+ int fd = open (roceTypePath , O_RDONLY );
202
+ if (fd == -1 ) {
203
+ return ncclSystemError ;
204
+ }
205
+ int ret = read (fd , gidRoceVerStr , 15 );
206
+ close (fd );
207
+
208
+ if (ret == -1 ) {
209
+ return ncclSystemError ;
210
+ }
211
+
212
+ if (strlen (gidRoceVerStr )) {
213
+ if (strncmp (gidRoceVerStr , "IB/RoCE v1" , strlen ("IB/RoCE v1" )) == 0 || strncmp (gidRoceVerStr , "RoCE v1" , strlen ("RoCE v1" )) == 0 ) {
214
+ * version = 1 ;
215
+ } else if (strncmp (gidRoceVerStr , "RoCE v2" , strlen ("RoCE v2" )) == 0 ) {
216
+ * version = 2 ;
217
+ }
218
+ }
219
+
220
+ return ncclSuccess ;
221
+ }
222
+
223
+ static ncclResult_t ncclUpdateGidIndex (struct ibv_context * context , uint8_t portNum , sa_family_t af , void * prefix , int prefixlen , int roceVer , int gidIndexCandidate , int * gidIndex ) {
224
+ union ibv_gid gid , gidCandidate ;
225
+ NCCLCHECK (wrap_ibv_query_gid (context , portNum , * gidIndex , & gid ));
226
+ NCCLCHECK (wrap_ibv_query_gid (context , portNum , gidIndexCandidate , & gidCandidate ));
227
+
228
+ sa_family_t usrFam = af ;
229
+ sa_family_t gidFam = getGidAddrFamily (& gid );
230
+ sa_family_t gidCandidateFam = getGidAddrFamily (& gidCandidate );
231
+ bool gidCandidateMatchSubnet = matchGidAddrPrefix (usrFam , prefix , prefixlen , & gidCandidate );
232
+
233
+ if (gidCandidateFam != gidFam && gidCandidateFam == usrFam && gidCandidateMatchSubnet ) {
234
+ * gidIndex = gidIndexCandidate ;
235
+ } else {
236
+ if (gidCandidateFam != usrFam || !validGid (& gidCandidate ) || !gidCandidateMatchSubnet ) {
237
+ return ncclSuccess ;
238
+ }
239
+ int usrRoceVer = roceVer ;
240
+ int gidRoceVerNum , gidRoceVerNumCandidate ;
241
+ const char * deviceName = wrap_ibv_get_device_name (context -> device );
242
+ NCCLCHECK (ncclIbRoceGetVersionNum (deviceName , portNum , * gidIndex , & gidRoceVerNum ));
243
+ NCCLCHECK (ncclIbRoceGetVersionNum (deviceName , portNum , gidIndexCandidate , & gidRoceVerNumCandidate ));
244
+ if ((gidRoceVerNum != gidRoceVerNumCandidate || !validGid (& gid )) && gidRoceVerNumCandidate == usrRoceVer ) {
245
+ * gidIndex = gidIndexCandidate ;
246
+ }
247
+ }
248
+
249
+ return ncclSuccess ;
250
+ }
251
+
252
+ static ncclResult_t ncclIbGetGidIndex (struct ibv_context * context , uint8_t portNum , int gidTblLen , int * gidIndex ) {
253
+ * gidIndex = ncclParamIbGidIndex ();
254
+ if (* gidIndex >= 0 ) {
255
+ return ncclSuccess ;
256
+ }
257
+
258
+ sa_family_t userAddrFamily = envIbAddrFamily ();
259
+ int userRoceVersion = ncclParamIbRoceVersionNum ();
260
+ int prefixlen ;
261
+ void * prefix = envIbAddrRange (userAddrFamily , & prefixlen );
262
+
263
+ * gidIndex = 0 ;
264
+ for (int gidIndexNext = 1 ; gidIndexNext < gidTblLen ; ++ gidIndexNext ) {
265
+ NCCLCHECK (ncclUpdateGidIndex (context , portNum , userAddrFamily , prefix , prefixlen , userRoceVersion , gidIndexNext , gidIndex ));
266
+ }
267
+
268
+ return ncclSuccess ;
269
+ }
270
+
271
+
66
272
NCCL_PARAM (IbDisable , "IBEXT_DISABLE" , 0 );
67
273
NCCL_PARAM (IbMergeVfs , "IB_MERGE_VFS" , 1 );
68
274
NCCL_PARAM (IbMergeNics , "IB_MERGE_NICS" , 1 );
@@ -373,7 +579,7 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base,
373
579
return ncclSuccess ;
374
580
}
375
581
376
- ncclResult_t ncclIbRtrQp (struct ibv_qp * qp , uint32_t dest_qp_num , struct ncclIbDevInfo * info ) {
582
+ ncclResult_t ncclIbRtrQp (struct ibv_qp * qp , uint8_t sGidIndex , uint32_t dest_qp_num , struct ncclIbDevInfo * info ) {
377
583
struct ibv_qp_attr qpAttr ;
378
584
memset (& qpAttr , 0 , sizeof (struct ibv_qp_attr ));
379
585
qpAttr .qp_state = IBV_QPS_RTR ;
@@ -392,7 +598,7 @@ ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint32_t dest_qp_num, struct ncclIbD
392
598
qpAttr .ah_attr .grh .dgid .global .subnet_prefix = info -> spn ;
393
599
qpAttr .ah_attr .grh .dgid .global .interface_id = info -> iid ;
394
600
qpAttr .ah_attr .grh .flow_label = 0 ;
395
- qpAttr .ah_attr .grh .sgid_index = ncclParamIbGidIndex () ;
601
+ qpAttr .ah_attr .grh .sgid_index = sGidIndex ;
396
602
qpAttr .ah_attr .grh .hop_limit = 255 ;
397
603
qpAttr .ah_attr .grh .traffic_class = ncclParamIbTc ();
398
604
}
@@ -514,7 +720,9 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
514
720
);
515
721
516
722
if (devInfo -> link_layer == IBV_LINK_LAYER_ETHERNET || devInfo -> is_global ) {
517
- NCCLCHECK (wrap_ibv_query_gid (ibDev -> context , ibDev -> portNum , ncclParamIbGidIndex (), & commDev -> base .gidInfo .localGid ));
723
+
724
+ NCCLCHECK (ncclIbGetGidIndex (ibDev -> context , ibDev -> portNum , ibDev -> portAttr .gid_tbl_len , & commDev -> base .gidInfo .localGidIndex ));
725
+ NCCLCHECK (wrap_ibv_query_gid (ibDev -> context , ibDev -> portNum , commDev -> base .gidInfo .localGidIndex , & commDev -> base .gidInfo .localGid ));
518
726
devInfo -> spn = commDev -> base .gidInfo .localGid .global .subnet_prefix ;
519
727
devInfo -> iid = commDev -> base .gidInfo .localGid .global .interface_id ;
520
728
}
@@ -532,9 +740,9 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
532
740
// Print just the QPs for this dev
533
741
if (comm -> base .qps [q ].devIndex == i )
534
742
INFO (NCCL_NET ,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x" ,
535
- comm -> base .ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev" , dev ,
536
- commDev -> base .ibDevN , ibDev -> portNum , meta .qpInfo [q ].qpn , devInfo -> mtu , meta .qpInfo [q ].ece_supported , meta .qpInfo [q ].ece .vendor_id , meta .qpInfo [q ].ece .options , meta .qpInfo [q ].ece .comp_mask , ncclParamIbGidIndex () ,
537
- devInfo -> spn , devInfo -> iid , devInfo -> fifoRkey , commDev -> fifoMr -> lkey );
743
+ comm -> base .ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev" , dev ,
744
+ commDev -> base .ibDevN , ibDev -> portNum , meta .qpInfo [q ].qpn , devInfo -> mtu , meta .qpInfo [q ].ece_supported , meta .qpInfo [q ].ece .vendor_id , meta .qpInfo [q ].ece .options , meta .qpInfo [q ].ece .comp_mask , ( int64_t ) commDev -> base . gidInfo . localGidIndex ,
745
+ devInfo -> spn , devInfo -> iid , devInfo -> fifoRkey , commDev -> fifoMr -> lkey );
538
746
}
539
747
}
540
748
}
@@ -602,12 +810,15 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
602
810
603
811
// Assign per-QP remDev
604
812
comm -> base .qps [q ].remDevIdx = remQpInfo -> devIndex ;
813
+ int devIndex = comm -> base .qps [q ].devIndex ;
814
+ ncclIbSendCommDev * commDev = comm -> devs + devIndex ;
815
+ uint8_t gidIndex = commDev -> base .gidInfo .localGidIndex ;
605
816
606
817
struct ibv_qp * qp = comm -> base .qps [q ].qp ;
607
818
if (remQpInfo -> ece_supported && remQpInfo -> ece_supported )
608
819
NCCLCHECK (wrap_ibv_set_ece (qp , & remQpInfo -> ece , & remQpInfo -> ece_supported ));
609
820
610
- NCCLCHECK (ncclIbRtrQp (qp , remQpInfo -> qpn , remDevInfo ));
821
+ NCCLCHECK (ncclIbRtrQp (qp , gidIndex , remQpInfo -> qpn , remDevInfo ));
611
822
NCCLCHECK (ncclIbRtsQp (qp ));
612
823
}
613
824
@@ -707,7 +918,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
707
918
ibDevN = mergedDev -> devs [i ];
708
919
NCCLCHECK (ncclIbInitCommDevBase (ibDevN , & rCommDev -> base ));
709
920
ibDev = ncclIbDevs + ibDevN ;
710
- NCCLCHECK (wrap_ibv_query_gid (ibDev -> context , ibDev -> portNum , ncclParamIbGidIndex (), & rCommDev -> base .gidInfo .localGid ));
921
+ NCCLCHECK (ncclIbGetGidIndex (ibDev -> context , ibDev -> portNum , ibDev -> portAttr .gid_tbl_len , & rCommDev -> base .gidInfo .localGidIndex ));
922
+ NCCLCHECK (wrap_ibv_query_gid (ibDev -> context , ibDev -> portNum , rCommDev -> base .gidInfo .localGidIndex , & rCommDev -> base .gidInfo .localGid ));
711
923
}
712
924
713
925
// Copy remDevInfo for things like remGidInfo, remFifoAddr, etc.
@@ -745,7 +957,7 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
745
957
if (meta .qpInfo [q ].ece_supported )
746
958
NCCLCHECK (wrap_ibv_query_ece (qp -> qp , & meta .qpInfo [q ].ece , & meta .qpInfo [q ].ece_supported ));
747
959
}
748
- NCCLCHECK (ncclIbRtrQp (qp -> qp , remMeta .qpInfo [q ].qpn , remDevInfo ));
960
+ NCCLCHECK (ncclIbRtrQp (qp -> qp , rCommDev -> base . gidInfo . localGidIndex , remMeta .qpInfo [q ].qpn , remDevInfo ));
749
961
NCCLCHECK (ncclIbRtsQp (qp -> qp ));
750
962
}
751
963
@@ -783,8 +995,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
783
995
#endif
784
996
);
785
997
devInfo .mtu = ibDev -> portAttr .active_mtu ;
786
- NCCLCHECK (ncclIbRtrQp (rCommDev -> gpuFlush .qp .qp , rCommDev -> gpuFlush .qp .qp -> qp_num , & devInfo ));
787
- NCCLCHECK (ncclIbRtsQp (rCommDev -> gpuFlush .qp .qp ));
998
+ NCCLCHECK (ncclIbRtrQp (rCommDev -> gpuFlush .qp .qp , rCommDev -> base . gidInfo . localGidIndex , rCommDev -> gpuFlush .qp .qp -> qp_num , & devInfo ));
999
+ NCCLCHECK (ncclIbRtsQp (rCommDev -> gpuFlush .qp .qp ));
788
1000
}
789
1001
790
1002
// Fill Handle
@@ -1431,7 +1643,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
1431
1643
return ncclInternalError ;
1432
1644
}
1433
1645
if (req -> nreqs == 1 ) {
1434
- req -> recv .sizes [0 ] + = wc -> imm_data ;
1646
+ req -> recv .sizes [0 ] = wc -> imm_data ;
1435
1647
}
1436
1648
}
1437
1649
req -> events [i ]-- ;
0 commit comments