Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update some files to enhance robustness. #1164

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions src/transport/net_ib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) {
ncclNMergedIbDevs = 0;
if (ncclFindInterfaces(ncclIbIfName, &ncclIbIfAddr, MAX_IF_NAME_SIZE, 1) != 1) {
WARN("NET/IB : No IP interface found.");
pthread_mutex_unlock(&ncclIbLock);
return ncclInternalError;
}

Expand All @@ -211,7 +212,10 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) {
if (searchExact) userIbEnv++;
int nUserIfs = parseStringList(userIbEnv, userIfs, MAX_IB_DEVS);

if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) return ncclInternalError;
if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) {
pthread_mutex_unlock(&ncclIbLock);
return ncclInternalError;
}

for (int d=0; d<nIbDevs && ncclNIbDevs<MAX_IB_DEVS; d++) {
struct ibv_context * context;
Expand All @@ -224,7 +228,10 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) {
memset(&devAttr, 0, sizeof(devAttr));
if (ncclSuccess != wrap_ibv_query_device(context, &devAttr)) {
WARN("NET/IB : Unable to query device %s", devices[d]->name);
if (ncclSuccess != wrap_ibv_close_device(context)) { return ncclInternalError; }
if (ncclSuccess != wrap_ibv_close_device(context)) {
pthread_mutex_unlock(&ncclIbLock);
return ncclInternalError;
}
continue;
}
for (int port_num = 1; port_num <= devAttr.phys_port_cnt; port_num++) {
Expand Down Expand Up @@ -295,9 +302,15 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) {
ncclNIbDevs++;
nPorts++;
}
if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { return ncclInternalError; }
if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) {
pthread_mutex_unlock(&ncclIbLock);
return ncclInternalError;
}
}
if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) { return ncclInternalError; };
if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) {
pthread_mutex_unlock(&ncclIbLock);
return ncclInternalError;
};
}
if (ncclNIbDevs == 0) {
INFO(NCCL_INIT|NCCL_NET, "NET/IB : No device found.");
Expand Down Expand Up @@ -1677,8 +1690,11 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
if (r->events[i]) {
NCCLCHECK(wrap_ibv_poll_cq(r->devBases[i]->cq, 4, wcs, &wrDone));
totalWrDone += wrDone;
if (wrDone == 0) { TIME_CANCEL(3); } else { TIME_STOP(3); }
if (wrDone == 0) continue;
if (wrDone == 0) {
TIME_CANCEL(3);
return ncclSuccess;
}
TIME_STOP(3);
for (int w=0; w<wrDone; w++) {
struct ibv_wc *wc = wcs+w;
if (wc->status != IBV_WC_SUCCESS) {
Expand Down Expand Up @@ -1811,4 +1827,3 @@ ncclNet_t ncclNetIb = {
NULL /* getDeviceMr */,
NULL /* irecvConsumed */
};

4 changes: 4 additions & 0 deletions src/transport/shm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ static ncclResult_t shmSendFree(struct ncclConnector* send) {
NCCLCHECK(ncclShmClose(resources->hostHandle));
NCCLCHECK(ncclShmClose(resources->remHandle));
free(resources);
send->transportResources = NULL;
}
return ncclSuccess;
}
Expand All @@ -220,6 +221,7 @@ static ncclResult_t shmRecvFree(struct ncclConnector* recv) {
NCCLCHECK(ncclShmClose(resources->hostHandle));
NCCLCHECK(ncclShmClose(resources->remHandle));
free(resources);
recv->transportResources = NULL;
}
return ncclSuccess;
}
Expand Down Expand Up @@ -271,6 +273,7 @@ static ncclResult_t shmSendProxyFree(struct ncclProxyConnection* connection, str
CUDACHECK(cudaEventDestroy(resources->events[i]));
}
free(connection->transportResources);
connection->transportResources = NULL;
}
return ncclSuccess;
}
Expand All @@ -286,6 +289,7 @@ static ncclResult_t shmRecvProxyFree(struct ncclProxyConnection* connection, str
CUDACHECK(cudaEventDestroy(resources->events[i]));
}
free(connection->transportResources);
connection->transportResources = NULL;
}
return ncclSuccess;
}
Expand Down