Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
bureddy committed Jun 12, 2024
1 parent 424090b commit 72a90a0
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 53 deletions.
150 changes: 104 additions & 46 deletions src/ib_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER;
int ncclIbRelaxedOrderingEnabled = 0;

NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", -1);
NCCL_PARAM(IbRoutableFlidIbGidIndex, "IB_ROUTABLE_FLID_GID_INDEX", 1);
NCCL_PARAM(IbRoceVersionNum, "IB_ROCE_VERSION_NUM", 2);
NCCL_PARAM(IbIsGlobal, "IB_IS_GLOBAL", 0);
NCCL_PARAM(IbTimeout, "IB_TIMEOUT", 18);
Expand All @@ -46,6 +47,7 @@ NCCL_PARAM(IbSl, "IB_SL", 0);
NCCL_PARAM(IbTc, "IB_TC", 0);
NCCL_PARAM(IbArThreshold, "IB_AR_THRESHOLD", 8192);
NCCL_PARAM(IbPciRelaxedOrdering, "IB_PCI_RELAXED_ORDERING", 2);
NCCL_PARAM(IbFifoTc, "IB_FIFO_TC", 0);

static pthread_t ncclIbAsyncThread;

Expand Down Expand Up @@ -249,7 +251,38 @@ static ncclResult_t ncclUpdateGidIndex(struct ibv_context* context, uint8_t port
return ncclSuccess;
}

static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, int gidTblLen, int *gidIndex) {
// GID Format
// global: | 64b - subnet-prefix | 64b - EUI |
// raw : | 10b fixed | 22b 0 | 16b FLID | 16b subnet-prefix | 64b - EUI |
static uint16_t ncclIbExtractLocalSubnetPrefix(uint64_t subnet_prefix)
{
return (be64toh(subnet_prefix) & 0xffff);
}

static int ncclIbExtractFlid (union ibv_gid *gid)
{
return ntohs(*((uint16_t*)((uintptr_t)(gid->raw) + 4)));
}

static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, struct ibv_port_attr* portAttr, int *gidIndex) {
int gidTblLen = portAttr->gid_tbl_len;

//for IB, choose GID Index that will have routable FLID if present
if (portAttr->link_layer == IBV_LINK_LAYER_INFINIBAND) {
union ibv_gid gid;
int routableGidIndex = ncclParamIbRoutableFlidIbGidIndex();
if (routableGidIndex < gidTblLen) {
NCCLCHECK(wrap_ibv_query_gid(context, portNum, routableGidIndex, &gid));
if (ncclIbExtractFlid(&gid) != 0) {
*gidIndex = routableGidIndex;
return ncclSuccess;
}
}
*gidIndex = 0;
return ncclSuccess;
}

//for ROCE
*gidIndex = ncclParamIbGidIndex();
if (*gidIndex >= 0) {
return ncclSuccess;
Expand Down Expand Up @@ -342,12 +375,13 @@ typedef struct ncclIbDevInfo {
uint8_t link_layer;
uint8_t is_global;

// For RoCE and IB GRH
uint64_t spn;
uint64_t iid;
//For RoCE and IB GRH & Rounter
union ibv_gid gid;

// FIFO RDMA info
uint32_t fifoRkey;

//remote dev info
union ibv_gid remoteGid;
} ncclIbDevInfo;

Expand Down Expand Up @@ -579,29 +613,53 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base,
return ncclSuccess;
}

ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint8_t sGidIndex, uint32_t dest_qp_num, struct ncclIbDevInfo* info) {
ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint32_t dest_qp_num, struct ncclIbDevInfo* info, bool override_tc) {
struct ibv_qp_attr qpAttr;
int same_subnet;
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
qpAttr.qp_state = IBV_QPS_RTR;
qpAttr.path_mtu = info->mtu;
qpAttr.dest_qp_num = dest_qp_num;
qpAttr.rq_psn = 0;
qpAttr.max_dest_rd_atomic = 1;
qpAttr.min_rnr_timer = 12;
qpAttr.ah_attr.is_global = 0;
qpAttr.ah_attr.dlid = info->lid;
qpAttr.ah_attr.sl = ncclParamIbSl();
qpAttr.ah_attr.src_path_bits = 0;
qpAttr.ah_attr.port_num = info->ib_port;
if (info->link_layer == IBV_LINK_LAYER_ETHERNET || info->is_global) {
if (info->link_layer == IBV_LINK_LAYER_ETHERNET) {
qpAttr.ah_attr.is_global = 1;
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->spn;
qpAttr.ah_attr.grh.dgid.global.interface_id = info->iid;
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->gid.global.subnet_prefix;
qpAttr.ah_attr.grh.dgid.global.interface_id = info->gid.global.interface_id;
qpAttr.ah_attr.grh.flow_label = 0;
qpAttr.ah_attr.grh.sgid_index = sGidIndex;
qpAttr.ah_attr.grh.sgid_index = sGidInfo->localGidIndex;
qpAttr.ah_attr.grh.hop_limit = 255;
qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc();
if(ncclParamIbFifoTc() && override_tc) {
qpAttr.ah_attr.grh.traffic_class = ncclParamIbFifoTc();
} else {
qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc();
}
} else {
same_subnet = (ncclIbExtractLocalSubnetPrefix(sGidInfo->localGid.global.subnet_prefix) ==
ncclIbExtractLocalSubnetPrefix(info->gid.global.subnet_prefix));
qpAttr.ah_attr.is_global = 0;
qpAttr.ah_attr.dlid = info->lid;
if (!same_subnet || info->is_global) {
if (!same_subnet) {
uint16_t flid = ncclIbExtractFlid(&info->gid);
if (flid == 0) {
WARN("Warning: remote FLID configured as zero even when endpoints are on different subnets, using dlid as fallback");
qpAttr.ah_attr.dlid = info->lid;
} else {
qpAttr.ah_attr.dlid = ncclIbExtractFlid(&info->gid);
}
}
qpAttr.ah_attr.is_global = 1;
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->gid.global.subnet_prefix;
qpAttr.ah_attr.grh.dgid.global.interface_id = info->gid.global.interface_id;
qpAttr.ah_attr.grh.sgid_index = sGidInfo->localGidIndex;
qpAttr.ah_attr.grh.hop_limit = 255;
}
}
qpAttr.ah_attr.sl = ncclParamIbSl();
qpAttr.ah_attr.src_path_bits = 0;
qpAttr.ah_attr.port_num = info->ib_port;
NCCLCHECK(wrap_ibv_modify_qp(qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER));
return ncclSuccess;
}
Expand Down Expand Up @@ -711,29 +769,28 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
NCCLCHECK(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ));
devInfo->fifoRkey = commDev->fifoMr->rkey;

// RoCE support
devInfo->link_layer = commDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer;
devInfo->is_global = (ncclParamIbIsGlobal()
#if HAVE_DECL_IBV_QPF_GRH_REQUIRED
|| (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED)
|| (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED)
#endif
);

// Pack local GID info
devInfo->link_layer = commDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer;
NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &ibDev->portAttr, &commDev->base.gidInfo.localGidIndex));
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid));
devInfo->gid.global.subnet_prefix = commDev->base.gidInfo.localGid.global.subnet_prefix;
devInfo->gid.global.interface_id = commDev->base.gidInfo.localGid.global.interface_id;

if (devInfo->link_layer == IBV_LINK_LAYER_ETHERNET || devInfo->is_global) {

NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &commDev->base.gidInfo.localGidIndex));
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid));
devInfo->spn = commDev->base.gidInfo.localGid.global.subnet_prefix;
devInfo->iid = commDev->base.gidInfo.localGid.global.interface_id;
}

// info logging
if (devInfo->link_layer == IBV_LINK_LAYER_INFINIBAND) { // IB
for (int q = 0; q < comm->base.nqps; q++) {
// Print just the QPs for this dev
if (comm->base.qps[q].devIndex == i)
INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d fifoRkey=0x%x fifoLkey=0x%x",
INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d subnet-prefix %lu FLID %d fifoRkey=0x%x fifoLkey=0x%x",
comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev",
dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid, devInfo->fifoRkey, commDev->fifoMr->lkey);
dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid,
devInfo->gid.global.subnet_prefix, ncclIbExtractFlid(&devInfo->gid), devInfo->fifoRkey, commDev->fifoMr->lkey);
}
} else { // RoCE
for (int q = 0; q < comm->base.nqps; q++) {
Expand All @@ -742,7 +799,7 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x",
comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev,
commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, (int64_t)commDev->base.gidInfo.localGidIndex,
devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey);
devInfo->gid.global.subnet_prefix, devInfo->gid.global.interface_id, devInfo->fifoRkey, commDev->fifoMr->lkey);
}
}
}
Expand Down Expand Up @@ -792,8 +849,8 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
// Copy remDevInfo for things like remGidInfo, remFifoAddr, etc.
for (int i = 0; i < remMeta.ndevs; i++) {
comm->base.remDevs[i] = remMeta.devs[i];
comm->base.remDevs[i].remoteGid.global.interface_id = comm->base.remDevs[i].iid;
comm->base.remDevs[i].remoteGid.global.subnet_prefix = comm->base.remDevs[i].spn;
comm->base.remDevs[i].remoteGid.global.interface_id = comm->base.remDevs[i].gid.global.interface_id;
comm->base.remDevs[i].remoteGid.global.subnet_prefix = comm->base.remDevs[i].gid.global.subnet_prefix;

// Retain remote sizes fifo info and prepare RDMA ops
comm->remSizesFifo.rkeys[i] = remMeta.devs[i].fifoRkey;
Expand All @@ -812,13 +869,12 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
comm->base.qps[q].remDevIdx = remQpInfo->devIndex;
int devIndex = comm->base.qps[q].devIndex;
ncclIbSendCommDev* commDev = comm->devs + devIndex;
uint8_t gidIndex = commDev->base.gidInfo.localGidIndex;

struct ibv_qp* qp = comm->base.qps[q].qp;
if (remQpInfo->ece_supported && remQpInfo->ece_supported)
NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported));

NCCLCHECK(ncclIbRtrQp(qp, gidIndex, remQpInfo->qpn, remDevInfo));
NCCLCHECK(ncclIbRtrQp(qp, &commDev->base.gidInfo, remQpInfo->qpn, remDevInfo, false));
NCCLCHECK(ncclIbRtsQp(qp));
}

Expand Down Expand Up @@ -918,15 +974,15 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
ibDevN = mergedDev->devs[i];
NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &rCommDev->base));
ibDev = ncclIbDevs + ibDevN;
NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &rCommDev->base.gidInfo.localGidIndex));
NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &ibDev->portAttr, &rCommDev->base.gidInfo.localGidIndex));
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, rCommDev->base.gidInfo.localGidIndex, &rCommDev->base.gidInfo.localGid));
}

// Copy remDevInfo for things like remGidInfo, remFifoAddr, etc.
for (int i = 0; i < remMeta.ndevs; i++) {
rComm->base.remDevs[i] = remMeta.devs[i];
rComm->base.remDevs[i].remoteGid.global.interface_id = rComm->base.remDevs[i].iid;
rComm->base.remDevs[i].remoteGid.global.subnet_prefix = rComm->base.remDevs[i].spn;
rComm->base.remDevs[i].remoteGid.global.interface_id = rComm->base.remDevs[i].gid.global.interface_id;
rComm->base.remDevs[i].remoteGid.global.subnet_prefix = rComm->base.remDevs[i].gid.global.subnet_prefix;
}

// Stripe QP creation across merged devs
Expand Down Expand Up @@ -957,7 +1013,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
if (meta.qpInfo[q].ece_supported)
NCCLCHECK(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported));
}
NCCLCHECK(ncclIbRtrQp(qp->qp, rCommDev->base.gidInfo.localGidIndex, remMeta.qpInfo[q].qpn, remDevInfo));
bool override_tc = (q == 0) ? true : false;
NCCLCHECK(ncclIbRtrQp(qp->qp, &rCommDev->base.gidInfo, remMeta.qpInfo[q].qpn, remDevInfo, override_tc));
NCCLCHECK(ncclIbRtsQp(qp->qp));
}

Expand Down Expand Up @@ -987,24 +1044,24 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
devInfo.lid = ibDev->portAttr.lid;
devInfo.link_layer = ibDev->portAttr.link_layer;
devInfo.ib_port = ibDev->portNum;
devInfo.spn = rCommDev->base.gidInfo.localGid.global.subnet_prefix;
devInfo.iid = rCommDev->base.gidInfo.localGid.global.interface_id;
devInfo.gid.global.subnet_prefix = rCommDev->base.gidInfo.localGid.global.subnet_prefix;
devInfo.gid.global.interface_id = rCommDev->base.gidInfo.localGid.global.interface_id;
devInfo.is_global = (ncclParamIbIsGlobal()
#if HAVE_DECL_IBV_QPF_GRH_REQUIRED
|| (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED)
#endif
);
devInfo.mtu = ibDev->portAttr.active_mtu;
NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->base.gidInfo.localGidIndex, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo));
NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp));
NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, &rCommDev->base.gidInfo, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo, false));
NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp));
}

// Fill Handle
meta.devs[i].lid = ibDev->portAttr.lid;
meta.devs[i].link_layer = rCommDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer;
meta.devs[i].ib_port = ibDev->portNum;
meta.devs[i].spn = rCommDev->base.gidInfo.localGid.global.subnet_prefix;
meta.devs[i].iid = rCommDev->base.gidInfo.localGid.global.interface_id;
meta.devs[i].gid.global.subnet_prefix = rCommDev->base.gidInfo.localGid.global.subnet_prefix;
meta.devs[i].gid.global.interface_id = rCommDev->base.gidInfo.localGid.global.interface_id;
meta.devs[i].is_global = (ncclParamIbIsGlobal()
#if HAVE_DECL_IBV_QPF_GRH_REQUIRED
|| (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED)
Expand Down Expand Up @@ -1612,9 +1669,10 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
}

char line[SOCKET_NAME_MAXLEN+1];
WARN("NET/IB : Got completion from peer %s with status=%d opcode=%d len=%d vendor err %d (%s)%s%s%s%s",
char *hcaName = r->devBases[i]->pd->context->device->name;
WARN("NET/IB: Got completion from peer %s with status=%d opcode=%d len=%d vendor err %d (%s)%s%s%s%s hca %s",
ncclSocketToString(&addr, line, 1), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, reqTypeStr[r->type],
localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString);
localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString, hcaName);
return ncclRemoteError;
}

Expand All @@ -1624,7 +1682,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {

#ifdef ENABLE_TRACE
char line[SOCKET_NAME_MAXLEN+1];
TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%d wr_id=%d r=%p type=%d events={%d,%d}, i=%d",
TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%d wr_id=%ld r=%p type=%d events={%d,%d}, i=%d",
ncclSocketToString(&addr, line, 1), wc->status, wc->opcode,wc->byte_len, wc->wr_id, req, req->type, req->events[0], req->events[1], i);
#endif
if (req->type == NCCL_NET_IB_REQ_SEND) {
Expand Down
22 changes: 20 additions & 2 deletions src/p2p_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,10 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
int nUserIfs = parseStringList(userIbEnv, userIfs, MAX_IB_DEVS);

if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) { ret = ncclInternalError; goto fail; }

// Should NCCL merge multi-port devices into one?
int mergeNics;
mergeNics = ncclParamIbMergeNics();
build_ib_list:
for (int d=0; d<nIbDevs; d++) {
struct ibv_context * context;
if (ncclSuccess != wrap_ibv_open_device(&context, devices[d]) || context == NULL) {
Expand Down Expand Up @@ -398,7 +401,7 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
}

int mergedDev = ncclNMergedIbDevs;
if (ncclParamIbMergeNics()) {
if (mergeNics) {
mergedDev = ncclIbFindMatchingDev(ncclNIbDevs);
}

Expand All @@ -425,6 +428,21 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
}
if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { ret = ncclInternalError; goto fail; }
}

// Detect if there are both multi-port and single-port NICs in the system. If so, disable port merging and build the list again
if (mergeNics) {
for (int d = 0; d < ncclNMergedIbDevs; d++) {
if (ncclIbMergedDevs[d].ndevs != ncclIbMergedDevs[0].ndevs) {
INFO(NCCL_NET, "Detected a mix of single and multiple-port NICs. Force-disabling NCCL_IB_MERGE_NICS");
mergeNics = 0;
ncclNIbDevs = 0;
ncclNMergedIbDevs = 0;
memset(ncclIbMergedDevs, 0, sizeof(ncclIbMergedDevs));
goto build_ib_list;
}
}
}

if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) { ret = ncclInternalError; goto fail; };
}
if (ncclNIbDevs == 0) {
Expand Down
10 changes: 5 additions & 5 deletions src/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr
}
}
(*offset) += bytes;
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) {
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) {
INFO(NCCL_NET, "socketProgressOpt: abort called");
return ncclInternalError;
}
Expand Down Expand Up @@ -624,12 +624,12 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock) {
do {
NCCLCHECK(socketProgressState(sock));
} while (sock->asyncFlag == 0 &&
(sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED) == 0) &&
(sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE) == 0) &&
(sock->state == ncclSocketStateConnecting ||
sock->state == ncclSocketStateConnectPolling ||
sock->state == ncclSocketStateConnected));

if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError;
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) return ncclInternalError;

switch (sock->state) {
case ncclSocketStateConnecting:
Expand Down Expand Up @@ -671,11 +671,11 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen
do {
NCCLCHECKGOTO(socketProgressState(sock), ret, exit);
} while (sock->asyncFlag == 0 &&
(sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED) == 0) &&
(sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE) == 0) &&
(sock->state == ncclSocketStateAccepting ||
sock->state == ncclSocketStateAccepted));

if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError;
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) return ncclInternalError;

switch (sock->state) {
case ncclSocketStateAccepting:
Expand Down

0 comments on commit 72a90a0

Please sign in to comment.