Skip to content

Commit

Permalink
prov/efa: Implement FI_CONTEXT2 in EFA Direct
Browse files Browse the repository at this point in the history
Store the completion flags and peer address in FI_CONTEXT2 and
retrieve later when writing cq.

Signed-off-by: Jessie Yang <[email protected]>
  • Loading branch information
jiaxiyan committed Jan 21, 2025
1 parent a9ebef2 commit 46651bd
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 17 deletions.
35 changes: 35 additions & 0 deletions prov/efa/src/efa.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,41 @@ struct efa_fabric {
#endif
};

struct efa_context {
uint64_t completion_flags;
fi_addr_t addr;
};

#if defined(static_assert)
static_assert(sizeof(struct efa_context) <= sizeof(struct fi_context2),
"efa_context must not be larger than fi_context2");
#endif

/**
* Prepare and return a pointer to an EFA context structure.
*
* @param context Pointer to the msg context.
* @param addr Peer address associated with the operation.
* @param flags Operation flags (e.g., FI_COMPLETION).
* @param completion_flags Completion flags reported in the cq entry.
* @return A pointer to an initialized EFA context structure,
* or NULL if context is invalid or FI_COMPLETION is not set.
*/
static inline struct efa_context *efa_fill_context(const void *context,
fi_addr_t addr,
uint64_t flags,
uint64_t completion_flags)
{
if (!context || !(flags & FI_COMPLETION))
return NULL;

struct efa_context *efa_context = (struct efa_context *) context;
efa_context->completion_flags = completion_flags;
efa_context->addr = addr;

return efa_context;
}

static inline
int efa_str_to_ep_addr(const char *node, const char *service, struct efa_ep_addr *addr)
{
Expand Down
8 changes: 5 additions & 3 deletions prov/efa/src/efa_cq.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ static void efa_cq_construct_cq_entry(struct ibv_cq_ex *ibv_cqx,
struct fi_cq_tagged_entry *entry)
{
entry->op_context = (void *)ibv_cqx->wr_id;
entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx));
if (ibv_cqx->wr_id)
entry->flags = ((struct efa_context *) ibv_cqx->wr_id)->completion_flags;
else
entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx));
entry->len = ibv_wc_read_byte_len(ibv_cqx);
entry->buf = NULL;
entry->data = 0;
Expand Down Expand Up @@ -81,8 +84,7 @@ static void efa_cq_handle_error(struct efa_base_ep *base_ep,
err_entry.prov_errno = prov_errno;

if (is_tx)
// TODO: get correct peer addr for TX operation
addr = FI_ADDR_NOTAVAIL;
addr = ibv_cq_ex->wr_id ? ((struct efa_context *)ibv_cq_ex->wr_id)->addr : FI_ADDR_NOTAVAIL;
else
addr = efa_av_reverse_lookup(base_ep->av,
ibv_wc_read_slid(ibv_cq_ex),
Expand Down
6 changes: 4 additions & 2 deletions prov/efa/src/efa_msg.c
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi
wr = &base_ep->efa_recv_wr_vec[wr_index].wr;
wr->num_sge = msg->iov_count;
wr->sg_list = base_ep->efa_recv_wr_vec[wr_index].sge;
wr->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);
wr->wr_id = (uintptr_t) efa_fill_context(msg->context, msg->addr, flags,
FI_RECV | FI_MSG);

for (i = 0; i < msg->iov_count; i++) {
addr = (uintptr_t)msg->msg_iov[i].iov_base;
Expand Down Expand Up @@ -224,7 +225,8 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi
base_ep->is_wr_started = true;
}

qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);
qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context(
msg->context, msg->addr, flags, FI_SEND | FI_MSG);

if (flags & FI_REMOTE_CQ_DATA) {
ibv_wr_send_imm(qp->ibv_qp_ex, msg->data);
Expand Down
8 changes: 6 additions & 2 deletions prov/efa/src/efa_rma.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep,
ibv_wr_start(qp->ibv_qp_ex);
base_ep->is_wr_started = true;
}
qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);

qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context(
msg->context, msg->addr, flags, FI_RMA | FI_READ);

/* ep->domain->info->tx_attr->rma_iov_limit is set to 1 */
ibv_wr_rdma_read(qp->ibv_qp_ex, msg->rma_iov[0].key, msg->rma_iov[0].addr);
Expand Down Expand Up @@ -225,7 +227,9 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep,
ibv_wr_start(qp->ibv_qp_ex);
base_ep->is_wr_started = true;
}
qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);

qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context(
msg->context, msg->addr, flags, FI_RMA | FI_WRITE);

if (flags & FI_REMOTE_CQ_DATA) {
ibv_wr_rdma_write_imm(qp->ibv_qp_ex, msg->rma_iov[0].key,
Expand Down
29 changes: 19 additions & 10 deletions prov/efa/test/efa_unit_test_cq.c
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,8 @@ void test_ibv_cq_ex_read_ignore_removed_peer()
#endif

static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr,
int ibv_wc_opcode, int status, int vendor_error)
int ibv_wc_opcode, int status, int vendor_error,
struct efa_context *ctx)
{
int ret;
size_t raw_addr_len = sizeof(struct efa_ep_addr);
Expand Down Expand Up @@ -847,7 +848,9 @@ static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr,
if (ibv_wc_opcode == IBV_WC_RECV) {
ibv_cqx = container_of(base_ep->util_ep.rx_cq, struct efa_cq, util_cq)->ibv_cq.ibv_cq_ex;
ibv_cqx->start_poll = &efa_mock_ibv_start_poll_return_mock;
ibv_cqx->wr_id = (uintptr_t)12345;
ctx->completion_flags = FI_RECV | FI_MSG;
ctx->addr = 0x12345678;
ibv_cqx->wr_id = (uintptr_t) ctx;
will_return(efa_mock_ibv_start_poll_return_mock, 0);
ibv_cqx->status = status;
} else {
Expand Down Expand Up @@ -894,16 +897,18 @@ void test_efa_cq_read_send_success(struct efa_resource **state)
{
struct efa_resource *resource = *state;
struct efa_unit_test_buff send_buff;
struct fi_context2 ctx;
struct fi_cq_data_entry cq_entry;
fi_addr_t addr;
int ret;

test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0);
test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0,
(struct efa_context *) &ctx);
efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */);

assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
ret = fi_send(resource->ep, send_buff.buff, send_buff.size,
fi_mr_desc(send_buff.mr), addr, (void *) 12345);
fi_mr_desc(send_buff.mr), addr, &ctx);
assert_int_equal(ret, 0);
assert_int_equal(g_ibv_submitted_wr_id_cnt, 1);

Expand All @@ -924,14 +929,16 @@ void test_efa_cq_read_recv_success(struct efa_resource **state)
struct efa_resource *resource = *state;
struct efa_unit_test_buff recv_buff;
struct fi_cq_data_entry cq_entry;
struct fi_context2 ctx;
fi_addr_t addr;
int ret;

test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0);
test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0,
(struct efa_context *) &ctx);
efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */);

ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size,
fi_mr_desc(recv_buff.mr), addr, NULL);
fi_mr_desc(recv_buff.mr), addr, &ctx);
assert_int_equal(ret, 0);

ret = fi_cq_read(resource->cq, &cq_entry, 1);
Expand Down Expand Up @@ -974,16 +981,17 @@ void test_efa_cq_read_send_failure(struct efa_resource **state)
struct efa_resource *resource = *state;
struct efa_unit_test_buff send_buff;
struct fi_cq_data_entry cq_entry;
struct fi_context2 ctx;
fi_addr_t addr;
int ret;

test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_GENERAL_ERR,
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE);
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *) &ctx);
efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */);

assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
ret = fi_send(resource->ep, send_buff.buff, send_buff.size,
fi_mr_desc(send_buff.mr), addr, (void *) 12345);
fi_mr_desc(send_buff.mr), addr, &ctx);
assert_int_equal(ret, 0);
assert_int_equal(g_ibv_submitted_wr_id_cnt, 1);

Expand Down Expand Up @@ -1011,15 +1019,16 @@ void test_efa_cq_read_recv_failure(struct efa_resource **state)
struct efa_resource *resource = *state;
struct efa_unit_test_buff recv_buff;
struct fi_cq_data_entry cq_entry;
struct fi_context2 ctx;
fi_addr_t addr;
int ret;

test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_GENERAL_ERR,
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE);
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *) &ctx);
efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */);

ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size,
fi_mr_desc(recv_buff.mr), addr, NULL);
fi_mr_desc(recv_buff.mr), addr, &ctx);
assert_int_equal(ret, 0);

ret = fi_cq_read(resource->cq, &cq_entry, 1);
Expand Down

0 comments on commit 46651bd

Please sign in to comment.