Skip to content

Commit

Permalink
support User Buffes
Browse files Browse the repository at this point in the history
  • Loading branch information
bureddy committed Jan 9, 2024
1 parent 17e4f1e commit b4f48dd
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
4 changes: 2 additions & 2 deletions include/p2p_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ typedef enum nccl_p2p_plugin {

struct ncclIbMr {
uintptr_t addr;
int pages;
size_t pages;
int refs;
struct ibv_mr *mr;
};
Expand All @@ -59,7 +59,7 @@ struct ncclIbRequest {
int offset;
} send;
struct {
int sizes[NCCL_NET_IB_MAX_RECVS];
int* sizes;
} recv;
};
};
Expand Down
80 changes: 52 additions & 28 deletions src/ib_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,20 @@ struct ncclIbSendFifo {
uint64_t idx;
};

struct ncclIbRemSizesFifo {
int elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
uint64_t fifoTail;
uint64_t addr;
uint32_t rkey;
uint32_t flags;
struct ibv_mr* mr;
struct ibv_sge sge;
};

struct ncclIbSendComm {
struct ncclIbVerbs verbs;
struct ncclIbSendFifo fifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
struct ncclIbRemSizesFifo remSizesFifo;
uint64_t fifoHead;
struct ncclIbRequest* fifoReqs[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
struct ibv_send_wr wrs[NCCL_NET_IB_MAX_RECVS+1];
Expand Down Expand Up @@ -225,12 +236,14 @@ struct ncclIbRemFifo {
struct ncclIbRecvComm {
struct ncclIbVerbs verbs;
struct ncclIbRemFifo remFifo;
int sizesFifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
struct ncclSocket sock;
int ready;
struct ibv_qp* qps[NCCL_IB_MAX_QPS];
int nqps;
int qpIndex;
struct ncclIbGpuFlush gpuFlush;
struct ibv_mr* sizesFifoMr;
struct ncclIbGidInfo gidInfo;
};

Expand Down Expand Up @@ -463,6 +476,12 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
NCCLCHECK(ncclIbRtsQp(qp));
}

// Retain remote sizes fifo info and prepare RDMA ops
comm->remSizesFifo.rkey = remQpInfo.fifoRkey;
comm->remSizesFifo.addr = remQpInfo.fifoAddr;
NCCLCHECK(wrap_ibv_reg_mr(&comm->remSizesFifo.mr, comm->verbs.pd, &comm->remSizesFifo.elems, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ));
comm->remSizesFifo.sge.lkey = comm->remSizesFifo.mr->lkey;

comm->ready = 1;
stage->state = ncclIbCommStateConnected;
stage->offset = 0;
Expand Down Expand Up @@ -599,6 +618,11 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
);
qpInfo.mtu=remQpInfo.mtu;

// Prepare sizes fifo
NCCLCHECK(wrap_ibv_reg_mr(&rComm->sizesFifoMr, rComm->verbs.pd, rComm->sizesFifo, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ));
qpInfo.fifoRkey = rComm->sizesFifoMr->rkey;
qpInfo.fifoAddr = (uint64_t)rComm->sizesFifo;

stage->state = ncclIbCommStateSend;
stage->offset = 0;
if (stage->buffer) free(stage->buffer);
Expand Down Expand Up @@ -669,7 +693,7 @@ ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, ui
ncclResult_t res;
pthread_mutex_lock(&ncclIbDevs[verbs->dev].lock);
for (int slot=0; /*true*/; slot++) {
if (slot == cache->population) { // didn't find in cache
if (slot == cache->population || addr < cache->slots[slot].addr) { // didn't find in cache
if (cache->population == cache->capacity) { // must grow cache
cache->capacity = cache->capacity < 32 ? 32 : 2*cache->capacity;
NCCLCHECKGOTO(ncclRealloc((void **)&cache->slots, sizeof(struct ncclIbMr)*cache->population, sizeof(struct ncclIbMr)*cache->capacity), res, returning);
Expand All @@ -691,16 +715,17 @@ ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, ui
}
}
TRACE(NCCL_INIT,"regAddr %llx size %lld rkey %x fd %d", (unsigned long long)addr, (long long)pages*pageSize, mr->rkey, fd);
cache->population += 1;
if (slot != cache->population) memmove(cache->slots+slot+1, cache->slots+slot, (cache->population-slot)*sizeof(struct ncclIbMr));
cache->slots[slot].addr = addr;
cache->slots[slot].pages = pages;
cache->slots[slot].refs = 1;
cache->slots[slot].mr = mr;
cache->population += 1;
*mhandle = (void*)mr;
res = ncclSuccess;
goto returning;
}
else if (cache->slots[slot].addr == addr && cache->slots[slot].pages == pages) {
} else if ((addr >= cache->slots[slot].addr) &&
((addr-cache->slots[slot].addr)/pageSize+pages) <= cache->slots[slot].pages) {
cache->slots[slot].refs += 1;
*mhandle = (void*)cache->slots[slot].mr;
res = ncclSuccess;
Expand Down Expand Up @@ -778,13 +803,10 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
if (nreqs == 1) {
immData = reqs[0]->send.size;
} else {
if (nreqs > 32) {
WARN("Cannot store sizes of %d requests in a 32-bits field", nreqs);
return ncclInternalError;
}
for (int r=0; r<nreqs; r++) {
immData |= (reqs[r]->send.size ? 1 : 0) << r;
}
int* sizes = comm->remSizesFifo.elems[slot];
for (int r=0; r<nreqs; r++) sizes[r] = reqs[r]->send.size;
comm->remSizesFifo.sge.addr = (uint64_t)sizes;
comm->remSizesFifo.sge.length = nreqs*sizeof(int);
}

struct ibv_send_wr* lastWr = comm->wrs+nreqs-1;
Expand All @@ -794,6 +816,13 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
// completion.
lastWr++;
memset(lastWr, 0, sizeof(struct ibv_send_wr));
if (nreqs > 1) {
// Write remote sizes Fifo
lastWr->wr.rdma.remote_addr = comm->remSizesFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(int);
lastWr->wr.rdma.rkey = comm->remSizesFifo.rkey;
lastWr->num_sge = 1;
lastWr->sg_list = &comm->remSizesFifo.sge;
}
}
lastWr->wr_id = wr_id;
lastWr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
Expand Down Expand Up @@ -855,16 +884,9 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh
for (int r=0; r<nreqs; r++) {
if (reqs[r] != NULL || slots[r].tag != tag) continue;

// Sanity checks to catch user collective call count/size mismatches
if (size > slots[r].size) {
char line[SOCKET_NAME_MAXLEN + 1];
union ncclSocketAddress addr;
ncclSocketGetAddr(&comm->sock, &addr);
WARN("NET/IB : req %d/%d tag %x peer %s collective mismatch error, local size %d remote size %d",
r, nreqs, tag, ncclSocketToString(&comm->sock.addr, line, 1), size, slots[r].size);
return ncclInvalidUsage;
} // plus any potential programming errors
else if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkey == 0) {
if (size > slots[r].size) size = slots[r].size;
// Sanity checks
if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkey == 0) {
char line[SOCKET_NAME_MAXLEN + 1];
union ncclSocketAddress addr;
ncclSocketGetAddr(&comm->sock, &addr);
Expand Down Expand Up @@ -912,6 +934,8 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int
memset(&wr, 0, sizeof(wr));

int slot = comm->remFifo.fifoTail%MAX_REQUESTS;
req->recv.sizes = comm->sizesFifo[slot];
for (int i=0; i<n; i++) req->recv.sizes[i] = 0;
struct ncclIbSendFifo* localElem = comm->remFifo.elems[slot];

for (int i=0; i<n; i++) {
Expand Down Expand Up @@ -979,7 +1003,6 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* ta
req->sock = &comm->sock;
req->nreqs = n;
if (comm->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) req->gidInfo = &comm->gidInfo;
for (int i=0; i<n; i++) req->recv.sizes[i] = 0;

struct ibv_recv_wr wr;
memset(&wr, 0, sizeof(wr));
Expand Down Expand Up @@ -1051,6 +1074,10 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
if (sizes && r->type == NCCL_NET_IB_REQ_RECV) {
for (int i=0; i<r->nreqs; i++) sizes[i] = r->recv.sizes[i];
}
if (sizes && r->type == NCCL_NET_IB_REQ_SEND) {
sizes[0] = r->send.size;
}

NCCLCHECK(ncclIbFreeRequest(r));
return ncclSuccess;
}
Expand Down Expand Up @@ -1091,12 +1118,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
} else {
if (req && wc->opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
if (req->type != NCCL_NET_IB_REQ_RECV) return ncclInternalError;
if (req->nreqs > 1) {
// In the case of a multi recv, we only set sizes to 0 or 1.
for (int i=0; i<req->nreqs; i++) {
req->recv.sizes[i] = (wc->imm_data >> i) & 0x1;
}
} else {
if (req->nreqs == 1) {
req->recv.sizes[0] += wc->imm_data;
}
}
Expand All @@ -1113,6 +1135,7 @@ ncclResult_t ncclIbCloseSend(void* sendComm) {
for (int q=0; q<comm->nqps; q++)
if (comm->qps[q] != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->qps[q]));
if (comm->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->fifoMr));
if (comm->remSizesFifo.mr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->remSizesFifo.mr));
NCCLCHECK(ncclIbDestroyVerbs(&comm->verbs));
free(comm);
}
Expand All @@ -1131,6 +1154,7 @@ ncclResult_t ncclIbCloseRecv(void* recvComm) {
if (comm->gpuFlush.hostMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->gpuFlush.hostMr));
}
if (comm->remFifo.mr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->remFifo.mr));
if (comm->sizesFifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->sizesFifoMr));
NCCLCHECK(ncclIbDestroyVerbs(&comm->verbs));
free(comm);
}
Expand Down

0 comments on commit b4f48dd

Please sign in to comment.