From cef74eb5fe1492dd80e2b3f6fe4e2c45744818f8 Mon Sep 17 00:00:00 2001 From: Mikhail Brinskii Date: Wed, 18 Sep 2024 09:28:57 +0300 Subject: [PATCH] UCP: Fix ppln with cuda + minor fixes --- src/ucp/proto/proto_common.c | 23 ++++++++++++----------- src/ucp/proto/proto_common.h | 12 +++++++----- src/ucp/proto/proto_multi.c | 9 +++++---- src/ucp/proto/proto_multi.h | 1 + src/ucp/proto/proto_select.h | 2 -- src/ucp/proto/proto_single.c | 3 ++- src/ucp/rndv/proto_rndv.c | 15 +++++++-------- src/ucp/rndv/rndv_get.c | 9 +++++---- src/ucp/rndv/rndv_put.c | 13 +++++++------ src/ucp/rndv/rndv_rtr.c | 2 +- 10 files changed, 47 insertions(+), 42 deletions(-) diff --git a/src/ucp/proto/proto_common.c b/src/ucp/proto/proto_common.c index b93e01dc1e4..cd4771e8648 100644 --- a/src/ucp/proto/proto_common.c +++ b/src/ucp/proto/proto_common.c @@ -409,7 +409,8 @@ ucp_lane_index_t ucp_proto_common_find_lanes(const ucp_proto_init_params_t *params, uct_ep_operation_t memtype_op, unsigned flags, ptrdiff_t max_iov_offs, size_t min_iov, - ucp_lane_type_t lane_type, uint64_t tl_cap_flags, + ucp_lane_type_t lane_type, + ucs_memory_type_t mem_type, uint64_t tl_cap_flags, ucp_lane_index_t max_lanes, ucp_lane_map_t exclude_map, ucp_lane_index_t *lanes) { @@ -494,25 +495,24 @@ ucp_proto_common_find_lanes(const ucp_proto_init_params_t *params, } /* Check memory registration capabilities for zero-copy case */ - if (flags & UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY) { + if (flags & (UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY | + UCP_PROTO_COMMON_INIT_FLAG_MTYPE_ZCOPY)) { if (md_attr->flags & UCT_MD_FLAG_NEED_MEMH) { /* Memory domain must support registration on the relevant * memory type */ - if (!(context->reg_md_map[select_param->mem_type] & - UCS_BIT(md_index))) { + if (!(context->reg_md_map[mem_type] & UCS_BIT(md_index))) { ucs_trace("%s: md %s cannot register %s memory", lane_desc, context->tl_mds[md_index].rsc.md_name, - ucs_memory_type_names[select_param->mem_type]); + ucs_memory_type_names[mem_type]); continue; } - } else if (!(md_attr->access_mem_types & - UCS_BIT(select_param->mem_type))) { + } else if (!(md_attr->access_mem_types & UCS_BIT(mem_type))) { /* * Memory domain which does not require a registration for zero * copy operation must be able to access the relevant memory type */ ucs_trace("%s: no access to mem type %s", lane_desc, - ucs_memory_type_names[select_param->mem_type]); + ucs_memory_type_names[mem_type]); continue; } } @@ -598,8 +598,9 @@ ucp_proto_common_reg_md_map(const ucp_proto_common_init_params_t *params, ucp_lane_index_t ucp_proto_common_find_lanes_with_min_frag( const ucp_proto_common_init_params_t *params, ucp_lane_type_t lane_type, - uint64_t tl_cap_flags, ucp_lane_index_t max_lanes, - ucp_lane_map_t exclude_map, ucp_lane_index_t *lanes) + ucs_memory_type_t mem_type, uint64_t tl_cap_flags, + ucp_lane_index_t max_lanes, ucp_lane_map_t exclude_map, + ucp_lane_index_t *lanes) { ucp_lane_index_t lane_index, lane, num_lanes, num_valid_lanes; const uct_iface_attr_t *iface_attr; @@ -608,7 +609,7 @@ ucp_lane_index_t ucp_proto_common_find_lanes_with_min_frag( num_lanes = ucp_proto_common_find_lanes(¶ms->super, params->memtype_op, params->flags, params->max_iov_offs, params->min_iov, lane_type, - tl_cap_flags, max_lanes, + mem_type, tl_cap_flags, max_lanes, exclude_map, lanes); num_valid_lanes = 0; diff --git a/src/ucp/proto/proto_common.h b/src/ucp/proto/proto_common.h index 6492904530b..10aacc07a6d 100644 --- a/src/ucp/proto/proto_common.h +++ b/src/ucp/proto/proto_common.h @@ -61,7 +61,8 @@ typedef enum { /* Supports starting the request when its datatype iterator offset is > 0 */ UCP_PROTO_COMMON_INIT_FLAG_RESUME = UCS_BIT(10), - UCP_PROTO_COMMON_KEEP_MD_MAP = UCS_BIT(11) + UCP_PROTO_COMMON_INIT_FLAG_KEEP_MD_MAP = UCS_BIT(11), + UCP_PROTO_COMMON_INIT_FLAG_MTYPE_ZCOPY = UCS_BIT(12) } ucp_proto_common_init_flags_t; @@ -248,16 +249,17 @@ ucp_proto_common_get_lane_perf(const ucp_proto_common_init_params_t *params, /* @return number of lanes found */ ucp_lane_index_t ucp_proto_common_find_lanes_with_min_frag( const ucp_proto_common_init_params_t *params, ucp_lane_type_t lane_type, - uint64_t tl_cap_flags, ucp_lane_index_t max_lanes, - ucp_lane_map_t exclude_map, ucp_lane_index_t *lanes); + ucs_memory_type_t mem_type, uint64_t tl_cap_flags, + ucp_lane_index_t max_lanes, ucp_lane_map_t exclude_map, + ucp_lane_index_t *lanes); ucp_lane_index_t ucp_proto_common_find_lanes(const ucp_proto_init_params_t *params, uct_ep_operation_t memtype_op, unsigned flags, ptrdiff_t max_iov_offs, size_t min_iov, - ucp_lane_type_t lane_type, uint64_t tl_cap_flags, - ucp_lane_index_t max_lanes, + ucp_lane_type_t lane_type, ucs_memory_type_t mem_type, + uint64_t tl_cap_flags, ucp_lane_index_t max_lanes, ucp_lane_map_t exclude_map, ucp_lane_index_t *lanes); diff --git a/src/ucp/proto/proto_multi.c b/src/ucp/proto/proto_multi.c index 0689866c2a3..f0e2cd139ad 100644 --- a/src/ucp/proto/proto_multi.c +++ b/src/ucp/proto/proto_multi.c @@ -22,6 +22,7 @@ ucs_status_t ucp_proto_multi_init(const ucp_proto_multi_init_params_t *params, ucp_proto_caps_t *caps, + ucs_memory_type_t mem_type, ucp_proto_multi_priv_t *mpriv) { ucp_context_h context = params->super.super.worker->context; @@ -55,8 +56,8 @@ ucs_status_t ucp_proto_multi_init(const ucp_proto_multi_init_params_t *params, /* Find first lane */ num_lanes = ucp_proto_common_find_lanes_with_min_frag( - ¶ms->super, params->first.lane_type, params->first.tl_cap_flags, - 1, 0, lanes); + ¶ms->super, params->first.lane_type, mem_type, + params->first.tl_cap_flags, 1, 0, lanes); if (num_lanes == 0) { ucs_trace("no lanes for %s", ucp_proto_id_field(params->super.super.proto_id, name)); @@ -65,7 +66,7 @@ ucs_status_t ucp_proto_multi_init(const ucp_proto_multi_init_params_t *params, /* Find rest of the lanes */ num_lanes += ucp_proto_common_find_lanes_with_min_frag( - ¶ms->super, params->middle.lane_type, + ¶ms->super, params->middle.lane_type, mem_type, params->middle.tl_cap_flags, UCP_PROTO_MAX_LANES - 1, UCS_BIT(lanes[0]), lanes + 1); @@ -264,7 +265,7 @@ void ucp_proto_multi_probe(const ucp_proto_multi_init_params_t *params) ucp_proto_caps_t caps; ucs_status_t status; - status = ucp_proto_multi_init(params, &caps, &mpriv); + status = ucp_proto_multi_init(params, &caps, UCS_MEMORY_TYPE_HOST, &mpriv); if (status != UCS_OK) { return; } diff --git a/src/ucp/proto/proto_multi.h b/src/ucp/proto/proto_multi.h index 8e6dff043c7..828ceeac272 100644 --- a/src/ucp/proto/proto_multi.h +++ b/src/ucp/proto/proto_multi.h @@ -160,6 +160,7 @@ typedef ucs_status_t (*ucp_proto_multi_lane_send_func_t)(ucp_request_t *req, ucs_status_t ucp_proto_multi_init(const ucp_proto_multi_init_params_t *params, ucp_proto_caps_t *caps, + ucs_memory_type_t mem_type, ucp_proto_multi_priv_t *mpriv); diff --git a/src/ucp/proto/proto_select.h b/src/ucp/proto/proto_select.h index c3371d4a33f..9441c3c0044 100644 --- a/src/ucp/proto/proto_select.h +++ b/src/ucp/proto/proto_select.h @@ -82,8 +82,6 @@ struct ucp_proto_select_param { uint8_t sys_dev; /* Reply buffer system device */ } UCS_S_PACKED reply; - uint8_t mem_type_flags; - /* Align struct size to uint64_t */ uint8_t padding[2]; diff --git a/src/ucp/proto/proto_single.c b/src/ucp/proto/proto_single.c index ddd754e86a2..07cb74eb7ff 100644 --- a/src/ucp/proto/proto_single.c +++ b/src/ucp/proto/proto_single.c @@ -30,7 +30,8 @@ ucs_status_t ucp_proto_single_init(const ucp_proto_single_init_params_t *params, ucs_status_t status; num_lanes = ucp_proto_common_find_lanes_with_min_frag( - ¶ms->super, params->lane_type, params->tl_cap_flags, 1, + ¶ms->super, params->lane_type, + params->super.super.select_param->mem_type, params->tl_cap_flags, 1, params->super.exclude_map, &lane); if (num_lanes == 0) { ucs_trace("no lanes for %s", diff --git a/src/ucp/rndv/proto_rndv.c b/src/ucp/rndv/proto_rndv.c index bdf3b37288c..524fedcc715 100644 --- a/src/ucp/rndv/proto_rndv.c +++ b/src/ucp/rndv/proto_rndv.c @@ -269,7 +269,7 @@ ucp_proto_rndv_ctrl_init_priv(const ucp_proto_rndv_ctrl_init_params_t *params, /* Use only memory domains for which the unpacking of the remote key was * successful */ if ((init_params->rkey_config_key != NULL) && - !(params->super.flags & UCP_PROTO_COMMON_KEEP_MD_MAP)) { + !(params->super.flags & UCP_PROTO_COMMON_INIT_FLAG_KEEP_MD_MAP)) { rpriv->md_map &= ~init_params->rkey_config_key->unreachable_md_map; } @@ -519,12 +519,10 @@ ucp_proto_rndv_find_ctrl_lane(const ucp_proto_init_params_t *params) { ucp_lane_index_t lane, num_lanes; - num_lanes = ucp_proto_common_find_lanes(params, UCT_EP_OP_LAST, - UCP_PROTO_COMMON_INIT_FLAG_HDR_ONLY, - UCP_PROTO_COMMON_OFFSET_INVALID, 1, - UCP_LANE_TYPE_AM, - UCT_IFACE_FLAG_AM_BCOPY, 1, 0, - &lane); + num_lanes = ucp_proto_common_find_lanes( + params, UCT_EP_OP_LAST, UCP_PROTO_COMMON_INIT_FLAG_HDR_ONLY, + UCP_PROTO_COMMON_OFFSET_INVALID, 1, UCP_LANE_TYPE_AM, + UCS_MEMORY_TYPE_HOST, UCT_IFACE_FLAG_AM_BCOPY, 1, 0, &lane); if (num_lanes == 0) { ucs_debug("no active message lane for %s", ucp_proto_id_field(params->proto_id, name)); @@ -709,7 +707,8 @@ ucp_proto_rndv_bulk_init(const ucp_proto_multi_init_params_t *init_params, ucp_proto_caps_t multi_caps; ucs_status_t status; - status = ucp_proto_multi_init(init_params, &multi_caps, mpriv); + status = ucp_proto_multi_init(init_params, &multi_caps, + rpriv->frag_mem_type, mpriv); if (status != UCS_OK) { return status; } diff --git a/src/ucp/rndv/rndv_get.c b/src/ucp/rndv/rndv_get.c index fc850c30095..1ec5d331b29 100644 --- a/src/ucp/rndv/rndv_get.c +++ b/src/ucp/rndv/rndv_get.c @@ -71,13 +71,13 @@ ucp_proto_rndv_get_common_probe(const ucp_proto_init_params_t *init_params, return; } + rpriv.frag_mem_type = frag_mem_type; status = ucp_proto_rndv_bulk_init(¶ms, UCP_PROTO_RNDV_GET_DESC, UCP_PROTO_RNDV_ATS_NAME, &rpriv, &caps); if (status != UCS_OK) { return; } - rpriv.frag_mem_type = frag_mem_type; priv_size = UCP_PROTO_MULTI_EXTENDED_PRIV_SIZE(&rpriv, mpriv); ucp_proto_common_add_proto(¶ms.super, &caps, &rpriv, priv_size); } @@ -127,7 +127,7 @@ ucp_proto_rndv_get_zcopy_probe(const ucp_proto_init_params_t *init_params) UCT_EP_OP_LAST, UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, - 0, 0, UCS_MEMORY_TYPE_HOST); + 0, 0, init_params->select_param->mem_type); } static void @@ -322,12 +322,13 @@ ucp_proto_rndv_get_mtype_probe(const ucp_proto_init_params_t *init_params) status = ucp_proto_rndv_mtype_init(init_params, frag_mem_type, &mdesc_md_map, &frag_size); if (status != UCS_OK) { - return; + continue; } ucp_proto_rndv_get_common_probe(init_params, UCS_BIT(UCP_RNDV_MODE_GET_PIPELINE), - frag_size, UCT_EP_OP_PUT_ZCOPY, 0, + frag_size, UCT_EP_OP_PUT_ZCOPY, + UCP_PROTO_COMMON_INIT_FLAG_MTYPE_ZCOPY, mdesc_md_map, 1, frag_mem_type); } } diff --git a/src/ucp/rndv/rndv_put.c b/src/ucp/rndv/rndv_put.c index 4fddf4a8216..639a77885e0 100644 --- a/src/ucp/rndv/rndv_put.c +++ b/src/ucp/rndv/rndv_put.c @@ -278,6 +278,7 @@ ucp_proto_rndv_put_common_probe(const ucp_proto_init_params_t *init_params, return; } + rpriv.bulk.frag_mem_type = frag_mem_type; status = ucp_proto_rndv_bulk_init(¶ms, UCP_PROTO_RNDV_PUT_DESC, UCP_PROTO_RNDV_ATP_NAME, &rpriv.bulk, &caps); @@ -346,9 +347,8 @@ ucp_proto_rndv_put_common_probe(const ucp_proto_init_params_t *init_params, if (send_atp) { ucs_assert(rpriv.atp_map != 0); } - rpriv.atp_num_lanes = ucs_popcount(rpriv.atp_map); - rpriv.stat_counter = stat_counter; - rpriv.bulk.frag_mem_type = frag_mem_type; + rpriv.atp_num_lanes = ucs_popcount(rpriv.atp_map); + rpriv.stat_counter = stat_counter; priv_size = UCP_PROTO_MULTI_EXTENDED_PRIV_SIZE(&rpriv, bulk.mpriv); ucp_proto_common_add_proto(¶ms.super, &caps, &rpriv, priv_size); @@ -419,7 +419,7 @@ ucp_proto_rndv_put_zcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, 0, ucp_proto_rndv_put_zcopy_completion, 0, - UCP_WORKER_STAT_RNDV_PUT_ZCOPY, UCS_MEMORY_TYPE_HOST); + UCP_WORKER_STAT_RNDV_PUT_ZCOPY, init_params->select_param->mem_type); } static void @@ -592,8 +592,9 @@ ucp_proto_rndv_put_mtype_probe(const ucp_proto_init_params_t *init_params) ucp_proto_rndv_put_common_probe( init_params, UCS_BIT(UCP_RNDV_MODE_PUT_PIPELINE), frag_size, - UCT_EP_OP_GET_ZCOPY, 0, mdesc_md_map, comp_cb, 1, - UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY, frag_mem_type); + UCT_EP_OP_GET_ZCOPY, UCP_PROTO_COMMON_INIT_FLAG_MTYPE_ZCOPY, + mdesc_md_map, comp_cb, 1, UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY, + frag_mem_type); } static void diff --git a/src/ucp/rndv/rndv_rtr.c b/src/ucp/rndv/rndv_rtr.c index b64bac5f25d..d6508a9a4cb 100644 --- a/src/ucp/rndv/rndv_rtr.c +++ b/src/ucp/rndv/rndv_rtr.c @@ -389,7 +389,7 @@ ucp_proto_rndv_rtr_mtype_probe(const ucp_proto_init_params_t *init_params) .super.memtype_op = UCT_EP_OP_LAST, .super.flags = UCP_PROTO_COMMON_INIT_FLAG_RESPONSE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING | - UCP_PROTO_COMMON_KEEP_MD_MAP, + UCP_PROTO_COMMON_INIT_FLAG_KEEP_MD_MAP, .super.exclude_map = 0, .remote_op_id = UCP_OP_ID_RNDV_SEND, .lane = ucp_proto_rndv_find_ctrl_lane(init_params),