diff --git a/prov/efa/src/efa.h b/prov/efa/src/efa.h index 4d8e982355c..e1cf716cb05 100644 --- a/prov/efa/src/efa.h +++ b/prov/efa/src/efa.h @@ -107,6 +107,41 @@ struct efa_fabric { #endif }; +struct efa_context { + uint64_t completion_flags; + fi_addr_t addr; +}; + +#if defined(static_assert) +static_assert(sizeof(struct efa_context) <= sizeof(struct fi_context2), + "efa_context must not be larger than fi_context2"); +#endif + +/** + * Prepare and return a pointer to an EFA context structure. + * + * @param context Pointer to the msg context. + * @param addr Peer address associated with the operation. + * @param flags Operation flags (e.g., FI_COMPLETION). + * @param completion_flags Completion flags reported in the cq entry. + * @return A pointer to an initialized EFA context structure, + * or NULL if context is invalid or FI_COMPLETION is not set. + */ +static inline struct efa_context *efa_fill_context(const void *context, + fi_addr_t addr, + uint64_t flags, + uint64_t completion_flags) +{ + if (!context || !(flags & FI_COMPLETION)) + return NULL; + + struct efa_context *efa_context = (struct efa_context *) context; + efa_context->completion_flags = completion_flags; + efa_context->addr = addr; + + return efa_context; +} + static inline int efa_str_to_ep_addr(const char *node, const char *service, struct efa_ep_addr *addr) { diff --git a/prov/efa/src/efa_cq.c b/prov/efa/src/efa_cq.c index a5b737d89ac..eeffe60cf3f 100644 --- a/prov/efa/src/efa_cq.c +++ b/prov/efa/src/efa_cq.c @@ -36,7 +36,10 @@ static void efa_cq_construct_cq_entry(struct ibv_cq_ex *ibv_cqx, struct fi_cq_tagged_entry *entry) { entry->op_context = (void *)ibv_cqx->wr_id; - entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx)); + if (ibv_cqx->wr_id) + entry->flags = ((struct efa_context *) ibv_cqx->wr_id)->completion_flags; + else + entry->flags = efa_cq_opcode_to_fi_flags(ibv_wc_read_opcode(ibv_cqx)); entry->len = ibv_wc_read_byte_len(ibv_cqx); entry->buf = NULL; entry->data = 0; @@ -81,8 +84,7 @@ static void efa_cq_handle_error(struct efa_base_ep *base_ep, err_entry.prov_errno = prov_errno; if (is_tx) - // TODO: get correct peer addr for TX operation - addr = FI_ADDR_NOTAVAIL; + addr = ibv_cq_ex->wr_id ? ((struct efa_context *)ibv_cq_ex->wr_id)->addr : FI_ADDR_NOTAVAIL; else addr = efa_av_reverse_lookup(base_ep->av, ibv_wc_read_slid(ibv_cq_ex), diff --git a/prov/efa/src/efa_msg.c b/prov/efa/src/efa_msg.c index c2af757e112..5d5768c8ff1 100644 --- a/prov/efa/src/efa_msg.c +++ b/prov/efa/src/efa_msg.c @@ -101,7 +101,8 @@ static inline ssize_t efa_post_recv(struct efa_base_ep *base_ep, const struct fi wr = &base_ep->efa_recv_wr_vec[wr_index].wr; wr->num_sge = msg->iov_count; wr->sg_list = base_ep->efa_recv_wr_vec[wr_index].sge; - wr->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL); + wr->wr_id = (uintptr_t) efa_fill_context(msg->context, msg->addr, flags, + FI_RECV | FI_MSG); for (i = 0; i < msg->iov_count; i++) { addr = (uintptr_t)msg->msg_iov[i].iov_base; @@ -224,7 +225,8 @@ static inline ssize_t efa_post_send(struct efa_base_ep *base_ep, const struct fi base_ep->is_wr_started = true; } - qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL); + qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context( + msg->context, msg->addr, flags, FI_SEND | FI_MSG); if (flags & FI_REMOTE_CQ_DATA) { ibv_wr_send_imm(qp->ibv_qp_ex, msg->data); diff --git a/prov/efa/src/efa_rma.c b/prov/efa/src/efa_rma.c index 8fee3a2021b..da33b44350f 100644 --- a/prov/efa/src/efa_rma.c +++ b/prov/efa/src/efa_rma.c @@ -90,7 +90,9 @@ static inline ssize_t efa_rma_post_read(struct efa_base_ep *base_ep, ibv_wr_start(qp->ibv_qp_ex); base_ep->is_wr_started = true; } - qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL); + + qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context( + msg->context, msg->addr, flags, FI_RMA | FI_READ); /* ep->domain->info->tx_attr->rma_iov_limit is set to 1 */ ibv_wr_rdma_read(qp->ibv_qp_ex, msg->rma_iov[0].key, msg->rma_iov[0].addr); @@ -225,7 +227,9 @@ static inline ssize_t efa_rma_post_write(struct efa_base_ep *base_ep, ibv_wr_start(qp->ibv_qp_ex); base_ep->is_wr_started = true; } - qp->ibv_qp_ex->wr_id = (uintptr_t) ((flags & FI_COMPLETION) ? msg->context : NULL); + + qp->ibv_qp_ex->wr_id = (uintptr_t) efa_fill_context( + msg->context, msg->addr, flags, FI_RMA | FI_WRITE); if (flags & FI_REMOTE_CQ_DATA) { ibv_wr_rdma_write_imm(qp->ibv_qp_ex, msg->rma_iov[0].key, diff --git a/prov/efa/test/efa_unit_test_cq.c b/prov/efa/test/efa_unit_test_cq.c index e69fb8b432e..d793a66715b 100644 --- a/prov/efa/test/efa_unit_test_cq.c +++ b/prov/efa/test/efa_unit_test_cq.c @@ -813,7 +813,8 @@ void test_ibv_cq_ex_read_ignore_removed_peer() #endif static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr, - int ibv_wc_opcode, int status, int vendor_error) + int ibv_wc_opcode, int status, int vendor_error, + struct efa_context *ctx) { int ret; size_t raw_addr_len = sizeof(struct efa_ep_addr); @@ -847,7 +848,9 @@ static void test_efa_cq_read(struct efa_resource *resource, fi_addr_t *addr, if (ibv_wc_opcode == IBV_WC_RECV) { ibv_cqx = container_of(base_ep->util_ep.rx_cq, struct efa_cq, util_cq)->ibv_cq.ibv_cq_ex; ibv_cqx->start_poll = &efa_mock_ibv_start_poll_return_mock; - ibv_cqx->wr_id = (uintptr_t)12345; + ctx->completion_flags = FI_RECV | FI_MSG; + ctx->addr = 0x12345678; + ibv_cqx->wr_id = (uintptr_t) ctx; will_return(efa_mock_ibv_start_poll_return_mock, 0); ibv_cqx->status = status; } else { @@ -894,16 +897,18 @@ void test_efa_cq_read_send_success(struct efa_resource **state) { struct efa_resource *resource = *state; struct efa_unit_test_buff send_buff; + struct fi_context2 ctx; struct fi_cq_data_entry cq_entry; fi_addr_t addr; int ret; - test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0); + test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_SUCCESS, 0, + (struct efa_context *) &ctx); efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */); assert_int_equal(g_ibv_submitted_wr_id_cnt, 0); ret = fi_send(resource->ep, send_buff.buff, send_buff.size, - fi_mr_desc(send_buff.mr), addr, (void *) 12345); + fi_mr_desc(send_buff.mr), addr, &ctx); assert_int_equal(ret, 0); assert_int_equal(g_ibv_submitted_wr_id_cnt, 1); @@ -924,14 +929,16 @@ void test_efa_cq_read_recv_success(struct efa_resource **state) struct efa_resource *resource = *state; struct efa_unit_test_buff recv_buff; struct fi_cq_data_entry cq_entry; + struct fi_context2 ctx; fi_addr_t addr; int ret; - test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0); + test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_SUCCESS, 0, + (struct efa_context *) &ctx); efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */); ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size, - fi_mr_desc(recv_buff.mr), addr, NULL); + fi_mr_desc(recv_buff.mr), addr, &ctx); assert_int_equal(ret, 0); ret = fi_cq_read(resource->cq, &cq_entry, 1); @@ -974,16 +981,17 @@ void test_efa_cq_read_send_failure(struct efa_resource **state) struct efa_resource *resource = *state; struct efa_unit_test_buff send_buff; struct fi_cq_data_entry cq_entry; + struct fi_context2 ctx; fi_addr_t addr; int ret; test_efa_cq_read(resource, &addr, IBV_WC_SEND, IBV_WC_GENERAL_ERR, - EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE); + EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *) &ctx); efa_unit_test_buff_construct(&send_buff, resource, 4096 /* buff_size */); assert_int_equal(g_ibv_submitted_wr_id_cnt, 0); ret = fi_send(resource->ep, send_buff.buff, send_buff.size, - fi_mr_desc(send_buff.mr), addr, (void *) 12345); + fi_mr_desc(send_buff.mr), addr, &ctx); assert_int_equal(ret, 0); assert_int_equal(g_ibv_submitted_wr_id_cnt, 1); @@ -1011,15 +1019,16 @@ void test_efa_cq_read_recv_failure(struct efa_resource **state) struct efa_resource *resource = *state; struct efa_unit_test_buff recv_buff; struct fi_cq_data_entry cq_entry; + struct fi_context2 ctx; fi_addr_t addr; int ret; test_efa_cq_read(resource, &addr, IBV_WC_RECV, IBV_WC_GENERAL_ERR, - EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE); + EFA_IO_COMP_STATUS_LOCAL_ERROR_UNRESP_REMOTE, (struct efa_context *) &ctx); efa_unit_test_buff_construct(&recv_buff, resource, 4096 /* buff_size */); ret = fi_recv(resource->ep, recv_buff.buff, recv_buff.size, - fi_mr_desc(recv_buff.mr), addr, NULL); + fi_mr_desc(recv_buff.mr), addr, &ctx); assert_int_equal(ret, 0); ret = fi_cq_read(resource->cq, &cq_entry, 1);