Skip to content

Commit

Permalink
prov/efa: Implement FI_CONTEXT2 in EFA Direct
Browse files Browse the repository at this point in the history
Store the completion flags and peer address in FI_CONTEXT2 and
retrieve later when writing cq.

Signed-off-by: Jessie Yang <[email protected]>
  • Loading branch information
jiaxiyan committed Jan 17, 2025
1 parent a93beca commit d60a366
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 17 deletions.
27 changes: 27 additions & 0 deletions prov/efa/src/efa.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,33 @@ struct efa_fabric {
#endif
};

struct efa_context {
uint64_t completion_flags;
fi_addr_t addr;
};

#if defined(static_assert) && defined(__x86_64__)
static_assert(sizeof(struct efa_context) <= sizeof(struct fi_context2),
"efa_context must not be larger than fi_context2");
#endif

static inline uintptr_t get_efa_context(const void *context,
const fi_addr_t addr,
const uint64_t flags,
uint64_t completion_flags)
{
struct efa_context *efa_context;

if (!context || !(flags & FI_COMPLETION))
return (uintptr_t) NULL;

efa_context = (struct efa_context *) context;
efa_context->completion_flags = completion_flags;
efa_context->addr = addr;

return (uintptr_t) efa_context;
}

static inline
int efa_str_to_ep_addr(const char *node, const char *service, struct efa_ep_addr *addr)
{
Expand Down
8 changes: 5 additions & 3 deletions prov/efa/src/efa_cq.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ static void efa_cq_construct_cq_entry(struct ibv_cq_ex *ibv_cqx,
struct fi_cq_tagged_entry *entry)
{
entry->op_context = (void *)ibv_cqx->wr_id;
entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx));
if (!ibv_cqx->wr_id)
entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx));
else
entry->flags = ((struct efa_context *) ibv_cqx->wr_id)->completion_flags;
entry->len = ibv_wc_read_byte_len(ibv_cqx);
entry->buf = NULL;
entry->data = 0;
Expand Down Expand Up @@ -81,8 +84,7 @@ static void efa_cq_handle_error(struct efa_base_ep *base_ep,
err_entry.prov_errno = prov_errno;

if (is_tx)
// TODO: get correct peer addr for TX operation
addr = FI_ADDR_NOTAVAIL;
addr = ibv_cq_ex->wr_id ? ((struct efa_context *)ibv_cq_ex->wr_id)->addr : FI_ADDR_NOTAVAIL;
else
addr = efa_av_reverse_lookup(base_ep->av,
ibv_wc_read_slid(ibv_cq_ex),
Expand Down
4 changes: 2 additions & 2 deletions prov/efa/src/efa_msg.c
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi
wr = &base_ep->efa_recv_wr_vec[wr_index].wr;
wr->num_sge = msg->iov_count;
wr->sg_list = base_ep->efa_recv_wr_vec[wr_index].sge;
wr->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);
wr->wr_id = get_efa_context(msg->context, msg->addr, flags, FI_RECV | FI_MSG);

for (i = 0; i < msg->iov_count; i++) {
addr = (uintptr_t)msg->msg_iov[i].iov_base;
Expand Down Expand Up @@ -224,7 +224,7 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi
base_ep->is_wr_started = true;
}

qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);
qp->ibv_qp_ex->wr_id = get_efa_context(msg->context, msg->addr, flags, FI_SEND | FI_MSG);

if (flags & FI_REMOTE_CQ_DATA) {
ibv_wr_send_imm(qp->ibv_qp_ex, msg->data);
Expand Down
10 changes: 8 additions & 2 deletions prov/efa/src/efa_rma.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep,
ibv_wr_start(qp->ibv_qp_ex);
base_ep->is_wr_started = true;
}
qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);

qp->ibv_qp_ex->wr_id = get_efa_context(msg->context, msg->addr, flags, FI_RMA | FI_READ);

/* ep->domain->info->tx_attr->rma_iov_limit is set to 1 */
ibv_wr_rdma_read(qp->ibv_qp_ex, msg->rma_iov[0].key, msg->rma_iov[0].addr);
Expand Down Expand Up @@ -225,7 +226,12 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep,
ibv_wr_start(qp->ibv_qp_ex);
base_ep->is_wr_started = true;
}
qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL);

qp->ibv_qp_ex->wr_id = get_efa_context(
msg->context, msg->addr, flags,
flags & FI_REMOTE_CQ_DATA ?
FI_REMOTE_CQ_DATA | FI_RMA | FI_REMOTE_WRITE :
FI_RMA | FI_WRITE);

if (flags & FI_REMOTE_CQ_DATA) {
ibv_wr_rdma_write_imm(qp->ibv_qp_ex, msg->rma_iov[0].key,
Expand Down
39 changes: 29 additions & 10 deletions prov/efa/test/efa_unit_test_cq.c
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,8 @@ void test_ibv_cq_ex_read_ignore_removed_peer()
#endif

static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr,
int ibv_wc_opcode, int status, int vendor_error)
int ibv_wc_opcode, int status, int vendor_error,
struct efa_context *ctx)
{
int ret;
size_t raw_addr_len = sizeof(struct efa_ep_addr);
Expand Down Expand Up @@ -847,7 +848,9 @@ static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr,
if (ibv_wc_opcode == IBV_WC_RECV) {
ibv_cqx = container_of(base_ep->util_ep.rx_cq, struct efa_cq, util_cq)->ibv_cq.ibv_cq_ex;
ibv_cqx->start_poll = &efa_mock_ibv_start_poll_return_mock;
ibv_cqx->wr_id = (uintptr_t)12345;
ctx->completion_flags = FI_RECV | FI_MSG;
ctx->addr = 0x12345678;
ibv_cqx->wr_id = (uintptr_t) ctx;
will_return(efa_mock_ibv_start_poll_return_mock, 0);
ibv_cqx->status = status;
} else {
Expand Down Expand Up @@ -894,16 +897,19 @@ void test_efa_cq_read_send_success(struct efa_resource **state)
{
struct efa_resource *resource = *state;
struct efa_unit_test_buff send_buff;
struct fi_context2 *ctx;
struct fi_cq_data_entry cq_entry;
fi_addr_t addr;
int ret;

test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0);
ctx = malloc(sizeof(struct fi_context2));
assert_non_null(ctx);
test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0, (struct efa_context *)ctx);
efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */);

assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
ret = fi_send(resource->ep, send_buff.buff, send_buff.size,
fi_mr_desc(send_buff.mr), addr, (void *) 12345);
fi_mr_desc(send_buff.mr), addr, ctx);
assert_int_equal(ret, 0);
assert_int_equal(g_ibv_submitted_wr_id_cnt, 1);

Expand All @@ -913,6 +919,7 @@ void test_efa_cq_read_send_success(struct efa_resource **state)
assert_int_equal(ret, 1);

efa_unit_test_buff_destruct(&send_buff);
free(ctx);
}

/**
Expand All @@ -924,20 +931,24 @@ void test_efa_cq_read_recv_success(struct efa_resource **state)
struct efa_resource *resource = *state;
struct efa_unit_test_buff recv_buff;
struct fi_cq_data_entry cq_entry;
struct fi_context2 *ctx;
fi_addr_t addr;
int ret;

test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0);
ctx = malloc(sizeof(struct fi_context2));
assert_non_null(ctx);
test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0, (struct efa_context *)ctx);
efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */);

ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size,
fi_mr_desc(recv_buff.mr), addr, NULL);
fi_mr_desc(recv_buff.mr), addr, ctx);
assert_int_equal(ret, 0);

ret = fi_cq_read(resource->cq, &cq_entry, 1);
assert_int_equal(ret, 1);

efa_unit_test_buff_destruct(&recv_buff);
free(ctx);
}

static void efa_cq_check_cq_err_entry(struct efa_resource *resource, int vendor_error) {
Expand Down Expand Up @@ -974,16 +985,19 @@ void test_efa_cq_read_send_failure(struct efa_resource **state)
struct efa_resource *resource = *state;
struct efa_unit_test_buff send_buff;
struct fi_cq_data_entry cq_entry;
struct fi_context2 *ctx;
fi_addr_t addr;
int ret;

ctx = malloc(sizeof(struct fi_context2));
assert_non_null(ctx);
test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_GENERAL_ERR,
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE);
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *)ctx);
efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */);

assert_int_equal(g_ibv_submitted_wr_id_cnt, 0);
ret = fi_send(resource->ep, send_buff.buff, send_buff.size,
fi_mr_desc(send_buff.mr), addr, (void *) 12345);
fi_mr_desc(send_buff.mr), addr, ctx);
assert_int_equal(ret, 0);
assert_int_equal(g_ibv_submitted_wr_id_cnt, 1);

Expand All @@ -996,6 +1010,7 @@ void test_efa_cq_read_send_failure(struct efa_resource **state)
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE);

efa_unit_test_buff_destruct(&send_buff);
free(ctx);
}

/**
Expand All @@ -1011,15 +1026,18 @@ void test_efa_cq_read_recv_failure(struct efa_resource **state)
struct efa_resource *resource = *state;
struct efa_unit_test_buff recv_buff;
struct fi_cq_data_entry cq_entry;
struct fi_context2 *ctx;
fi_addr_t addr;
int ret;

ctx = malloc(sizeof(struct fi_context2));
assert_non_null(ctx);
test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_GENERAL_ERR,
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE);
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *)ctx);
efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */);

ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size,
fi_mr_desc(recv_buff.mr), addr, NULL);
fi_mr_desc(recv_buff.mr), addr, ctx);
assert_int_equal(ret, 0);

ret = fi_cq_read(resource->cq, &cq_entry, 1);
Expand All @@ -1029,4 +1047,5 @@ void test_efa_cq_read_recv_failure(struct efa_resource **state)
EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE);

efa_unit_test_buff_destruct(&recv_buff);
free(ctx);
}

0 comments on commit d60a366

Please sign in to comment.