Skip to content

Commit

Permalink
sendrecv: add MR cache to SENDRECV protocol
Browse files Browse the repository at this point in the history
This commit is making the SENDRECV protocol use the MR cache for memory
registrations.

Signed-off-by: Amedeo Sapio <[email protected]>
  • Loading branch information
AmedeoSapio committed Jul 9, 2024
1 parent c867732 commit b45a27f
Showing 1 changed file with 95 additions and 37 deletions.
132 changes: 95 additions & 37 deletions src/nccl_ofi_sendrecv.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "nccl_ofi_tracepoint.h"
#include "nccl_ofi_math.h"
#include "nccl_ofi_pthread.h"
#include "nccl_ofi_mr.h"


static int selected_api_version = 0;
Expand Down Expand Up @@ -686,6 +687,56 @@ static int reg_mr_base(struct fid_domain *domain, struct fid_ep *ep,
(struct fid_mr **)mhandle);
}

static int dereg_mr_base_comm(struct fid_mr *mr_handle,
nccl_ofi_idpool_t *key_pool,
nccl_ofi_mr_cache_t *mr_cache)
{
int ret = 0;

if (OFI_LIKELY(mr_handle == NULL)) {
NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Null MR handle provided. Skipping deregisteration.");
goto exit;
}

if (mr_cache) {
/*
* Depending on the number of references on this handle and the
* cache itself, this call would either just decrement the
* refcnt, or delete the entry for this handle.
*/
nccl_net_ofi_mutex_lock(&mr_cache->lock);
ret = nccl_ofi_mr_cache_del_entry(mr_cache, (void *)mr_handle);
nccl_net_ofi_mutex_unlock(&mr_cache->lock);
if (OFI_UNLIKELY(ret < 0)) {
NCCL_OFI_WARN("Failed to delete MR cache entry");
} else if (ret == 0) {
/* Entry must not be deregistered */
return ret;
}
}

if (key_pool->ids) {
uint64_t key = fi_mr_key(mr_handle);
if (OFI_UNLIKELY(key == FI_KEY_NOTAVAIL)) {
NCCL_OFI_WARN("Error retrieving MR key, leaking key");
} else {
ret = nccl_ofi_idpool_free_id(key_pool, key);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Error freeing MR key %"PRIu64", leaking key", key);
}
}
}

ret = fi_close((fid_t)mr_handle);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Unable to de-register memory. RC: %d, Error: %s",
ret, fi_strerror(-ret));
}

exit:
return ret;
}

static int reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, void *data,
size_t size, int type, void **mhandle)
{
Expand All @@ -706,11 +757,51 @@ static int reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, void *data,
}
int dev_id = device->base.dev_id;

int ret = 0;
nccl_ofi_mr_cache_t *mr_cache = device->base.mr_cache;
void *ret_handle = NULL;

/*
* MR cache is locked between lookup and insert, to be sure we
* insert a missing entry
*/
nccl_net_ofi_mutex_lock(&mr_cache->lock);
ret_handle = nccl_ofi_mr_cache_lookup_entry(mr_cache, data, size);
if (ret_handle) {
/* Cache hit */
goto exit;
}
/* Cache miss */

nccl_ofi_idpool_t *key_pool = &device->key_pool;
struct fid_domain *domain;
domain = get_domain_from_endpoint(ep);
return reg_mr_base(domain, ep->ofi_ep, key_pool,
dev_id, data, size, type, mhandle);
ret = reg_mr_base(domain, ep->ofi_ep, key_pool,
dev_id, data, size, type, &ret_handle);
if (OFI_UNLIKELY(ret != 0)) {
ret_handle = NULL;
goto exit;
}

ret = nccl_ofi_mr_cache_insert_entry(mr_cache,
data,
size,
ret_handle);
if (OFI_UNLIKELY(ret != 0)) {
/* MR cache insert failed. Deregister memory region without
* trying to delete MR cache entry.
*/
if (dereg_mr_base_comm((struct fid_mr *)ret_handle, key_pool, NULL) != 0) {
NCCL_OFI_WARN("Error deregistering memory region for addr %p", data);
}
ret_handle = NULL;
goto exit;
}

exit:
nccl_net_ofi_mutex_unlock(&mr_cache->lock);
*mhandle = ret_handle;
return ret;
}

static int reg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm, void *data,
Expand All @@ -725,39 +816,6 @@ static int reg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm, void *data,
return reg_mr_base_comm(&recv_comm->base, data, size, type, mhandle);
}

static int dereg_mr_base_comm(struct fid_mr *mr_handle,
nccl_ofi_idpool_t *key_pool,
int dev_id)
{
int ret = 0;

if (OFI_LIKELY(mr_handle == NULL)) {
NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Null MR handle provided. Skipping deregisteration.");
goto exit;
}

if (key_pool->ids) {
uint64_t key = fi_mr_key(mr_handle);
if (OFI_UNLIKELY(key == FI_KEY_NOTAVAIL)) {
NCCL_OFI_WARN("Error retrieving MR key, leaking key");
} else {
ret = nccl_ofi_idpool_free_id(key_pool, key);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Error freeing MR key %"PRIu64", leaking key", key);
}
}
}

ret = fi_close((fid_t)mr_handle);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Unable to de-register memory. RC: %d, Error: %s",
ret, fi_strerror(-ret));
}

exit:
return ret;
}

static int dereg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm,
nccl_net_ofi_mr_handle_t *mhandle)
{
Expand All @@ -777,7 +835,7 @@ static int dereg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm,
return -EINVAL;
}
struct fid_mr *mr_handle = (struct fid_mr *)mhandle;
return dereg_mr_base_comm(mr_handle, &device->key_pool, recv_comm->base.dev_id);
return dereg_mr_base_comm(mr_handle, &device->key_pool, device->base.mr_cache);
}

/*
Expand Down Expand Up @@ -1581,7 +1639,7 @@ static int dereg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm,

struct fid_mr *mr_handle = (struct fid_mr *)mhandle;
return dereg_mr_base_comm(mr_handle, &device->key_pool,
send_comm->base.dev_id);
device->base.mr_cache);
}

static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int tag,
Expand Down

0 comments on commit b45a27f

Please sign in to comment.