diff --git a/src/ucp/am/eager_multi.c b/src/ucp/am/eager_multi.c index ee40787a8ed..f2279313009 100644 --- a/src/ucp/am/eager_multi.c +++ b/src/ucp/am/eager_multi.c @@ -37,6 +37,7 @@ ucp_am_eager_multi_bcopy_proto_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING | UCP_PROTO_COMMON_INIT_FLAG_RESUME, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .max_lanes = context->config.ext.max_eager_lanes, .initial_reg_md_map = 0, .first.lane_type = UCP_LANE_TYPE_AM, @@ -197,6 +198,7 @@ ucp_am_eager_multi_zcopy_proto_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = init_params->select_param->mem_type, .max_lanes = context->config.ext.max_eager_lanes, .initial_reg_md_map = 0, .opt_align_offs = UCP_PROTO_COMMON_OFFSET_INVALID, diff --git a/src/ucp/am/eager_single.c b/src/ucp/am/eager_single.c index 68bec0725ab..1eff688ade6 100644 --- a/src/ucp/am/eager_single.c +++ b/src/ucp/am/eager_single.c @@ -111,6 +111,7 @@ ucp_am_eager_short_probe_common(const ucp_proto_init_params_t *init_params, UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_AM, .tl_cap_flags = UCT_IFACE_FLAG_AM_SHORT }; @@ -240,6 +241,7 @@ static void ucp_am_eager_single_bcopy_probe_common( UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_AM, .tl_cap_flags = UCT_IFACE_FLAG_AM_BCOPY }; @@ -330,6 +332,7 @@ static void ucp_am_eager_single_zcopy_probe_common( UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = init_params->select_param->mem_type, .lane_type = UCP_LANE_TYPE_AM, .tl_cap_flags = UCT_IFACE_FLAG_AM_ZCOPY }; diff --git a/src/ucp/core/ucp_context.c b/src/ucp/core/ucp_context.c index d5f5b173f3f..ba0f4fba7f2 100644 --- a/src/ucp/core/ucp_context.c +++ b/src/ucp/core/ucp_context.c @@ -340,11 +340,15 @@ static ucs_config_field_t ucp_context_config_table[] = { "and the resulting performance.", ucs_offsetof(ucp_context_config_t, estimated_num_ppn), UCS_CONFIG_TYPE_ULUNITS}, - {"RNDV_FRAG_MEM_TYPE", "host", - "Memory type of fragments used for RNDV pipeline protocol.\n" - "Allowed memory types is one of: host, cuda, rocm, ze-host, ze-device", - ucs_offsetof(ucp_context_config_t, rndv_frag_mem_type), - UCS_CONFIG_TYPE_ENUM(ucs_memory_type_names)}, + {"RNDV_FRAG_MEM_TYPE", NULL, "", + ucs_offsetof(ucp_context_config_t, rndv_frag_mem_types), + UCS_CONFIG_TYPE_BITMAP(ucs_memory_type_names)}, + + {"RNDV_FRAG_MEM_TYPES", "host,cuda", + "Memory types of fragments used for RNDV pipeline protocol.\n" + "Allowed memory types are: host, cuda, rocm, ze-host, ze-device", + ucs_offsetof(ucp_context_config_t, rndv_frag_mem_types), + UCS_CONFIG_TYPE_BITMAP(ucs_memory_type_names)}, {"RNDV_PIPELINE_SEND_THRESH", "inf", "RNDV size threshold to enable sender side pipeline for mem type", diff --git a/src/ucp/core/ucp_context.h b/src/ucp/core/ucp_context.h index feb70b99fdc..3fd39d03a78 100644 --- a/src/ucp/core/ucp_context.h +++ b/src/ucp/core/ucp_context.h @@ -78,8 +78,8 @@ typedef struct ucp_context_config { size_t rndv_frag_size[UCS_MEMORY_TYPE_LAST]; /** Number of RNDV pipeline fragments per allocation */ size_t rndv_num_frags[UCS_MEMORY_TYPE_LAST]; - /** Memory type of fragments used for RNDV pipeline protocol */ - ucs_memory_type_t rndv_frag_mem_type; + /** Memory types of fragments used for RNDV pipeline protocol */ + uint64_t rndv_frag_mem_types; /** RNDV pipeline send threshold */ size_t rndv_pipeline_send_thresh; /** Enabling 2-stage pipeline rndv protocol */ diff --git a/src/ucp/proto/proto_common.c b/src/ucp/proto/proto_common.c index e8d611d8491..8dd69612cf0 100644 --- a/src/ucp/proto/proto_common.c +++ b/src/ucp/proto/proto_common.c @@ -433,8 +433,9 @@ ucp_lane_index_t ucp_proto_common_find_lanes(const ucp_proto_init_params_t *params, uct_ep_operation_t memtype_op, unsigned flags, ptrdiff_t max_iov_offs, size_t min_iov, - ucp_lane_type_t lane_type, uint64_t tl_cap_flags, - ucp_lane_index_t max_lanes, + ucp_lane_type_t lane_type, + ucs_memory_type_t reg_mem_type, + uint64_t tl_cap_flags, ucp_lane_index_t max_lanes, ucp_lane_map_t exclude_map, ucp_lane_index_t *lanes) { UCS_STRING_BUFFER_ONSTACK(sel_param_strb, UCP_PROTO_SELECT_PARAM_STR_MAX); @@ -457,7 +458,7 @@ ucp_proto_common_find_lanes(const ucp_proto_init_params_t *params, } ucp_proto_select_info_str(params->worker, params->rkey_cfg_index, - params->select_param, ucp_operation_names, + select_param, ucp_operation_names, &sel_param_strb); num_lanes = 0; @@ -518,25 +519,29 @@ ucp_proto_common_find_lanes(const ucp_proto_init_params_t *params, } /* Check memory registration capabilities for zero-copy case */ - if (flags & UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY) { + if (reg_mem_type != UCS_MEMORY_TYPE_UNKNOWN) { + ucs_assertv((reg_mem_type == select_param->mem_type) || + !(flags & UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY), + "flags=0x%x reg_mem_type=%s select_param->mem_type=%s", + flags, ucs_memory_type_names[reg_mem_type], + ucs_memory_type_names[select_param->mem_type]); + if (md_attr->flags & UCT_MD_FLAG_NEED_MEMH) { /* Memory domain must support registration on the relevant * memory type */ - if (!(context->reg_md_map[select_param->mem_type] & - UCS_BIT(md_index))) { + if (!(context->reg_md_map[reg_mem_type] & UCS_BIT(md_index))) { ucs_trace("%s: md %s cannot register %s memory", lane_desc, context->tl_mds[md_index].rsc.md_name, - ucs_memory_type_names[select_param->mem_type]); + ucs_memory_type_names[reg_mem_type]); continue; } - } else if (!(md_attr->access_mem_types & - UCS_BIT(select_param->mem_type))) { + } else if (!(md_attr->access_mem_types & UCS_BIT(reg_mem_type))) { /* * Memory domain which does not require a registration for zero * copy operation must be able to access the relevant memory type */ ucs_trace("%s: no access to mem type %s", lane_desc, - ucs_memory_type_names[select_param->mem_type]); + ucs_memory_type_names[reg_mem_type]); continue; } } @@ -629,11 +634,11 @@ ucp_lane_index_t ucp_proto_common_find_lanes_with_min_frag( const uct_iface_attr_t *iface_attr; size_t tl_min_frag, tl_max_frag; - num_lanes = ucp_proto_common_find_lanes(¶ms->super, params->memtype_op, - params->flags, params->max_iov_offs, - params->min_iov, lane_type, - tl_cap_flags, max_lanes, - exclude_map, lanes); + num_lanes = ucp_proto_common_find_lanes( + ¶ms->super, params->memtype_op, params->flags, + params->max_iov_offs, params->min_iov, lane_type, + params->reg_mem_type, tl_cap_flags, max_lanes, exclude_map, + lanes); num_valid_lanes = 0; for (lane_index = 0; lane_index < num_lanes; ++lane_index) { diff --git a/src/ucp/proto/proto_common.h b/src/ucp/proto/proto_common.h index 83bb44d6eef..5e72c6c3af0 100644 --- a/src/ucp/proto/proto_common.h +++ b/src/ucp/proto/proto_common.h @@ -61,6 +61,7 @@ typedef enum { /* Supports starting the request when its datatype iterator offset is > 0 */ UCP_PROTO_COMMON_INIT_FLAG_RESUME = UCS_BIT(10), + UCP_PROTO_COMMON_KEEP_MD_MAP = UCS_BIT(11) } ucp_proto_common_init_flags_t; @@ -120,6 +121,13 @@ typedef struct { /* Map of unsuitable lanes */ ucp_lane_map_t exclude_map; + + /* Memory type of the buffer used for data transfer on the transport level. + * If UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY flag is set, it is expected to + * be the user buffer memory type. Alternatively, it refers to the type of + * memory used for bounce buffers (either in the UCP or UCT layer) where + * data needs to be copied as part of the protocol. */ + ucs_memory_type_t reg_mem_type; } ucp_proto_common_init_params_t; @@ -259,8 +267,9 @@ ucp_lane_index_t ucp_proto_common_find_lanes(const ucp_proto_init_params_t *params, uct_ep_operation_t memtype_op, unsigned flags, ptrdiff_t max_iov_offs, size_t min_iov, - ucp_lane_type_t lane_type, uint64_t tl_cap_flags, - ucp_lane_index_t max_lanes, + ucp_lane_type_t lane_type, + ucs_memory_type_t reg_mem_type, + uint64_t tl_cap_flags, ucp_lane_index_t max_lanes, ucp_lane_map_t exclude_map, ucp_lane_index_t *lanes); diff --git a/src/ucp/proto/proto_init.c b/src/ucp/proto/proto_init.c index 7dfa291bdb5..077d26b04c8 100644 --- a/src/ucp/proto/proto_init.c +++ b/src/ucp/proto/proto_init.c @@ -376,6 +376,7 @@ ucp_proto_init_add_buffer_perf(const ucp_proto_common_init_params_t *params, ucp_md_map_t reg_md_map, ucp_proto_perf_t *perf) { const ucp_proto_select_param_t *select_param = params->super.select_param; + ucs_memory_type_t buffer_mem_type; ucs_memory_type_t recv_mem_type; uint32_t op_attr_mask; ucs_status_t status; @@ -390,8 +391,17 @@ ucp_proto_init_add_buffer_perf(const ucp_proto_common_init_params_t *params, } } else if (!(params->flags & UCP_PROTO_COMMON_INIT_FLAG_RKEY_PTR)) { ucs_assert(reg_md_map == 0); + + /* TODO: This mem_type initialization is specific to put and get mtype + * protocols. Consider moving it to the corresponding probe functions. + */ + if (params->reg_mem_type != UCS_MEMORY_TYPE_UNKNOWN) { + buffer_mem_type = params->reg_mem_type; + } else { + buffer_mem_type = UCS_MEMORY_TYPE_HOST; + } status = ucp_proto_init_add_buffer_copy_time( - params->super.worker, "local copy", UCS_MEMORY_TYPE_HOST, + params->super.worker, "local copy", buffer_mem_type, select_param->mem_type, params->memtype_op, range_start, range_end, 1, perf); if (status != UCS_OK) { @@ -423,11 +433,8 @@ ucp_proto_init_add_buffer_perf(const ucp_proto_common_init_params_t *params, params->super.worker, "remote copy", UCS_MEMORY_TYPE_HOST, recv_mem_type, UCT_EP_OP_PUT_SHORT, range_start, range_end, 0, perf); - if (status != UCS_OK) { - return status; - } - return UCS_OK; + return status; } static int diff --git a/src/ucp/rma/amo_offload.c b/src/ucp/rma/amo_offload.c index b35c36d8c6c..76a478b2a7c 100644 --- a/src/ucp/rma/amo_offload.c +++ b/src/ucp/rma/amo_offload.c @@ -171,6 +171,7 @@ static void ucp_proto_amo_probe(const ucp_proto_init_params_t *init_params, UCP_PROTO_COMMON_INIT_FLAG_RECV_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_SINGLE_FRAG, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_AMO, .tl_cap_flags = 0 }; diff --git a/src/ucp/rma/amo_sw.c b/src/ucp/rma/amo_sw.c index 21ca2da90a6..f8859502b2b 100644 --- a/src/ucp/rma/amo_sw.c +++ b/src/ucp/rma/amo_sw.c @@ -424,6 +424,7 @@ static void ucp_proto_amo_sw_probe(const ucp_proto_init_params_t *init_params, .super.flags = flags | UCP_PROTO_COMMON_INIT_FLAG_SINGLE_FRAG | UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_AM, .tl_cap_flags = 0 }; diff --git a/src/ucp/rma/get_am.c b/src/ucp/rma/get_am.c index 03585204b5f..4d6a6960b9f 100644 --- a/src/ucp/rma/get_am.c +++ b/src/ucp/rma/get_am.c @@ -95,6 +95,7 @@ ucp_proto_get_am_bcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_AM, .tl_cap_flags = UCT_IFACE_FLAG_AM_BCOPY }; diff --git a/src/ucp/rma/get_offload.c b/src/ucp/rma/get_offload.c index 5f01c879087..2b824bd40e0 100644 --- a/src/ucp/rma/get_offload.c +++ b/src/ucp/rma/get_offload.c @@ -96,6 +96,7 @@ ucp_proto_get_offload_bcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_REMOTE_ACCESS | UCP_PROTO_COMMON_INIT_FLAG_RESPONSE, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .max_lanes = UCP_PROTO_RMA_MAX_BCOPY_LANES, .initial_reg_md_map = 0, .first.tl_cap_flags = UCT_IFACE_FLAG_GET_BCOPY, @@ -202,6 +203,7 @@ ucp_proto_get_offload_zcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_RESPONSE | UCP_PROTO_COMMON_INIT_FLAG_MIN_FRAG, .super.exclude_map = 0, + .super.reg_mem_type = init_params->select_param->mem_type, .max_lanes = context->config.ext.max_rma_lanes, .initial_reg_md_map = 0, .first.tl_cap_flags = UCT_IFACE_FLAG_GET_ZCOPY, diff --git a/src/ucp/rma/put_am.c b/src/ucp/rma/put_am.c index 1150978951d..6e73626298e 100644 --- a/src/ucp/rma/put_am.c +++ b/src/ucp/rma/put_am.c @@ -97,6 +97,7 @@ ucp_proto_put_am_bcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING | UCP_PROTO_COMMON_INIT_FLAG_RESUME, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .max_lanes = 1, .initial_reg_md_map = 0, .first.tl_cap_flags = UCT_IFACE_FLAG_AM_BCOPY, diff --git a/src/ucp/rma/put_offload.c b/src/ucp/rma/put_offload.c index e108d8bbb27..63bf4a45c0e 100644 --- a/src/ucp/rma/put_offload.c +++ b/src/ucp/rma/put_offload.c @@ -70,6 +70,7 @@ ucp_proto_put_offload_short_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_SINGLE_FRAG | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_RMA, .tl_cap_flags = UCT_IFACE_FLAG_PUT_SHORT }; @@ -166,6 +167,7 @@ ucp_proto_put_offload_bcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_REMOTE_ACCESS | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .max_lanes = UCP_PROTO_RMA_MAX_BCOPY_LANES, .initial_reg_md_map = 0, .first.tl_cap_flags = UCT_IFACE_FLAG_PUT_BCOPY, @@ -254,6 +256,7 @@ ucp_proto_put_offload_zcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_REMOTE_ACCESS | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = init_params->select_param->mem_type, .max_lanes = context->config.ext.max_rma_lanes, .initial_reg_md_map = 0, .first.tl_cap_flags = UCT_IFACE_FLAG_PUT_ZCOPY, diff --git a/src/ucp/rndv/proto_rndv.c b/src/ucp/rndv/proto_rndv.c index 9ac51d37284..81e39416de6 100644 --- a/src/ucp/rndv/proto_rndv.c +++ b/src/ucp/rndv/proto_rndv.c @@ -209,6 +209,8 @@ ucp_proto_rndv_ctrl_perf(const ucp_proto_common_init_params_t *init_params, return status; } + ucs_assert(lane < UCP_MAX_LANES); + perf_attr.field_mask = UCT_PERF_ATTR_FIELD_OPERATION | UCT_PERF_ATTR_FIELD_SEND_PRE_OVERHEAD | UCT_PERF_ATTR_FIELD_SEND_POST_OVERHEAD | @@ -295,7 +297,8 @@ ucp_proto_rndv_ctrl_init_priv(const ucp_proto_rndv_ctrl_init_params_t *params, /* Use only memory domains for which the unpacking of the remote key was * successful */ - if (init_params->rkey_config_key != NULL) { + if ((init_params->rkey_config_key != NULL) && + !(params->super.flags & UCP_PROTO_COMMON_KEEP_MD_MAP)) { rpriv->md_map &= ~init_params->rkey_config_key->unreachable_md_map; } @@ -504,6 +507,7 @@ ucp_proto_rndv_find_ctrl_lane(const ucp_proto_init_params_t *params) UCP_PROTO_COMMON_INIT_FLAG_HDR_ONLY, UCP_PROTO_COMMON_OFFSET_INVALID, 1, UCP_LANE_TYPE_AM, + UCS_MEMORY_TYPE_UNKNOWN, UCT_IFACE_FLAG_AM_BCOPY, 1, 0, &lane); if (num_lanes == 0) { @@ -538,6 +542,7 @@ void ucp_proto_rndv_rts_probe(const ucp_proto_init_params_t *init_params) .super.flags = UCP_PROTO_COMMON_INIT_FLAG_RESPONSE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .remote_op_id = UCP_OP_ID_RNDV_RECV, .lane = ucp_proto_rndv_find_ctrl_lane(init_params), .perf_bias = context->config.ext.rndv_perf_diff / 100.0, @@ -642,6 +647,8 @@ ucp_proto_rndv_bulk_init(const ucp_proto_multi_init_params_t *init_params, goto out_destroy_bulk_perf; } + rpriv->frag_mem_type = init_params->super.reg_mem_type; + if (rpriv->super.lane == UCP_NULL_LANE) { /* Add perf without ACK in case of pipeline */ *perf_p = bulk_perf; diff --git a/src/ucp/rndv/proto_rndv.h b/src/ucp/rndv/proto_rndv.h index 54b4391fb11..14e0897c9cb 100644 --- a/src/ucp/rndv/proto_rndv.h +++ b/src/ucp/rndv/proto_rndv.h @@ -69,6 +69,11 @@ typedef struct { typedef struct { ucp_proto_rndv_ack_priv_t super; + /* Memory type of fragment buffers which are used by get/mtype and put/mtype + * protocols. + * TODO: Create a separate struct for mtype protocols and move it there. */ + ucs_memory_type_t frag_mem_type; + /* Multi-lane common part. Must be the last field, see @ref ucp_proto_multi_priv_t */ ucp_proto_multi_priv_t mpriv; diff --git a/src/ucp/rndv/rndv.c b/src/ucp/rndv/rndv.c index 903bac92bfc..c90298a1203 100644 --- a/src/ucp/rndv/rndv.c +++ b/src/ucp/rndv/rndv.c @@ -21,6 +21,27 @@ #include +static UCS_F_ALWAYS_INLINE int +ucp_rndv_frag_mem_type(ucp_context_t *context) +{ + ucs_memory_type_t frag_mem_type; + + if (context->config.ext.rndv_frag_mem_types == 0) { + return UCS_MEMORY_TYPE_HOST; + } + + /* Just one fragment memory type can be specified for proto v1, so take the + * first one from the map. Anyway for proto v1 UCX_RNDV_FRAG_MEM_TYPE is + * supposed to be used. + */ + frag_mem_type = ucs_ffs64(context->config.ext.rndv_frag_mem_types); + + ucs_assertv(frag_mem_type < UCS_MEMORY_TYPE_UNKNOWN, "frag_mem_type = %u", + frag_mem_type); + + return frag_mem_type; +} + static UCS_F_ALWAYS_INLINE int ucp_rndv_memtype_direct_support(ucp_context_h context, size_t reg_length, uint64_t reg_mem_types) @@ -70,7 +91,7 @@ static int ucp_rndv_is_recv_pipeline_needed(ucp_request_t *rndv_req, const ucp_ep_config_t *ep_config = ucp_ep_config(rndv_req->send.ep); ucp_context_h context = rndv_req->send.ep->worker->context; int found = 0; - ucs_memory_type_t frag_mem_type = context->config.ext.rndv_frag_mem_type; + ucs_memory_type_t frag_mem_type = ucp_rndv_frag_mem_type(context); ucp_md_index_t md_index; uct_md_attr_v2_t *md_attr; uint64_t mem_types; @@ -1170,7 +1191,7 @@ static void ucp_rndv_send_frag_get_mem_type(ucp_request_t *sreq, size_t length, uct_completion_callback_t comp_cb) { ucp_worker_h worker = sreq->send.ep->worker; - ucs_memory_type_t frag_mem_type = worker->context->config.ext.rndv_frag_mem_type; + ucs_memory_type_t frag_mem_type = ucp_rndv_frag_mem_type(worker->context); ucp_request_t *freq; ucp_mem_desc_t *mdesc; @@ -1254,7 +1275,7 @@ ucp_rndv_recv_start_get_pipeline(ucp_worker_h worker, ucp_request_t *rndv_req, size_t frag_size; /* use ucp_rkey_packed_mem_type(rkey_buffer) with non-host fragments */ - frag_mem_type = context->config.ext.rndv_frag_mem_type; + frag_mem_type = ucp_rndv_frag_mem_type(context); frag_size = context->config.ext.rndv_frag_size[frag_mem_type]; min_zcopy = config->rndv.get_zcopy.min; @@ -1337,7 +1358,7 @@ static void ucp_rndv_send_frag_rtr(ucp_worker_h worker, ucp_request_t *rndv_req, ucp_trace_req(rreq, "using rndv pipeline protocol rndv_req %p", rndv_req); offset = 0; - frag_mem_type = worker->context->config.ext.rndv_frag_mem_type; + frag_mem_type = ucp_rndv_frag_mem_type(worker->context); max_frag_size = worker->context->config.ext.rndv_frag_size[frag_mem_type]; num_frags = ucs_div_round_up(rndv_rts_hdr->size, max_frag_size); @@ -2069,7 +2090,7 @@ static ucs_status_t ucp_rndv_send_start_put_pipeline(ucp_request_t *sreq, ucp_worker_h worker = sreq->send.ep->worker; ucp_context_h context = worker->context; size_t rndv_base_offset = rndv_rtr_hdr->offset; - ucs_memory_type_t frag_mem_type = context->config.ext.rndv_frag_mem_type; + ucs_memory_type_t frag_mem_type = ucp_rndv_frag_mem_type(context); size_t rndv_size = ucs_min(rndv_rtr_hdr->size, sreq->send.length); const uct_md_attr_v2_t *md_attr; diff --git a/src/ucp/rndv/rndv_am.c b/src/ucp/rndv/rndv_am.c index 09b7f2c0acf..40b5ca3894c 100644 --- a/src/ucp/rndv/rndv_am.c +++ b/src/ucp/rndv/rndv_am.c @@ -114,6 +114,8 @@ static void ucp_rndv_am_bcopy_probe(const ucp_proto_init_params_t *init_params) .super.flags = UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING | UCP_PROTO_COMMON_INIT_FLAG_RESUME, + .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .first.tl_cap_flags = UCT_IFACE_FLAG_AM_BCOPY, .middle.tl_cap_flags = UCT_IFACE_FLAG_AM_BCOPY }; @@ -186,6 +188,8 @@ static void ucp_rndv_am_zcopy_probe(const ucp_proto_init_params_t *init_params) .super.flags = UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, + .super.exclude_map = 0, + .super.reg_mem_type = init_params->select_param->mem_type, .first.tl_cap_flags = UCT_IFACE_FLAG_AM_ZCOPY, .middle.tl_cap_flags = UCT_IFACE_FLAG_AM_ZCOPY }; diff --git a/src/ucp/rndv/rndv_get.c b/src/ucp/rndv/rndv_get.c index 9a6cc5d40b4..b01a4bac16b 100644 --- a/src/ucp/rndv/rndv_get.c +++ b/src/ucp/rndv/rndv_get.c @@ -20,7 +20,8 @@ ucp_proto_rndv_get_common_probe(const ucp_proto_init_params_t *init_params, uint64_t rndv_modes, size_t max_length, uct_ep_operation_t memtype_op, unsigned flags, ucp_md_map_t initial_reg_md_map, - int support_ppln) + int support_ppln, + ucs_memory_type_t reg_mem_type) { ucp_context_t *context = init_params->worker->context; ucp_proto_multi_init_params_t params = { @@ -45,6 +46,7 @@ ucp_proto_rndv_get_common_probe(const ucp_proto_init_params_t *init_params, UCP_PROTO_COMMON_INIT_FLAG_RESPONSE | UCP_PROTO_COMMON_INIT_FLAG_MIN_FRAG, .super.exclude_map = 0, + .super.reg_mem_type = reg_mem_type, .max_lanes = context->config.ext.max_rndv_lanes, .initial_reg_md_map = initial_reg_md_map, .first.tl_cap_flags = UCT_IFACE_FLAG_GET_ZCOPY, @@ -121,7 +123,7 @@ ucp_proto_rndv_get_zcopy_probe(const ucp_proto_init_params_t *init_params) UCT_EP_OP_LAST, UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, - 0, 0); + 0, 0, init_params->select_param->mem_type); } static void @@ -269,11 +271,9 @@ ucp_proto_rndv_get_mtype_fetch_completion(uct_completion_t *uct_comp) ucp_request_t *req = ucs_container_of(uct_comp, ucp_request_t, send.state.uct_comp); - ucp_proto_rndv_mtype_copy(req, req->send.rndv.mdesc->ptr, - ucp_proto_rndv_mtype_get_req_memh(req), - uct_ep_put_zcopy, - ucp_proto_rndv_get_mtype_unpack_completion, - "out to"); + ucp_proto_rndv_mdesc_mtype_copy(req, uct_ep_put_zcopy, + ucp_proto_rndv_get_mtype_unpack_completion, + "out to"); } static ucs_status_t @@ -283,8 +283,11 @@ ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req) const ucp_proto_rndv_bulk_priv_t *rpriv; ucs_status_t status; + /* coverity[tainted_data_downcast] */ + rpriv = req->send.proto_config->priv; + if (!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED)) { - status = ucp_proto_rndv_mtype_request_init(req); + status = ucp_proto_rndv_mtype_request_init(req, rpriv->frag_mem_type); if (status != UCS_OK) { ucp_proto_request_abort(req, status); return UCS_OK; @@ -296,8 +299,6 @@ ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req) req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; } - /* coverity[tainted_data_downcast] */ - rpriv = req->send.proto_config->priv; return ucp_proto_multi_progress(req, &rpriv->mpriv, ucp_proto_rndv_get_mtype_send_func, ucp_request_invoke_uct_completion_success, @@ -307,27 +308,35 @@ ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req) static void ucp_proto_rndv_get_mtype_probe(const ucp_proto_init_params_t *init_params) { + ucp_context_t *context = init_params->worker->context; ucp_md_map_t mdesc_md_map; ucs_status_t status; size_t frag_size; + ucs_memory_type_t frag_mem_type; - status = ucp_proto_rndv_mtype_init(init_params, &mdesc_md_map, &frag_size); - if (status != UCS_OK) { - return; - } + ucs_for_each_bit(frag_mem_type, context->config.ext.rndv_frag_mem_types) { + status = ucp_proto_rndv_mtype_init(init_params, frag_mem_type, + &mdesc_md_map, &frag_size); + if (status != UCS_OK) { + continue; + } - ucp_proto_rndv_get_common_probe(init_params, - UCS_BIT(UCP_RNDV_MODE_GET_PIPELINE), - frag_size, UCT_EP_OP_PUT_ZCOPY, 0, - mdesc_md_map, 1); + ucp_proto_rndv_get_common_probe(init_params, + UCS_BIT(UCP_RNDV_MODE_GET_PIPELINE), + frag_size, UCT_EP_OP_PUT_ZCOPY, 0, + mdesc_md_map, 1, frag_mem_type); + } } static void ucp_proto_rndv_get_mtype_query(const ucp_proto_query_params_t *params, ucp_proto_query_attr_t *attr) { + const ucp_proto_rndv_bulk_priv_t *rpriv = params->priv; + ucp_proto_rndv_bulk_query(params, attr); - ucp_proto_rndv_mtype_query_desc(params, attr, UCP_PROTO_RNDV_GET_DESC); + ucp_proto_rndv_mtype_query_desc(params, rpriv->frag_mem_type, attr, + UCP_PROTO_RNDV_GET_DESC); } static ucs_status_t ucp_proto_rndv_get_mtype_reset(ucp_request_t *req) diff --git a/src/ucp/rndv/rndv_mtype.inl b/src/ucp/rndv/rndv_mtype.inl index 55e021a83e9..1a582eb0600 100644 --- a/src/ucp/rndv/rndv_mtype.inl +++ b/src/ucp/rndv/rndv_mtype.inl @@ -13,11 +13,9 @@ static ucp_ep_h ucp_proto_rndv_mtype_ep(ucp_worker_t *worker, + ucs_memory_type_t frag_mem_type, ucs_memory_type_t buf_mem_type) { - ucs_memory_type_t frag_mem_type = - worker->context->config.ext.rndv_frag_mem_type; - if (worker->mem_type_ep[buf_mem_type] != NULL) { return worker->mem_type_ep[buf_mem_type]; } @@ -27,15 +25,15 @@ static ucp_ep_h ucp_proto_rndv_mtype_ep(ucp_worker_t *worker, static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mtype_init(const ucp_proto_init_params_t *init_params, + ucs_memory_type_t frag_mem_type, ucp_md_map_t *mdesc_md_map_p, size_t *frag_size_p) { - ucp_worker_h worker = init_params->worker; - ucp_context_h context = worker->context; - ucs_memory_type_t mem_type = init_params->select_param->mem_type; - ucs_memory_type_t frag_mem_type = context->config.ext.rndv_frag_mem_type; + ucp_worker_h worker = init_params->worker; + ucp_context_h context = worker->context; + ucs_memory_type_t mem_type = init_params->select_param->mem_type; if ((init_params->select_param->dt_class != UCP_DATATYPE_CONTIG) || - (ucp_proto_rndv_mtype_ep(worker, mem_type) == NULL) || + (ucp_proto_rndv_mtype_ep(worker, frag_mem_type, mem_type) == NULL) || !ucp_proto_init_check_op(init_params, UCP_PROTO_RNDV_OP_ID_MASK)) { return UCS_ERR_UNSUPPORTED; } @@ -47,11 +45,10 @@ ucp_proto_rndv_mtype_init(const ucp_proto_init_params_t *init_params, } static UCS_F_ALWAYS_INLINE ucs_status_t -ucp_proto_rndv_mtype_request_init(ucp_request_t *req) +ucp_proto_rndv_mtype_request_init(ucp_request_t *req, + ucs_memory_type_t frag_mem_type) { - ucp_worker_h worker = req->send.ep->worker; - ucs_memory_type_t frag_mem_type = - worker->context->config.ext.rndv_frag_mem_type; + ucp_worker_h worker = req->send.ep->worker; req->send.rndv.mdesc = ucp_rndv_mpool_get(worker, frag_mem_type, UCS_SYS_DEVICE_ID_UNKNOWN); @@ -82,6 +79,7 @@ static ucp_ep_h ucp_proto_rndv_req_mtype_ep(ucp_request_t *req) ucp_ep_h mem_type_ep; mem_type_ep = ucp_proto_rndv_mtype_ep(req->send.ep->worker, + req->send.rndv.mdesc->memh->mem_type, req->send.state.dt_iter.mem_info.type); ucs_assert(mem_type_ep != NULL); @@ -130,10 +128,12 @@ ucp_proto_rndv_mtype_next_iov(ucp_request_t *req, static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mtype_copy( ucp_request_t *req, void *buffer, uct_mem_h memh, - uct_ep_put_zcopy_func_t copy_func, uct_completion_callback_t comp_func, - const char *mode) + ucs_memory_type_t frag_mem_type, uct_ep_put_zcopy_func_t copy_func, + uct_completion_callback_t comp_func, const char *mode) { - ucp_ep_h mtype_ep = ucp_proto_rndv_req_mtype_ep(req); + ucp_ep_h mtype_ep = ucp_proto_rndv_mtype_ep( + req->send.ep->worker, frag_mem_type, + req->send.state.dt_iter.mem_info.type); ucp_lane_index_t lane = ucp_ep_config(mtype_ep)->key.rma_bw_lanes[0]; ucp_context_t UCS_V_UNUSED *context = req->send.ep->worker->context; ucs_status_t status; @@ -142,7 +142,7 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mtype_copy( ucp_trace_req(req, "buffer %p copy-%s %p %s-%s using memtype-ep %p lane[%d]", buffer, mode, req->send.state.dt_iter.type.contig.buffer, ucs_memory_type_names[req->send.state.dt_iter.mem_info.type], - ucs_memory_type_names[context->config.ext.rndv_frag_mem_type], + ucs_memory_type_names[frag_mem_type], mtype_ep, lane); ucp_proto_completion_init(&req->send.state.uct_comp, comp_func); @@ -167,20 +167,39 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mtype_copy( return status; } +static UCS_F_ALWAYS_INLINE ucs_status_t +ucp_proto_rndv_mdesc_mtype_copy(ucp_request_t *req, + uct_ep_put_zcopy_func_t copy_func, + uct_completion_callback_t comp_func, + const char *mode) +{ + ucp_mem_desc_t *mdesc = req->send.rndv.mdesc; + + return ucp_proto_rndv_mtype_copy( + req, mdesc->ptr, ucp_proto_rndv_mtype_get_req_memh(req), + mdesc->memh->mem_type, copy_func, comp_func, mode); +} + static UCS_F_ALWAYS_INLINE void ucp_proto_rndv_mtype_query_desc(const ucp_proto_query_params_t *params, + ucs_memory_type_t frag_mem_type, ucp_proto_query_attr_t *attr, const char *xfer_desc) { UCS_STRING_BUFFER_FIXED(strb, attr->desc, sizeof(attr->desc)); ucp_context_h context = params->worker->context; ucs_memory_type_t mem_type = params->select_param->mem_type; - ucp_ep_h mtype_ep = ucp_proto_rndv_mtype_ep(params->worker, - mem_type); + ucp_ep_h mtype_ep; ucp_lane_index_t lane; ucp_rsc_index_t rsc_index; const char *tl_name; + /* Make coverity happy */ + ucs_assertv(frag_mem_type < UCS_MEMORY_TYPE_UNKNOWN, "frag_mem_type = %u", + frag_mem_type); + + mtype_ep = ucp_proto_rndv_mtype_ep(params->worker, frag_mem_type, + mem_type); ucs_assert(mtype_ep != NULL); lane = ucp_ep_config(mtype_ep)->key.rma_bw_lanes[0]; @@ -196,6 +215,9 @@ ucp_proto_rndv_mtype_query_desc(const ucp_proto_query_params_t *params, if (ucp_proto_select_op_id(params->select_param) == UCP_OP_ID_RNDV_RECV) { ucs_string_buffer_appendf(&strb, ", %s", tl_name); } + + ucs_string_buffer_appendf(&strb, ", frag %s", + ucs_memory_type_names[frag_mem_type]); } #endif diff --git a/src/ucp/rndv/rndv_put.c b/src/ucp/rndv/rndv_put.c index d81f3522a0e..66c2c566376 100644 --- a/src/ucp/rndv/rndv_put.c +++ b/src/ucp/rndv/rndv_put.c @@ -227,7 +227,8 @@ ucp_proto_rndv_put_common_probe(const ucp_proto_init_params_t *init_params, uct_ep_operation_t memtype_op, unsigned flags, ucp_md_map_t initial_reg_md_map, uct_completion_callback_t comp_cb, - int support_ppln, uint8_t stat_counter) + int support_ppln, uint8_t stat_counter, + ucs_memory_type_t reg_mem_type) { const size_t atp_size = sizeof(ucp_rndv_ack_hdr_t); ucp_context_t *context = init_params->worker->context; @@ -251,6 +252,7 @@ ucp_proto_rndv_put_common_probe(const ucp_proto_init_params_t *init_params, UCP_PROTO_COMMON_INIT_FLAG_REMOTE_ACCESS | UCP_PROTO_COMMON_INIT_FLAG_MIN_FRAG, .super.exclude_map = 0, + .super.reg_mem_type = reg_mem_type, .max_lanes = context->config.ext.max_rndv_lanes, .initial_reg_md_map = initial_reg_md_map, .first.tl_cap_flags = UCT_IFACE_FLAG_PUT_ZCOPY, @@ -418,7 +420,8 @@ ucp_proto_rndv_put_zcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, 0, ucp_proto_rndv_put_zcopy_completion, 0, - UCP_WORKER_STAT_RNDV_PUT_ZCOPY); + UCP_WORKER_STAT_RNDV_PUT_ZCOPY, + init_params->select_param->mem_type); } static void @@ -500,12 +503,15 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_put_mtype_send_func( static ucs_status_t ucp_proto_rndv_put_mtype_copy_progress(uct_pending_req_t *uct_req) { - ucp_request_t *req = ucs_container_of(uct_req, ucp_request_t, send.uct); + ucp_request_t *req = ucs_container_of(uct_req, + ucp_request_t, + send.uct); + const ucp_proto_rndv_put_priv_t *rpriv = req->send.proto_config->priv; ucs_status_t status; ucs_assert(!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED)); - status = ucp_proto_rndv_mtype_request_init(req); + status = ucp_proto_rndv_mtype_request_init(req, rpriv->bulk.frag_mem_type); if (status != UCS_OK) { ucp_proto_request_abort(req, status); return UCS_OK; @@ -513,11 +519,9 @@ ucp_proto_rndv_put_mtype_copy_progress(uct_pending_req_t *uct_req) ucp_proto_rndv_put_common_request_init(req); req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; - ucp_proto_rndv_mtype_copy(req, req->send.rndv.mdesc->ptr, - ucp_proto_rndv_mtype_get_req_memh(req), - uct_ep_get_zcopy, - ucp_proto_rndv_put_mtype_pack_completion, - "in from"); + ucp_proto_rndv_mdesc_mtype_copy(req, uct_ep_get_zcopy, + ucp_proto_rndv_put_mtype_pack_completion, + "in from"); return UCS_OK; } @@ -565,8 +569,20 @@ ucp_proto_rndv_put_mtype_probe(const ucp_proto_init_params_t *init_params) ucs_status_t status; size_t frag_size; unsigned flags; + ucs_memory_type_t frag_mem_type; - status = ucp_proto_rndv_mtype_init(init_params, &mdesc_md_map, &frag_size); + if (init_params->rkey_config_key == NULL) { + return; + } + + /* Can initialize only the same fragment type as received in RTR + * because pipeline protocols assume that both peers use the same + * fragment sizes (and they are different for different memory types by + * default). */ + frag_mem_type = init_params->rkey_config_key->mem_type; + + status = ucp_proto_rndv_mtype_init(init_params, frag_mem_type, + &mdesc_md_map, &frag_size); if (status != UCS_OK) { return; } @@ -580,21 +596,22 @@ ucp_proto_rndv_put_mtype_probe(const ucp_proto_init_params_t *init_params) comp_cb = ucp_proto_rndv_put_mtype_completion; } - ucp_proto_rndv_put_common_probe(init_params, - UCS_BIT(UCP_RNDV_MODE_PUT_PIPELINE), - frag_size, UCT_EP_OP_GET_ZCOPY, flags, - mdesc_md_map, comp_cb, 1, - UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY); + ucp_proto_rndv_put_common_probe( + init_params, UCS_BIT(UCP_RNDV_MODE_PUT_PIPELINE), frag_size, + UCT_EP_OP_GET_ZCOPY, flags, mdesc_md_map, comp_cb, 1, + UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY, frag_mem_type); } static void ucp_proto_rndv_put_mtype_query(const ucp_proto_query_params_t *params, ucp_proto_query_attr_t *attr) { + const ucp_proto_rndv_put_priv_t *rpriv = params->priv; const char *put_desc; put_desc = ucp_proto_rndv_put_common_query(params, attr); - ucp_proto_rndv_mtype_query_desc(params, attr, put_desc); + ucp_proto_rndv_mtype_query_desc(params, rpriv->bulk.frag_mem_type, attr, + put_desc); } ucp_proto_t ucp_rndv_put_mtype_proto = { diff --git a/src/ucp/rndv/rndv_rkey_ptr.c b/src/ucp/rndv/rndv_rkey_ptr.c index f88b526e288..0408db1cfc6 100644 --- a/src/ucp/rndv/rndv_rkey_ptr.c +++ b/src/ucp/rndv/rndv_rkey_ptr.c @@ -93,6 +93,7 @@ ucp_proto_rndv_rkey_ptr_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_REMOTE_ACCESS | UCP_PROTO_COMMON_INIT_FLAG_SINGLE_FRAG, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_RKEY_PTR, .tl_cap_flags = 0, }; @@ -253,6 +254,7 @@ ucp_proto_rndv_rkey_ptr_mtype_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_REMOTE_ACCESS, .super.exclude_map = (rkey_ptr_lane == UCP_NULL_LANE) ? 0 : UCS_BIT(rkey_ptr_lane), + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_LAST, .tl_cap_flags = 0 }; @@ -268,8 +270,9 @@ ucp_proto_rndv_rkey_ptr_mtype_probe(const ucp_proto_init_params_t *init_params) return; } - status = ucp_proto_rndv_mtype_init(init_params, &mdesc_md_map, - ¶ms.super.max_length); + /* 2-stage ppln protocols work with host staging buffers only */ + status = ucp_proto_rndv_mtype_init(init_params, UCS_MEMORY_TYPE_HOST, + &mdesc_md_map, ¶ms.super.max_length); if (status != UCS_OK) { return; } @@ -342,7 +345,7 @@ ucp_proto_rndv_rkey_ptr_mtype_copy_progress(uct_pending_req_t *uct_req) req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; ucp_proto_rndv_mtype_copy(req, ppln_data->local_ptr, ppln_data->uct_memh, - uct_ep_get_zcopy, + UCS_MEMORY_TYPE_HOST, uct_ep_get_zcopy, ucp_proto_rndv_rkey_ptr_mtype_copy_completion, "in from"); @@ -368,7 +371,7 @@ ucp_proto_rndv_rkey_ptr_mtype_query(const ucp_proto_query_params_t *params, const char *desc = UCP_PROTO_RNDV_RKEY_PTR_DESC; ucp_rndv_rkey_ptr_query_common(params, attr); - ucp_proto_rndv_mtype_query_desc(params, attr, desc); + ucp_proto_rndv_mtype_query_desc(params, UCS_MEMORY_TYPE_HOST, attr, desc); } ucp_proto_t ucp_rndv_rkey_ptr_mtype_proto = { diff --git a/src/ucp/rndv/rndv_rtr.c b/src/ucp/rndv/rndv_rtr.c index d1802e2c262..f92202fb5b1 100644 --- a/src/ucp/rndv/rndv_rtr.c +++ b/src/ucp/rndv/rndv_rtr.c @@ -32,6 +32,11 @@ typedef struct { ucp_proto_rndv_rtr_data_received_cb_t data_received; } ucp_proto_rndv_rtr_priv_t; +typedef struct { + ucp_proto_rndv_rtr_priv_t super; + ucs_memory_type_t frag_mem_type; +} ucp_proto_rndv_rtr_mtype_priv_t; + static UCS_F_ALWAYS_INLINE void ucp_proto_rtr_common_request_init(ucp_request_t *req) { @@ -172,6 +177,7 @@ static void ucp_proto_rndv_rtr_probe(const ucp_proto_init_params_t *init_params) .super.flags = UCP_PROTO_COMMON_INIT_FLAG_RESPONSE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .remote_op_id = UCP_OP_ID_RNDV_SEND, .lane = ucp_proto_rndv_find_ctrl_lane(init_params), .perf_bias = 0.0, @@ -334,21 +340,20 @@ ucp_proto_rndv_rtr_mtype_data_received(ucp_request_t *req, int in_buffer) } else { /* Data was not placed in user buffer, which means it was placed to the remote address we published - the rendezvous fragment */ - ucp_proto_rndv_mtype_copy(req, req->send.rndv.mdesc->ptr, - ucp_proto_rndv_mtype_get_req_memh(req), - uct_ep_put_zcopy, - ucp_proto_rndv_rtr_mtype_copy_completion, - "out to"); + ucp_proto_rndv_mdesc_mtype_copy( + req, uct_ep_put_zcopy, ucp_proto_rndv_rtr_mtype_copy_completion, + "out to"); } } static ucs_status_t ucp_proto_rndv_rtr_mtype_progress(uct_pending_req_t *self) { ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); + const ucp_proto_rndv_rtr_mtype_priv_t *rpriv = req->send.proto_config->priv; ucs_status_t status; if (!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED)) { - status = ucp_proto_rndv_mtype_request_init(req); + status = ucp_proto_rndv_mtype_request_init(req, rpriv->frag_mem_type); if (status != UCS_OK) { ucp_proto_request_abort(req, status); return UCS_OK; @@ -385,16 +390,18 @@ ucp_proto_rndv_rtr_mtype_probe(const ucp_proto_init_params_t *init_params) .super.send_op = UCT_EP_OP_AM_BCOPY, .super.memtype_op = UCT_EP_OP_LAST, .super.flags = UCP_PROTO_COMMON_INIT_FLAG_RESPONSE | - UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, + UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING | + UCP_PROTO_COMMON_KEEP_MD_MAP, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .remote_op_id = UCP_OP_ID_RNDV_SEND, .lane = ucp_proto_rndv_find_ctrl_lane(init_params), .perf_bias = 0.0, - .mem_info.type = context->config.ext.rndv_frag_mem_type, .mem_info.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN, .ctrl_msg_name = UCP_PROTO_RNDV_RTR_NAME, }; - ucp_proto_rndv_rtr_priv_t rpriv; + ucs_memory_type_t frag_mem_type; + ucp_proto_rndv_rtr_mtype_priv_t rpriv; ucp_md_map_t dummy_md_map; ucp_md_index_t md_index; ucs_status_t status; @@ -404,56 +411,67 @@ ucp_proto_rndv_rtr_mtype_probe(const ucp_proto_init_params_t *init_params) return; } - status = ucp_proto_rndv_mtype_init(init_params, &dummy_md_map, - ¶ms.super.max_length); - if (status != UCS_OK) { - return; - } + ucs_for_each_bit(frag_mem_type, context->config.ext.rndv_frag_mem_types) { + status = ucp_proto_rndv_mtype_init(init_params, frag_mem_type, + &dummy_md_map, + ¶ms.super.max_length); + if (status != UCS_OK) { + continue; + } - status = ucp_proto_perf_create("rtr/mtype unpack", ¶ms.unpack_perf); - if (status != UCS_OK) { - return; - } + params.mem_info.type = frag_mem_type; + + status = ucp_mm_get_alloc_md_index(context, &md_index, frag_mem_type); + if ((status == UCS_OK) && (md_index != UCP_NULL_RESOURCE)) { + params.md_map = UCS_BIT(md_index); + } else if (frag_mem_type == UCS_MEMORY_TYPE_HOST) { + params.md_map = 0; + } else { + /* To use non-host staging buffers it should be possible to + * allocate them with MD */ + continue; + } - status = ucp_proto_init_add_buffer_copy_time( - init_params->worker, "unpack copy", params.mem_info.type, - init_params->select_param->mem_type, UCT_EP_OP_PUT_ZCOPY, - params.super.min_length, params.super.max_length, 1, - params.unpack_perf); - if (status != UCS_OK) { - goto out_unpack_perf_destroy; - } + status = ucp_proto_perf_create("rtr/mtype unpack", ¶ms.unpack_perf); + if (status != UCS_OK) { + return; + } - status = ucp_mm_get_alloc_md_index(context, &md_index, - params.mem_info.type); - if ((status != UCS_OK) || (md_index == UCP_NULL_RESOURCE)) { - params.md_map = 0; - } else { - params.md_map = UCS_BIT(md_index); - } + status = ucp_proto_init_add_buffer_copy_time( + init_params->worker, "unpack copy", frag_mem_type, + init_params->select_param->mem_type, UCT_EP_OP_PUT_ZCOPY, + params.super.min_length, params.super.max_length, 1, + params.unpack_perf); + if (status != UCS_OK) { + goto out_unpack_perf_destroy; + } - rpriv.pack_cb = ucp_proto_rndv_rtr_mtype_pack; - rpriv.data_received = ucp_proto_rndv_rtr_mtype_data_received; + rpriv.super.pack_cb = ucp_proto_rndv_rtr_mtype_pack; + rpriv.super.data_received = ucp_proto_rndv_rtr_mtype_data_received; + rpriv.frag_mem_type = frag_mem_type; - ucp_proto_rndv_ctrl_probe(¶ms, &rpriv, sizeof(rpriv)); + ucp_proto_rndv_ctrl_probe(¶ms, &rpriv, sizeof(rpriv)); out_unpack_perf_destroy: - ucp_proto_perf_destroy(params.unpack_perf); + ucp_proto_perf_destroy(params.unpack_perf); + } } static void ucp_proto_rndv_rtr_mtype_query(const ucp_proto_query_params_t *params, ucp_proto_query_attr_t *attr) { - const ucp_proto_rndv_ctrl_priv_t *rpriv = params->priv; + const ucp_proto_rndv_rtr_mtype_priv_t *rpriv = params->priv; ucp_proto_query_attr_t remote_attr; - ucp_proto_config_query(params->worker, &rpriv->remote_proto_config, + ucp_proto_config_query(params->worker, + &rpriv->super.super.remote_proto_config, params->msg_length, &remote_attr); attr->is_estimation = 1; attr->max_msg_length = remote_attr.max_msg_length; - attr->lane_map = UCS_BIT(rpriv->lane); - ucp_proto_rndv_mtype_query_desc(params, attr, remote_attr.desc); + attr->lane_map = UCS_BIT(rpriv->super.super.lane); + ucp_proto_rndv_mtype_query_desc(params, rpriv->frag_mem_type, attr, + remote_attr.desc); ucs_strncpy_safe(attr->config, remote_attr.config, sizeof(attr->config)); } diff --git a/src/ucp/stream/stream_multi.c b/src/ucp/stream/stream_multi.c index 511ae9f5a9b..c61cb7a59b9 100644 --- a/src/ucp/stream/stream_multi.c +++ b/src/ucp/stream/stream_multi.c @@ -92,6 +92,7 @@ ucp_stream_multi_bcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING | UCP_PROTO_COMMON_INIT_FLAG_RESUME, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .max_lanes = 1, .initial_reg_md_map = 0, .opt_align_offs = UCP_PROTO_COMMON_OFFSET_INVALID, @@ -165,6 +166,7 @@ ucp_stream_multi_zcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = init_params->select_param->mem_type, .max_lanes = 1, .initial_reg_md_map = 0, .opt_align_offs = UCP_PROTO_COMMON_OFFSET_INVALID, diff --git a/src/ucp/tag/eager_multi.c b/src/ucp/tag/eager_multi.c index 184526fc572..9ed442a359a 100644 --- a/src/ucp/tag/eager_multi.c +++ b/src/ucp/tag/eager_multi.c @@ -70,6 +70,7 @@ static void ucp_proto_eager_bcopy_multi_common_probe( UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING | UCP_PROTO_COMMON_INIT_FLAG_RESUME, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .opt_align_offs = UCP_PROTO_COMMON_OFFSET_INVALID, .first.tl_cap_flags = UCT_IFACE_FLAG_AM_BCOPY, .middle.tl_cap_flags = UCT_IFACE_FLAG_AM_BCOPY @@ -241,6 +242,7 @@ ucp_proto_eager_zcopy_multi_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = init_params->select_param->mem_type, .opt_align_offs = UCP_PROTO_COMMON_OFFSET_INVALID, .first.tl_cap_flags = UCT_IFACE_FLAG_AM_ZCOPY, .middle.tl_cap_flags = UCT_IFACE_FLAG_AM_ZCOPY diff --git a/src/ucp/tag/eager_single.c b/src/ucp/tag/eager_single.c index 79d508a1c1f..2ff2d2c6480 100644 --- a/src/ucp/tag/eager_single.c +++ b/src/ucp/tag/eager_single.c @@ -67,6 +67,7 @@ ucp_proto_eager_short_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_AM, .tl_cap_flags = UCT_IFACE_FLAG_AM_SHORT }; @@ -139,6 +140,8 @@ ucp_proto_eager_bcopy_single_probe(const ucp_proto_init_params_t *init_params) .super.flags = UCP_PROTO_COMMON_INIT_FLAG_SINGLE_FRAG | UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, + .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_AM, .tl_cap_flags = UCT_IFACE_FLAG_AM_BCOPY }; @@ -186,6 +189,7 @@ ucp_proto_eager_zcopy_single_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, .super.exclude_map = 0, + .super.reg_mem_type = init_params->select_param->mem_type, .lane_type = UCP_LANE_TYPE_AM, .tl_cap_flags = UCT_IFACE_FLAG_AM_ZCOPY }; diff --git a/src/ucp/tag/offload/eager.c b/src/ucp/tag/offload/eager.c index 893b01738a2..55430bb3e54 100644 --- a/src/ucp/tag/offload/eager.c +++ b/src/ucp/tag/offload/eager.c @@ -65,6 +65,7 @@ static void ucp_proto_eager_tag_offload_short_probe( UCP_PROTO_COMMON_INIT_FLAG_RECV_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_TAG, .tl_cap_flags = UCT_IFACE_FLAG_TAG_EAGER_SHORT }; @@ -139,6 +140,7 @@ static void ucp_proto_eager_tag_offload_bcopy_probe_common( UCP_PROTO_COMMON_INIT_FLAG_RECV_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .lane_type = UCP_LANE_TYPE_TAG, .tl_cap_flags = UCT_IFACE_FLAG_TAG_EAGER_BCOPY }; @@ -249,6 +251,7 @@ static void ucp_proto_eager_tag_offload_zcopy_probe_common( UCP_PROTO_COMMON_INIT_FLAG_SINGLE_FRAG | UCP_PROTO_COMMON_INIT_FLAG_CAP_SEG_SIZE, .super.exclude_map = 0, + .super.reg_mem_type = init_params->select_param->mem_type, .lane_type = UCP_LANE_TYPE_TAG, .tl_cap_flags = UCT_IFACE_FLAG_TAG_EAGER_ZCOPY }; diff --git a/src/ucp/tag/offload/rndv.c b/src/ucp/tag/offload/rndv.c index d0c214aa3c5..e752b0db61e 100644 --- a/src/ucp/tag/offload/rndv.c +++ b/src/ucp/tag/offload/rndv.c @@ -45,6 +45,7 @@ ucp_tag_rndv_offload_proto_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_RECV_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_SINGLE_FRAG, .super.exclude_map = 0, + .super.reg_mem_type = init_params->select_param->mem_type, .lane_type = UCP_LANE_TYPE_TAG, .tl_cap_flags = UCT_IFACE_FLAG_TAG_RNDV_ZCOPY }; @@ -179,6 +180,7 @@ ucp_tag_rndv_offload_sw_proto_probe(const ucp_proto_init_params_t *init_params) .super.memtype_op = UCT_EP_OP_LAST, .super.flags = UCP_PROTO_COMMON_INIT_FLAG_RESPONSE, .super.exclude_map = 0, + .super.reg_mem_type = UCS_MEMORY_TYPE_UNKNOWN, .remote_op_id = UCP_OP_ID_RNDV_RECV, .lane = init_params->ep_config_key->tag_lane, .perf_bias = context->config.ext.rndv_perf_diff / 100.0, diff --git a/src/uct/base/uct_mem.c b/src/uct/base/uct_mem.c index 0b9cc8ec119..13d7e276b8a 100644 --- a/src/uct/base/uct_mem.c +++ b/src/uct/base/uct_mem.c @@ -69,6 +69,7 @@ ucs_status_t uct_mem_alloc(size_t length, const uct_alloc_method_t *methods, { const char *alloc_name; const uct_alloc_method_t *method; + ucs_log_level_t log_level; ucs_memory_type_t mem_type; uct_md_attr_t md_attr; ucs_status_t status; @@ -101,6 +102,8 @@ ucs_status_t uct_mem_alloc(size_t length, const uct_alloc_method_t *methods, mem_type = (params->field_mask & UCT_MEM_ALLOC_PARAM_FIELD_MEM_TYPE) ? params->mem_type : UCS_MEMORY_TYPE_HOST; alloc_length = length; + log_level = (flags & UCT_MD_MEM_FLAG_HIDE_ERRORS) ? UCS_LOG_LEVEL_DEBUG : + UCS_LOG_LEVEL_ERROR; ucs_trace("allocating %s: %s memory length %zu flags 0x%x", alloc_name, ucs_memory_type_names[mem_type], alloc_length, flags); @@ -122,7 +125,7 @@ ucs_status_t uct_mem_alloc(size_t length, const uct_alloc_method_t *methods, md = params->mds.mds[md_index]; status = uct_md_query(md, &md_attr); if (status != UCS_OK) { - ucs_error("Failed to query MD"); + ucs_log(log_level, "Failed to query MD"); goto out; } @@ -152,9 +155,10 @@ ucs_status_t uct_mem_alloc(size_t length, const uct_alloc_method_t *methods, mem_type, flags, alloc_name, &memh); if (status != UCS_OK) { - ucs_error("failed to allocate %zu bytes using md %s for %s: %s", - alloc_length, md->component->name, - alloc_name, ucs_status_string(status)); + ucs_log(log_level, + "failed to allocate %zu bytes using md %s for %s: %s", + alloc_length, md->component->name, alloc_name, + ucs_status_string(status)); goto out; } diff --git a/src/uct/cuda/cuda_copy/cuda_copy_md.c b/src/uct/cuda/cuda_copy/cuda_copy_md.c index dddbd00478d..a185dde3779 100644 --- a/src/uct/cuda/cuda_copy/cuda_copy_md.c +++ b/src/uct/cuda/cuda_copy/cuda_copy_md.c @@ -214,7 +214,8 @@ UCS_PROFILE_FUNC(ucs_status_t, uct_cuda_copy_mem_dereg, static ucs_status_t uct_cuda_copy_mem_alloc_fabric(uct_cuda_copy_md_t *md, - uct_cuda_copy_alloc_handle_t *alloc_handle) + uct_cuda_copy_alloc_handle_t *alloc_handle, + unsigned flags) { #if HAVE_CUDA_FABRIC CUmemAllocationProp prop = {}; @@ -224,6 +225,13 @@ uct_cuda_copy_mem_alloc_fabric(uct_cuda_copy_md_t *md, ucs_status_t status; CUdevice cu_device; + if (!(flags & UCT_MD_MEM_FLAG_HIDE_ERRORS) && + (md->config.enable_fabric == UCS_YES)) { + log_level = UCS_LOG_LEVEL_ERROR; + } else { + log_level = UCS_LOG_LEVEL_DEBUG; + } + status = UCT_CUDADRV_FUNC(cuCtxGetDevice(&cu_device), log_level); if (status != UCS_OK) { return status; @@ -304,21 +312,27 @@ uct_cuda_copy_mem_alloc(uct_md_h uct_md, size_t *length_p, void **address_p, uct_cuda_copy_md_t *md = ucs_derived_of(uct_md, uct_cuda_copy_md_t); ucs_status_t status; uct_cuda_copy_alloc_handle_t *alloc_handle; + ucs_log_level_t log_level; if ((mem_type != UCS_MEMORY_TYPE_CUDA_MANAGED) && (mem_type != UCS_MEMORY_TYPE_CUDA)) { return UCS_ERR_UNSUPPORTED; } + log_level = (flags & UCT_MD_MEM_FLAG_HIDE_ERRORS) ? UCS_LOG_LEVEL_DEBUG : + UCS_LOG_LEVEL_ERROR; + if (!uct_cuda_base_is_context_active()) { - ucs_error("attempt to allocate cuda memory without active context"); + ucs_log(log_level, + "attempt to allocate cuda memory without active context"); return UCS_ERR_NO_DEVICE; } alloc_handle = ucs_malloc(sizeof(*alloc_handle), "uct_cuda_copy_alloc_handle_t"); if (NULL == alloc_handle) { - ucs_error("failed to allocate memory for uct_cuda_copy_alloc_handle_t"); + ucs_log(log_level, + "failed to allocate memory for uct_cuda_copy_alloc_handle_t"); return UCS_ERR_NO_MEMORY; } @@ -327,7 +341,7 @@ uct_cuda_copy_mem_alloc(uct_md_h uct_md, size_t *length_p, void **address_p, if (mem_type == UCS_MEMORY_TYPE_CUDA) { if (md->config.enable_fabric != UCS_NO) { - status = uct_cuda_copy_mem_alloc_fabric(md, alloc_handle); + status = uct_cuda_copy_mem_alloc_fabric(md, alloc_handle, flags); if (status == UCS_OK) { goto allocated; } else { @@ -338,22 +352,24 @@ uct_cuda_copy_mem_alloc(uct_md_h uct_md, size_t *length_p, void **address_p, } if (md->config.enable_fabric != UCS_YES) { - status = UCT_CUDADRV_FUNC_LOG_ERR( - cuMemAlloc(&alloc_handle->ptr, alloc_handle->length)); + status = UCT_CUDADRV_FUNC(cuMemAlloc(&alloc_handle->ptr, + alloc_handle->length), + log_level); if (status == UCS_OK) { goto allocated; } } - ucs_error("unable to allocate cuda memory of length %ld bytes", - alloc_handle->length); + ucs_log(log_level, "unable to allocate cuda memory of length %ld bytes", + alloc_handle->length); status = UCS_ERR_NO_MEMORY; } else if (mem_type == UCS_MEMORY_TYPE_CUDA_MANAGED) { - status = UCT_CUDADRV_FUNC_LOG_ERR( + status = UCT_CUDADRV_FUNC( cuMemAllocManaged(&alloc_handle->ptr, alloc_handle->length, - CU_MEM_ATTACH_GLOBAL)); + CU_MEM_ATTACH_GLOBAL), log_level); } else { - ucs_error("allocation mem_types supported: cuda, cuda-managed"); + ucs_log(log_level, + "allocation mem_types supported: cuda, cuda-managed"); status = UCS_ERR_INVALID_PARAM; } diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_md.c b/src/uct/cuda/cuda_ipc/cuda_ipc_md.c index 17c6f14b2fa..5922f6e9f57 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_md.c +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_md.c @@ -61,14 +61,22 @@ static uct_cuda_ipc_dev_cache_t *uct_cuda_ipc_create_dev_cache(int dev_num) static uct_cuda_ipc_dev_cache_t * uct_cuda_ipc_get_dev_cache(uct_cuda_ipc_component_t *component, - const CUuuid *uuid) + uct_cuda_ipc_rkey_t *rkey) { khash_t(cuda_ipc_uuid_hash) *hash = &component->uuid_hash; + uct_cuda_ipc_uuid_hash_key_t key; uct_cuda_ipc_dev_cache_t *cache; khiter_t iter; int ret; - iter = kh_put(cuda_ipc_uuid_hash, hash, *uuid, &ret); + key.uuid = rkey->uuid; +#if HAVE_CUDA_FABRIC + key.type = rkey->ph.handle_type; +#else + key.type = 0; +#endif + + iter = kh_put(cuda_ipc_uuid_hash, hash, key, &ret); if (ret == UCS_KH_PUT_KEY_PRESENT) { return kh_val(hash, iter); } else if ((ret == UCS_KH_PUT_BUCKET_EMPTY) || @@ -258,7 +266,7 @@ uct_cuda_ipc_is_peer_accessible(uct_cuda_ipc_component_t *component, pthread_mutex_lock(&component->lock); - cache = uct_cuda_ipc_get_dev_cache(component, &rkey->uuid); + cache = uct_cuda_ipc_get_dev_cache(component, rkey); if (ucs_unlikely(NULL == cache)) { status = UCS_ERR_NO_RESOURCE; goto err; diff --git a/src/uct/cuda/cuda_ipc/cuda_ipc_md.h b/src/uct/cuda/cuda_ipc/cuda_ipc_md.h index 17f7c02cb6e..da2a83d6d80 100644 --- a/src/uct/cuda/cuda_ipc/cuda_ipc_md.h +++ b/src/uct/cuda/cuda_ipc/cuda_ipc_md.h @@ -13,6 +13,30 @@ #include #include + +#if HAVE_CUDA_FABRIC +typedef enum uct_cuda_ipc_key_handle { + UCT_CUDA_IPC_KEY_HANDLE_TYPE_ERROR = 0, + UCT_CUDA_IPC_KEY_HANDLE_TYPE_LEGACY, /* cudaMalloc memory */ + UCT_CUDA_IPC_KEY_HANDLE_TYPE_VMM, /* cuMemCreate memory */ + UCT_CUDA_IPC_KEY_HANDLE_TYPE_MEMPOOL /* cudaMallocAsync memory */ +} uct_cuda_ipc_key_handle_t; + + +typedef struct uct_cuda_ipc_md_handle { + uct_cuda_ipc_key_handle_t handle_type; + union { + CUipcMemHandle legacy; /* Legacy IPC handle */ + CUmemFabricHandle fabric_handle; /* VMM/Mallocasync export handle */ + } handle; + CUmemPoolPtrExportData ptr; + CUmemoryPool pool; +} uct_cuda_ipc_md_handle_t; +#else +typedef CUipcMemHandle uct_cuda_ipc_md_handle_t; +#endif + + /** * @brief cuda ipc MD descriptor */ @@ -21,6 +45,13 @@ typedef struct uct_cuda_ipc_md { ucs_ternary_auto_value_t enable_mnnvl; } uct_cuda_ipc_md_t; + +typedef struct uct_cuda_ipc_uuid_hash_key { + int type; + CUuuid uuid; +} uct_cuda_ipc_uuid_hash_key_t; + + typedef struct { /* GPU Device number */ int dev_num; @@ -28,21 +59,30 @@ typedef struct { uint8_t accessible[0]; } uct_cuda_ipc_dev_cache_t; -static UCS_F_ALWAYS_INLINE int uct_cuda_ipc_uuid_equals(CUuuid a, CUuuid b) + +static UCS_F_ALWAYS_INLINE int +uct_cuda_ipc_uuid_equals(uct_cuda_ipc_uuid_hash_key_t key1, + uct_cuda_ipc_uuid_hash_key_t key2) { - int64_t *a64 = (int64_t *)a.bytes; - int64_t *b64 = (int64_t *)b.bytes; - return (a64[0] == b64[0]) && (a64[1] == b64[1]); + int64_t *a64 = (int64_t *)key1.uuid.bytes; + int64_t *b64 = (int64_t *)key2.uuid.bytes; + + return (key1.type == key2.type) && (a64[0] == b64[0]) && (a64[1] == b64[1]); } -static UCS_F_ALWAYS_INLINE khint32_t uct_cuda_ipc_uuid_hash_func(CUuuid key) + +static UCS_F_ALWAYS_INLINE khint32_t +uct_cuda_ipc_uuid_hash_func(uct_cuda_ipc_uuid_hash_key_t key) { - int64_t *i64 = (int64_t *)key.bytes; - return kh_int64_hash_func(i64[0] ^ i64[1]); + int64_t *i64 = (int64_t *)key.uuid.bytes; + return kh_int64_hash_func(i64[0] ^ i64[1] ^ key.type); } -KHASH_INIT(cuda_ipc_uuid_hash, CUuuid, uct_cuda_ipc_dev_cache_t*, 1, - uct_cuda_ipc_uuid_hash_func, uct_cuda_ipc_uuid_equals); + +KHASH_INIT(cuda_ipc_uuid_hash, uct_cuda_ipc_uuid_hash_key_t, + uct_cuda_ipc_dev_cache_t*, 1, uct_cuda_ipc_uuid_hash_func, + uct_cuda_ipc_uuid_equals); + /** * @brief cuda ipc component extension @@ -74,29 +114,6 @@ typedef struct { } uct_cuda_ipc_memh_t; -#if HAVE_CUDA_FABRIC -typedef enum uct_cuda_ipc_key_handle { - UCT_CUDA_IPC_KEY_HANDLE_TYPE_ERROR = 0, - UCT_CUDA_IPC_KEY_HANDLE_TYPE_LEGACY, /* cudaMalloc memory */ - UCT_CUDA_IPC_KEY_HANDLE_TYPE_VMM, /* cuMemCreate memory */ - UCT_CUDA_IPC_KEY_HANDLE_TYPE_MEMPOOL /* cudaMallocAsync memory */ -} uct_cuda_ipc_key_handle_t; - - -typedef struct uct_cuda_ipc_md_handle { - uct_cuda_ipc_key_handle_t handle_type; - union { - CUipcMemHandle legacy; /* Legacy IPC handle */ - CUmemFabricHandle fabric_handle; /* VMM/Mallocasync export handle */ - } handle; - CUmemPoolPtrExportData ptr; - CUmemoryPool pool; -} uct_cuda_ipc_md_handle_t; -#else -typedef CUipcMemHandle uct_cuda_ipc_md_handle_t; -#endif - - /** * @brief cudar ipc region registered for exposure */ diff --git a/test/gtest/ucp/ucp_test.cc b/test/gtest/ucp/ucp_test.cc index 8fd0c7e1316..f2e746df777 100644 --- a/test/gtest/ucp/ucp_test.cc +++ b/test/gtest/ucp/ucp_test.cc @@ -1184,8 +1184,11 @@ bool ucp_test_base::entity::has_lane_with_caps(uint64_t caps) const bool ucp_test_base::entity::is_rndv_put_ppln_supported() const { - const auto config = ucp_ep_config(ep()); - ucs_memory_type_t mem_type = ucph()->config.ext.rndv_frag_mem_type; + const auto config = ucp_ep_config(ep()); + ucs_memory_type_t mem_type; + + mem_type = (ucs_memory_type_t)ucs_ffs64( + ucph()->config.ext.rndv_frag_mem_types); for (auto i = 0; i < config->key.num_lanes; ++i) { const auto lane = config->key.rma_bw_lanes[i];