diff --git a/src/ib_plugin.c b/src/ib_plugin.c index 8040f09..9f56eea 100644 --- a/src/ib_plugin.c +++ b/src/ib_plugin.c @@ -36,6 +36,7 @@ pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER; int ncclIbRelaxedOrderingEnabled = 0; NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", -1); +NCCL_PARAM(IbRoutableFlidIbGidIndex, "IB_ROUTABLE_FLID_GID_INDEX", 1); NCCL_PARAM(IbRoceVersionNum, "IB_ROCE_VERSION_NUM", 2); NCCL_PARAM(IbIsGlobal, "IB_IS_GLOBAL", 0); NCCL_PARAM(IbTimeout, "IB_TIMEOUT", 18); @@ -46,6 +47,7 @@ NCCL_PARAM(IbSl, "IB_SL", 0); NCCL_PARAM(IbTc, "IB_TC", 0); NCCL_PARAM(IbArThreshold, "IB_AR_THRESHOLD", 8192); NCCL_PARAM(IbPciRelaxedOrdering, "IB_PCI_RELAXED_ORDERING", 2); +NCCL_PARAM(IbFifoTc, "IB_FIFO_TC", 0); static pthread_t ncclIbAsyncThread; @@ -249,7 +251,38 @@ static ncclResult_t ncclUpdateGidIndex(struct ibv_context* context, uint8_t port return ncclSuccess; } -static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, int gidTblLen, int *gidIndex) { +// GID Format +// global: | 64b - subnet-prefix | 64b - EUI | +// raw : | 10b fixed | 22b 0 | 16b FLID | 16b subnet-prefix | 64b - EUI | +static uint16_t ncclIbExtractLocalSubnetPrefix(uint64_t subnet_prefix) +{ + return (be64toh(subnet_prefix) & 0xffff); +} + +static int ncclIbExtractFlid (union ibv_gid *gid) +{ + return ntohs(*((uint16_t*)((uintptr_t)(gid->raw) + 4))); +} + +static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, struct ibv_port_attr* portAttr, int *gidIndex) { + int gidTblLen = portAttr->gid_tbl_len; + + //for IB, choose GID Index that will have routable FLID if present + if (portAttr->link_layer == IBV_LINK_LAYER_INFINIBAND) { + union ibv_gid gid; + int routableGidIndex = ncclParamIbRoutableFlidIbGidIndex(); + if (routableGidIndex < gidTblLen) { + NCCLCHECK(wrap_ibv_query_gid(context, portNum, routableGidIndex, &gid)); + if (ncclIbExtractFlid(&gid) != 0) { + *gidIndex = routableGidIndex; + return ncclSuccess; + } + } + *gidIndex = 0; + return ncclSuccess; + } + + //for ROCE *gidIndex = ncclParamIbGidIndex(); if (*gidIndex >= 0) { return ncclSuccess; @@ -342,12 +375,13 @@ typedef struct ncclIbDevInfo { uint8_t link_layer; uint8_t is_global; - // For RoCE and IB GRH - uint64_t spn; - uint64_t iid; + //For RoCE and IB GRH & Rounter + union ibv_gid gid; // FIFO RDMA info uint32_t fifoRkey; + + //remote dev info union ibv_gid remoteGid; } ncclIbDevInfo; @@ -579,8 +613,9 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, return ncclSuccess; } -ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint8_t sGidIndex, uint32_t dest_qp_num, struct ncclIbDevInfo* info) { +ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint32_t dest_qp_num, struct ncclIbDevInfo* info, bool override_tc) { struct ibv_qp_attr qpAttr; + int same_subnet; memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); qpAttr.qp_state = IBV_QPS_RTR; qpAttr.path_mtu = info->mtu; @@ -588,20 +623,43 @@ ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint8_t sGidIndex, uint32_t dest_qp_ qpAttr.rq_psn = 0; qpAttr.max_dest_rd_atomic = 1; qpAttr.min_rnr_timer = 12; - qpAttr.ah_attr.is_global = 0; - qpAttr.ah_attr.dlid = info->lid; - qpAttr.ah_attr.sl = ncclParamIbSl(); - qpAttr.ah_attr.src_path_bits = 0; - qpAttr.ah_attr.port_num = info->ib_port; - if (info->link_layer == IBV_LINK_LAYER_ETHERNET || info->is_global) { + if (info->link_layer == IBV_LINK_LAYER_ETHERNET) { qpAttr.ah_attr.is_global = 1; - qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->spn; - qpAttr.ah_attr.grh.dgid.global.interface_id = info->iid; + qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->gid.global.subnet_prefix; + qpAttr.ah_attr.grh.dgid.global.interface_id = info->gid.global.interface_id; qpAttr.ah_attr.grh.flow_label = 0; - qpAttr.ah_attr.grh.sgid_index = sGidIndex; + qpAttr.ah_attr.grh.sgid_index = sGidInfo->localGidIndex; qpAttr.ah_attr.grh.hop_limit = 255; - qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc(); + if(ncclParamIbFifoTc() && override_tc) { + qpAttr.ah_attr.grh.traffic_class = ncclParamIbFifoTc(); + } else { + qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc(); + } + } else { + same_subnet = (ncclIbExtractLocalSubnetPrefix(sGidInfo->localGid.global.subnet_prefix) == + ncclIbExtractLocalSubnetPrefix(info->gid.global.subnet_prefix)); + qpAttr.ah_attr.is_global = 0; + qpAttr.ah_attr.dlid = info->lid; + if (!same_subnet || info->is_global) { + if (!same_subnet) { + uint16_t flid = ncclIbExtractFlid(&info->gid); + if (flid == 0) { + WARN("Warning: remote FLID configured as zero even when endpoints are on different subnets, using dlid as fallback"); + qpAttr.ah_attr.dlid = info->lid; + } else { + qpAttr.ah_attr.dlid = ncclIbExtractFlid(&info->gid); + } + } + qpAttr.ah_attr.is_global = 1; + qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->gid.global.subnet_prefix; + qpAttr.ah_attr.grh.dgid.global.interface_id = info->gid.global.interface_id; + qpAttr.ah_attr.grh.sgid_index = sGidInfo->localGidIndex; + qpAttr.ah_attr.grh.hop_limit = 255; + } } + qpAttr.ah_attr.sl = ncclParamIbSl(); + qpAttr.ah_attr.src_path_bits = 0; + qpAttr.ah_attr.port_num = info->ib_port; NCCLCHECK(wrap_ibv_modify_qp(qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER)); return ncclSuccess; } @@ -711,29 +769,28 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet NCCLCHECK(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ)); devInfo->fifoRkey = commDev->fifoMr->rkey; - // RoCE support - devInfo->link_layer = commDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; devInfo->is_global = (ncclParamIbIsGlobal() #if HAVE_DECL_IBV_QPF_GRH_REQUIRED - || (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED) + || (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED) #endif ); + + // Pack local GID info + devInfo->link_layer = commDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; + NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &ibDev->portAttr, &commDev->base.gidInfo.localGidIndex)); + NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid)); + devInfo->gid.global.subnet_prefix = commDev->base.gidInfo.localGid.global.subnet_prefix; + devInfo->gid.global.interface_id = commDev->base.gidInfo.localGid.global.interface_id; - if (devInfo->link_layer == IBV_LINK_LAYER_ETHERNET || devInfo->is_global) { - - NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &commDev->base.gidInfo.localGidIndex)); - NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid)); - devInfo->spn = commDev->base.gidInfo.localGid.global.subnet_prefix; - devInfo->iid = commDev->base.gidInfo.localGid.global.interface_id; - } - + // info logging if (devInfo->link_layer == IBV_LINK_LAYER_INFINIBAND) { // IB for (int q = 0; q < comm->base.nqps; q++) { // Print just the QPs for this dev if (comm->base.qps[q].devIndex == i) - INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d fifoRkey=0x%x fifoLkey=0x%x", + INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d subnet-prefix %lu FLID %d fifoRkey=0x%x fifoLkey=0x%x", comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", - dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid, devInfo->fifoRkey, commDev->fifoMr->lkey); + dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid, + devInfo->gid.global.subnet_prefix, ncclIbExtractFlid(&devInfo->gid), devInfo->fifoRkey, commDev->fifoMr->lkey); } } else { // RoCE for (int q = 0; q < comm->base.nqps; q++) { @@ -742,7 +799,7 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x", comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, (int64_t)commDev->base.gidInfo.localGidIndex, - devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey); + devInfo->gid.global.subnet_prefix, devInfo->gid.global.interface_id, devInfo->fifoRkey, commDev->fifoMr->lkey); } } } @@ -792,8 +849,8 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet // Copy remDevInfo for things like remGidInfo, remFifoAddr, etc. for (int i = 0; i < remMeta.ndevs; i++) { comm->base.remDevs[i] = remMeta.devs[i]; - comm->base.remDevs[i].remoteGid.global.interface_id = comm->base.remDevs[i].iid; - comm->base.remDevs[i].remoteGid.global.subnet_prefix = comm->base.remDevs[i].spn; + comm->base.remDevs[i].remoteGid.global.interface_id = comm->base.remDevs[i].gid.global.interface_id; + comm->base.remDevs[i].remoteGid.global.subnet_prefix = comm->base.remDevs[i].gid.global.subnet_prefix; // Retain remote sizes fifo info and prepare RDMA ops comm->remSizesFifo.rkeys[i] = remMeta.devs[i].fifoRkey; @@ -812,13 +869,12 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet comm->base.qps[q].remDevIdx = remQpInfo->devIndex; int devIndex = comm->base.qps[q].devIndex; ncclIbSendCommDev* commDev = comm->devs + devIndex; - uint8_t gidIndex = commDev->base.gidInfo.localGidIndex; struct ibv_qp* qp = comm->base.qps[q].qp; if (remQpInfo->ece_supported && remQpInfo->ece_supported) NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported)); - NCCLCHECK(ncclIbRtrQp(qp, gidIndex, remQpInfo->qpn, remDevInfo)); + NCCLCHECK(ncclIbRtrQp(qp, &commDev->base.gidInfo, remQpInfo->qpn, remDevInfo, false)); NCCLCHECK(ncclIbRtsQp(qp)); } @@ -918,15 +974,15 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl ibDevN = mergedDev->devs[i]; NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &rCommDev->base)); ibDev = ncclIbDevs + ibDevN; - NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &rCommDev->base.gidInfo.localGidIndex)); + NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &ibDev->portAttr, &rCommDev->base.gidInfo.localGidIndex)); NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, rCommDev->base.gidInfo.localGidIndex, &rCommDev->base.gidInfo.localGid)); } // Copy remDevInfo for things like remGidInfo, remFifoAddr, etc. for (int i = 0; i < remMeta.ndevs; i++) { rComm->base.remDevs[i] = remMeta.devs[i]; - rComm->base.remDevs[i].remoteGid.global.interface_id = rComm->base.remDevs[i].iid; - rComm->base.remDevs[i].remoteGid.global.subnet_prefix = rComm->base.remDevs[i].spn; + rComm->base.remDevs[i].remoteGid.global.interface_id = rComm->base.remDevs[i].gid.global.interface_id; + rComm->base.remDevs[i].remoteGid.global.subnet_prefix = rComm->base.remDevs[i].gid.global.subnet_prefix; } // Stripe QP creation across merged devs @@ -957,7 +1013,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl if (meta.qpInfo[q].ece_supported) NCCLCHECK(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported)); } - NCCLCHECK(ncclIbRtrQp(qp->qp, rCommDev->base.gidInfo.localGidIndex, remMeta.qpInfo[q].qpn, remDevInfo)); + bool override_tc = (q == 0) ? true : false; + NCCLCHECK(ncclIbRtrQp(qp->qp, &rCommDev->base.gidInfo, remMeta.qpInfo[q].qpn, remDevInfo, override_tc)); NCCLCHECK(ncclIbRtsQp(qp->qp)); } @@ -987,24 +1044,24 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl devInfo.lid = ibDev->portAttr.lid; devInfo.link_layer = ibDev->portAttr.link_layer; devInfo.ib_port = ibDev->portNum; - devInfo.spn = rCommDev->base.gidInfo.localGid.global.subnet_prefix; - devInfo.iid = rCommDev->base.gidInfo.localGid.global.interface_id; + devInfo.gid.global.subnet_prefix = rCommDev->base.gidInfo.localGid.global.subnet_prefix; + devInfo.gid.global.interface_id = rCommDev->base.gidInfo.localGid.global.interface_id; devInfo.is_global = (ncclParamIbIsGlobal() #if HAVE_DECL_IBV_QPF_GRH_REQUIRED || (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED) #endif ); devInfo.mtu = ibDev->portAttr.active_mtu; - NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->base.gidInfo.localGidIndex, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo)); - NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp)); + NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, &rCommDev->base.gidInfo, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo, false)); + NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp)); } // Fill Handle meta.devs[i].lid = ibDev->portAttr.lid; meta.devs[i].link_layer = rCommDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; meta.devs[i].ib_port = ibDev->portNum; - meta.devs[i].spn = rCommDev->base.gidInfo.localGid.global.subnet_prefix; - meta.devs[i].iid = rCommDev->base.gidInfo.localGid.global.interface_id; + meta.devs[i].gid.global.subnet_prefix = rCommDev->base.gidInfo.localGid.global.subnet_prefix; + meta.devs[i].gid.global.interface_id = rCommDev->base.gidInfo.localGid.global.interface_id; meta.devs[i].is_global = (ncclParamIbIsGlobal() #if HAVE_DECL_IBV_QPF_GRH_REQUIRED || (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED) @@ -1612,9 +1669,10 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { } char line[SOCKET_NAME_MAXLEN+1]; - WARN("NET/IB : Got completion from peer %s with status=%d opcode=%d len=%d vendor err %d (%s)%s%s%s%s", + char *hcaName = r->devBases[i]->pd->context->device->name; + WARN("NET/IB: Got completion from peer %s with status=%d opcode=%d len=%d vendor err %d (%s)%s%s%s%s hca %s", ncclSocketToString(&addr, line, 1), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, reqTypeStr[r->type], - localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString); + localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString, hcaName); return ncclRemoteError; } @@ -1624,7 +1682,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN+1]; - TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%d wr_id=%d r=%p type=%d events={%d,%d}, i=%d", + TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%d wr_id=%ld r=%p type=%d events={%d,%d}, i=%d", ncclSocketToString(&addr, line, 1), wc->status, wc->opcode,wc->byte_len, wc->wr_id, req, req->type, req->events[0], req->events[1], i); #endif if (req->type == NCCL_NET_IB_REQ_SEND) { diff --git a/src/p2p_plugin.c b/src/p2p_plugin.c index 78074db..500150f 100644 --- a/src/p2p_plugin.c +++ b/src/p2p_plugin.c @@ -332,7 +332,10 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb int nUserIfs = parseStringList(userIbEnv, userIfs, MAX_IB_DEVS); if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) { ret = ncclInternalError; goto fail; } - + // Should NCCL merge multi-port devices into one? + int mergeNics; + mergeNics = ncclParamIbMergeNics(); +build_ib_list: for (int d=0; dabortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) { + if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) { INFO(NCCL_NET, "socketProgressOpt: abort called"); return ncclInternalError; } @@ -624,12 +624,12 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { do { NCCLCHECK(socketProgressState(sock)); } while (sock->asyncFlag == 0 && - (sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED) == 0) && + (sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE) == 0) && (sock->state == ncclSocketStateConnecting || sock->state == ncclSocketStateConnectPolling || sock->state == ncclSocketStateConnected)); - if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError; + if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) return ncclInternalError; switch (sock->state) { case ncclSocketStateConnecting: @@ -671,11 +671,11 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen do { NCCLCHECKGOTO(socketProgressState(sock), ret, exit); } while (sock->asyncFlag == 0 && - (sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED) == 0) && + (sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE) == 0) && (sock->state == ncclSocketStateAccepting || sock->state == ncclSocketStateAccepted)); - if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError; + if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) return ncclInternalError; switch (sock->state) { case ncclSocketStateAccepting: