From f3afc3176318d404b2066d8fc5073a191181a3ae Mon Sep 17 00:00:00 2001 From: Amedeo Sapio Date: Thu, 4 Jul 2024 01:27:50 +0000 Subject: [PATCH] sendrecv: add MR cache to SENDRECV protocol This commit is making the SENDRECV protocol use the MR cache for memory registrations. Signed-off-by: Amedeo Sapio --- src/nccl_ofi_sendrecv.c | 132 +++++++++++++++++++++++++++++----------- 1 file changed, 95 insertions(+), 37 deletions(-) diff --git a/src/nccl_ofi_sendrecv.c b/src/nccl_ofi_sendrecv.c index 930152b80..1284da2a0 100644 --- a/src/nccl_ofi_sendrecv.c +++ b/src/nccl_ofi_sendrecv.c @@ -23,6 +23,7 @@ #include "nccl_ofi_tracepoint.h" #include "nccl_ofi_math.h" #include "nccl_ofi_pthread.h" +#include "nccl_ofi_mr.h" static int selected_api_version = 0; @@ -686,6 +687,56 @@ static int reg_mr_base(struct fid_domain *domain, struct fid_ep *ep, (struct fid_mr **)mhandle); } +static int dereg_mr_base_comm(struct fid_mr *mr_handle, + nccl_ofi_idpool_t *key_pool, + nccl_ofi_mr_cache_t *mr_cache) +{ + int ret = 0; + + if (OFI_LIKELY(mr_handle == NULL)) { + NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Null MR handle provided. Skipping deregisteration."); + goto exit; + } + + if (mr_cache) { + /* + * Depending on the number of references on this handle and the + * cache itself, this call would either just decrement the + * refcnt, or delete the entry for this handle. + */ + nccl_net_ofi_mutex_lock(&mr_cache->lock); + ret = nccl_ofi_mr_cache_del_entry(mr_cache, (void *)mr_handle); + nccl_net_ofi_mutex_unlock(&mr_cache->lock); + if (OFI_UNLIKELY(ret < 0)) { + NCCL_OFI_WARN("Failed to delete MR cache entry"); + } else if (ret == 0) { + /* Entry must not be deregistered */ + return ret; + } + } + + if (key_pool->ids) { + uint64_t key = fi_mr_key(mr_handle); + if (OFI_UNLIKELY(key == FI_KEY_NOTAVAIL)) { + NCCL_OFI_WARN("Error retrieving MR key, leaking key"); + } else { + ret = nccl_ofi_idpool_free_id(key_pool, key); + if (OFI_UNLIKELY(ret != 0)) { + NCCL_OFI_WARN("Error freeing MR key %"PRIu64", leaking key", key); + } + } + } + + ret = fi_close((fid_t)mr_handle); + if (OFI_UNLIKELY(ret != 0)) { + NCCL_OFI_WARN("Unable to de-register memory. RC: %d, Error: %s", + ret, fi_strerror(-ret)); + } + + exit: + return ret; +} + static int reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, void *data, size_t size, int type, void **mhandle) { @@ -706,11 +757,51 @@ static int reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, void *data, } int dev_id = device->base.dev_id; + int ret = 0; + nccl_ofi_mr_cache_t *mr_cache = device->base.mr_cache; + void *ret_handle = NULL; + + /* + * MR cache is locked between lookup and insert, to be sure we + * insert a missing entry + */ + nccl_net_ofi_mutex_lock(&mr_cache->lock); + ret_handle = nccl_ofi_mr_cache_lookup_entry(mr_cache, data, size); + if (ret_handle) { + /* Cache hit */ + goto exit; + } + /* Cache miss */ + nccl_ofi_idpool_t *key_pool = &device->key_pool; struct fid_domain *domain; domain = get_domain_from_endpoint(ep); - return reg_mr_base(domain, ep->ofi_ep, key_pool, - dev_id, data, size, type, mhandle); + ret = reg_mr_base(domain, ep->ofi_ep, key_pool, + dev_id, data, size, type, &ret_handle); + if (OFI_UNLIKELY(ret != 0)) { + ret_handle = NULL; + goto exit; + } + + ret = nccl_ofi_mr_cache_insert_entry(mr_cache, + data, + size, + ret_handle); + if (OFI_UNLIKELY(ret != 0)) { + /* MR cache insert failed. Deregister memory region without + * trying to delete MR cache entry. + */ + if (dereg_mr_base_comm((struct fid_mr *)ret_handle, key_pool, NULL) != 0) { + NCCL_OFI_WARN("Error deregistering memory region for addr %p", data); + } + ret_handle = NULL; + goto exit; + } + +exit: + nccl_net_ofi_mutex_unlock(&mr_cache->lock); + *mhandle = ret_handle; + return ret; } static int reg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm, void *data, @@ -725,39 +816,6 @@ static int reg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm, void *data, return reg_mr_base_comm(&recv_comm->base, data, size, type, mhandle); } -static int dereg_mr_base_comm(struct fid_mr *mr_handle, - nccl_ofi_idpool_t *key_pool, - int dev_id) -{ - int ret = 0; - - if (OFI_LIKELY(mr_handle == NULL)) { - NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Null MR handle provided. Skipping deregisteration."); - goto exit; - } - - if (key_pool->ids) { - uint64_t key = fi_mr_key(mr_handle); - if (OFI_UNLIKELY(key == FI_KEY_NOTAVAIL)) { - NCCL_OFI_WARN("Error retrieving MR key, leaking key"); - } else { - ret = nccl_ofi_idpool_free_id(key_pool, key); - if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Error freeing MR key %"PRIu64", leaking key", key); - } - } - } - - ret = fi_close((fid_t)mr_handle); - if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Unable to de-register memory. RC: %d, Error: %s", - ret, fi_strerror(-ret)); - } - - exit: - return ret; -} - static int dereg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm, nccl_net_ofi_mr_handle_t *mhandle) { @@ -777,7 +835,7 @@ static int dereg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm, return -EINVAL; } struct fid_mr *mr_handle = (struct fid_mr *)mhandle; - return dereg_mr_base_comm(mr_handle, &device->key_pool, recv_comm->base.dev_id); + return dereg_mr_base_comm(mr_handle, &device->key_pool, device->base.mr_cache); } /* @@ -1581,7 +1639,7 @@ static int dereg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm, struct fid_mr *mr_handle = (struct fid_mr *)mhandle; return dereg_mr_base_comm(mr_handle, &device->key_pool, - send_comm->base.dev_id); + device->base.mr_cache); } static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int tag,