diff --git a/src/ucp/core/ucp_context.c b/src/ucp/core/ucp_context.c index d9a1fff5030..838b3bf0a74 100644 --- a/src/ucp/core/ucp_context.c +++ b/src/ucp/core/ucp_context.c @@ -340,11 +340,15 @@ static ucs_config_field_t ucp_context_config_table[] = { "and the resulting performance.", ucs_offsetof(ucp_context_config_t, estimated_num_ppn), UCS_CONFIG_TYPE_ULUNITS}, - {"RNDV_FRAG_MEM_TYPE", "host", - "Memory type of fragments used for RNDV pipeline protocol.\n" - "Allowed memory types is one of: host, cuda, rocm, ze-host, ze-device", - ucs_offsetof(ucp_context_config_t, rndv_frag_mem_type), - UCS_CONFIG_TYPE_ENUM(ucs_memory_type_names)}, + {"RNDV_FRAG_MEM_TYPE", NULL, "", + ucs_offsetof(ucp_context_config_t, rndv_frag_mem_types), + UCS_CONFIG_TYPE_BITMAP(ucs_memory_type_names)}, + + {"RNDV_FRAG_MEM_TYPES", "host,cuda", + "Memory types of fragments used for RNDV pipeline protocol.\n" + "Allowed memory types are: host, cuda, rocm, ze-host, ze-device", + ucs_offsetof(ucp_context_config_t, rndv_frag_mem_types), + UCS_CONFIG_TYPE_BITMAP(ucs_memory_type_names)}, {"RNDV_PIPELINE_SEND_THRESH", "inf", "RNDV size threshold to enable sender side pipeline for mem type", diff --git a/src/ucp/core/ucp_context.h b/src/ucp/core/ucp_context.h index 40294ba4697..0599c66760c 100644 --- a/src/ucp/core/ucp_context.h +++ b/src/ucp/core/ucp_context.h @@ -78,8 +78,8 @@ typedef struct ucp_context_config { size_t rndv_frag_size[UCS_MEMORY_TYPE_LAST]; /** Number of RNDV pipeline fragments per allocation */ size_t rndv_num_frags[UCS_MEMORY_TYPE_LAST]; - /** Memory type of fragments used for RNDV pipeline protocol */ - ucs_memory_type_t rndv_frag_mem_type; + /** Memory types of fragments used for RNDV pipeline protocol */ + uint64_t rndv_frag_mem_types; /** RNDV pipeline send threshold */ size_t rndv_pipeline_send_thresh; /** Enabling 2-stage pipeline rndv protocol */ diff --git a/src/ucp/rndv/proto_rndv.h b/src/ucp/rndv/proto_rndv.h index 221bc2c3fc4..88ec9fbc64e 100644 --- a/src/ucp/rndv/proto_rndv.h +++ b/src/ucp/rndv/proto_rndv.h @@ -68,6 +68,7 @@ typedef struct { */ typedef struct { ucp_proto_rndv_ack_priv_t super; + ucs_memory_type_t frag_mem_type; /* Multi-lane common part. Must be the last field, see @ref ucp_proto_multi_priv_t */ diff --git a/src/ucp/rndv/rndv.c b/src/ucp/rndv/rndv.c index 903bac92bfc..7550f8d7a42 100644 --- a/src/ucp/rndv/rndv.c +++ b/src/ucp/rndv/rndv.c @@ -21,6 +21,21 @@ #include +static UCS_F_ALWAYS_INLINE int +ucp_rndv_frag_memtype(ucp_context_t *context) +{ + if (context->config.ext.rndv_frag_mem_types == 0) { + return UCS_MEMORY_TYPE_HOST; + } + + /* Just one fragment memory type can be specified for proto v1, so take the + * first one from the map. Anyway for proto v1 UCX_RNDV_FRAG_MEM_TYPE is + * supposed to be used. + */ + + return ucs_ffs64(context->config.ext.rndv_frag_mem_types); +} + static UCS_F_ALWAYS_INLINE int ucp_rndv_memtype_direct_support(ucp_context_h context, size_t reg_length, uint64_t reg_mem_types) @@ -70,7 +85,7 @@ static int ucp_rndv_is_recv_pipeline_needed(ucp_request_t *rndv_req, const ucp_ep_config_t *ep_config = ucp_ep_config(rndv_req->send.ep); ucp_context_h context = rndv_req->send.ep->worker->context; int found = 0; - ucs_memory_type_t frag_mem_type = context->config.ext.rndv_frag_mem_type; + ucs_memory_type_t frag_mem_type = ucp_rndv_frag_memtype(context); ucp_md_index_t md_index; uct_md_attr_v2_t *md_attr; uint64_t mem_types; @@ -1170,7 +1185,7 @@ static void ucp_rndv_send_frag_get_mem_type(ucp_request_t *sreq, size_t length, uct_completion_callback_t comp_cb) { ucp_worker_h worker = sreq->send.ep->worker; - ucs_memory_type_t frag_mem_type = worker->context->config.ext.rndv_frag_mem_type; + ucs_memory_type_t frag_mem_type = ucp_rndv_frag_memtype(worker->context); ucp_request_t *freq; ucp_mem_desc_t *mdesc; @@ -1254,7 +1269,7 @@ ucp_rndv_recv_start_get_pipeline(ucp_worker_h worker, ucp_request_t *rndv_req, size_t frag_size; /* use ucp_rkey_packed_mem_type(rkey_buffer) with non-host fragments */ - frag_mem_type = context->config.ext.rndv_frag_mem_type; + frag_mem_type = ucp_rndv_frag_memtype(context); frag_size = context->config.ext.rndv_frag_size[frag_mem_type]; min_zcopy = config->rndv.get_zcopy.min; @@ -1337,7 +1352,7 @@ static void ucp_rndv_send_frag_rtr(ucp_worker_h worker, ucp_request_t *rndv_req, ucp_trace_req(rreq, "using rndv pipeline protocol rndv_req %p", rndv_req); offset = 0; - frag_mem_type = worker->context->config.ext.rndv_frag_mem_type; + frag_mem_type = ucp_rndv_frag_memtype(worker->context); max_frag_size = worker->context->config.ext.rndv_frag_size[frag_mem_type]; num_frags = ucs_div_round_up(rndv_rts_hdr->size, max_frag_size); @@ -2069,7 +2084,7 @@ static ucs_status_t ucp_rndv_send_start_put_pipeline(ucp_request_t *sreq, ucp_worker_h worker = sreq->send.ep->worker; ucp_context_h context = worker->context; size_t rndv_base_offset = rndv_rtr_hdr->offset; - ucs_memory_type_t frag_mem_type = context->config.ext.rndv_frag_mem_type; + ucs_memory_type_t frag_mem_type = ucp_rndv_frag_memtype(context); size_t rndv_size = ucs_min(rndv_rtr_hdr->size, sreq->send.length); const uct_md_attr_v2_t *md_attr; diff --git a/src/ucp/rndv/rndv_get.c b/src/ucp/rndv/rndv_get.c index 14d37814a08..a77d920fc4c 100644 --- a/src/ucp/rndv/rndv_get.c +++ b/src/ucp/rndv/rndv_get.c @@ -25,7 +25,8 @@ ucp_proto_rndv_get_common_probe(const ucp_proto_init_params_t *init_params, uint64_t rndv_modes, size_t max_length, uct_ep_operation_t memtype_op, unsigned flags, ucp_md_map_t initial_reg_md_map, - int support_ppln) + int support_ppln, + ucs_memory_type_t frag_mem_type) { ucp_context_t *context = init_params->worker->context; ucp_proto_multi_init_params_t params = { @@ -76,6 +77,7 @@ ucp_proto_rndv_get_common_probe(const ucp_proto_init_params_t *init_params, return; } + rpriv.frag_mem_type = frag_mem_type; priv_size = UCP_PROTO_MULTI_EXTENDED_PRIV_SIZE(&rpriv, mpriv); ucp_proto_common_add_proto(¶ms.super, &caps, &rpriv, priv_size); } @@ -125,7 +127,7 @@ ucp_proto_rndv_get_zcopy_probe(const ucp_proto_init_params_t *init_params) UCT_EP_OP_LAST, UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, - 0, 0); + 0, 0, UCS_MEMORY_TYPE_HOST); } static void @@ -284,11 +286,11 @@ static ucs_status_t ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req) { ucp_request_t *req = ucs_container_of(uct_req, ucp_request_t, send.uct); - const ucp_proto_rndv_bulk_priv_t *rpriv; + const ucp_proto_rndv_bulk_priv_t *rpriv = req->send.proto_config->priv; ucs_status_t status; if (!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED)) { - status = ucp_proto_rndv_mtype_request_init(req); + status = ucp_proto_rndv_mtype_request_init(req, rpriv->frag_mem_type); if (status != UCS_OK) { ucp_proto_request_abort(req, status); return UCS_OK; @@ -300,8 +302,6 @@ ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req) req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED; } - /* coverity[tainted_data_downcast] */ - rpriv = req->send.proto_config->priv; return ucp_proto_multi_progress(req, &rpriv->mpriv, ucp_proto_rndv_get_mtype_send_func, ucp_request_invoke_uct_completion_success, @@ -311,27 +311,35 @@ ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req) static void ucp_proto_rndv_get_mtype_probe(const ucp_proto_init_params_t *init_params) { + ucp_context_t *context = init_params->worker->context; ucp_md_map_t mdesc_md_map; ucs_status_t status; size_t frag_size; + ucs_memory_type_t frag_mem_type; - status = ucp_proto_rndv_mtype_init(init_params, &mdesc_md_map, &frag_size); - if (status != UCS_OK) { - return; - } + ucs_for_each_bit(frag_mem_type, context->config.ext.rndv_frag_mem_types) { + status = ucp_proto_rndv_mtype_init(init_params, frag_mem_type, + &mdesc_md_map, &frag_size); + if (status != UCS_OK) { + return; + } - ucp_proto_rndv_get_common_probe(init_params, - UCS_BIT(UCP_RNDV_MODE_GET_PIPELINE), - frag_size, UCT_EP_OP_PUT_ZCOPY, 0, - mdesc_md_map, 1); + ucp_proto_rndv_get_common_probe(init_params, + UCS_BIT(UCP_RNDV_MODE_GET_PIPELINE), + frag_size, UCT_EP_OP_PUT_ZCOPY, 0, + mdesc_md_map, 1, frag_mem_type); + } } static void ucp_proto_rndv_get_mtype_query(const ucp_proto_query_params_t *params, ucp_proto_query_attr_t *attr) { + const ucp_proto_rndv_bulk_priv_t *rpriv = params->priv; + ucp_proto_rndv_bulk_query(params, attr); - ucp_proto_rndv_mtype_query_desc(params, attr, UCP_PROTO_RNDV_GET_DESC); + ucp_proto_rndv_mtype_query_desc(params, rpriv->frag_mem_type, attr, + UCP_PROTO_RNDV_GET_DESC); } static ucs_status_t ucp_proto_rndv_get_mtype_reset(ucp_request_t *req) diff --git a/src/ucp/rndv/rndv_mtype.inl b/src/ucp/rndv/rndv_mtype.inl index 55e021a83e9..b28e657678f 100644 --- a/src/ucp/rndv/rndv_mtype.inl +++ b/src/ucp/rndv/rndv_mtype.inl @@ -13,11 +13,9 @@ static ucp_ep_h ucp_proto_rndv_mtype_ep(ucp_worker_t *worker, + ucs_memory_type_t frag_mem_type, ucs_memory_type_t buf_mem_type) { - ucs_memory_type_t frag_mem_type = - worker->context->config.ext.rndv_frag_mem_type; - if (worker->mem_type_ep[buf_mem_type] != NULL) { return worker->mem_type_ep[buf_mem_type]; } @@ -27,15 +25,15 @@ static ucp_ep_h ucp_proto_rndv_mtype_ep(ucp_worker_t *worker, static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mtype_init(const ucp_proto_init_params_t *init_params, + ucs_memory_type_t frag_mem_type, ucp_md_map_t *mdesc_md_map_p, size_t *frag_size_p) { - ucp_worker_h worker = init_params->worker; - ucp_context_h context = worker->context; - ucs_memory_type_t mem_type = init_params->select_param->mem_type; - ucs_memory_type_t frag_mem_type = context->config.ext.rndv_frag_mem_type; + ucp_worker_h worker = init_params->worker; + ucp_context_h context = worker->context; + ucs_memory_type_t mem_type = init_params->select_param->mem_type; if ((init_params->select_param->dt_class != UCP_DATATYPE_CONTIG) || - (ucp_proto_rndv_mtype_ep(worker, mem_type) == NULL) || + (ucp_proto_rndv_mtype_ep(worker, frag_mem_type, mem_type) == NULL) || !ucp_proto_init_check_op(init_params, UCP_PROTO_RNDV_OP_ID_MASK)) { return UCS_ERR_UNSUPPORTED; } @@ -47,11 +45,10 @@ ucp_proto_rndv_mtype_init(const ucp_proto_init_params_t *init_params, } static UCS_F_ALWAYS_INLINE ucs_status_t -ucp_proto_rndv_mtype_request_init(ucp_request_t *req) +ucp_proto_rndv_mtype_request_init(ucp_request_t *req, + ucs_memory_type_t frag_mem_type) { - ucp_worker_h worker = req->send.ep->worker; - ucs_memory_type_t frag_mem_type = - worker->context->config.ext.rndv_frag_mem_type; + ucp_worker_h worker = req->send.ep->worker; req->send.rndv.mdesc = ucp_rndv_mpool_get(worker, frag_mem_type, UCS_SYS_DEVICE_ID_UNKNOWN); @@ -82,6 +79,7 @@ static ucp_ep_h ucp_proto_rndv_req_mtype_ep(ucp_request_t *req) ucp_ep_h mem_type_ep; mem_type_ep = ucp_proto_rndv_mtype_ep(req->send.ep->worker, + req->send.rndv.mdesc->memh->mem_type, req->send.state.dt_iter.mem_info.type); ucs_assert(mem_type_ep != NULL); @@ -142,7 +140,7 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mtype_copy( ucp_trace_req(req, "buffer %p copy-%s %p %s-%s using memtype-ep %p lane[%d]", buffer, mode, req->send.state.dt_iter.type.contig.buffer, ucs_memory_type_names[req->send.state.dt_iter.mem_info.type], - ucs_memory_type_names[context->config.ext.rndv_frag_mem_type], + ucs_memory_type_names[req->send.rndv.mdesc->memh->mem_type], mtype_ep, lane); ucp_proto_completion_init(&req->send.state.uct_comp, comp_func); @@ -169,6 +167,7 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mtype_copy( static UCS_F_ALWAYS_INLINE void ucp_proto_rndv_mtype_query_desc(const ucp_proto_query_params_t *params, + ucs_memory_type_t frag_mem_type, ucp_proto_query_attr_t *attr, const char *xfer_desc) { @@ -176,6 +175,7 @@ ucp_proto_rndv_mtype_query_desc(const ucp_proto_query_params_t *params, ucp_context_h context = params->worker->context; ucs_memory_type_t mem_type = params->select_param->mem_type; ucp_ep_h mtype_ep = ucp_proto_rndv_mtype_ep(params->worker, + frag_mem_type, mem_type); ucp_lane_index_t lane; ucp_rsc_index_t rsc_index; @@ -196,6 +196,9 @@ ucp_proto_rndv_mtype_query_desc(const ucp_proto_query_params_t *params, if (ucp_proto_select_op_id(params->select_param) == UCP_OP_ID_RNDV_RECV) { ucs_string_buffer_appendf(&strb, ", %s", tl_name); } + + ucs_string_buffer_appendf(&strb, ", frag %s", + ucs_memory_type_names[frag_mem_type]); } #endif diff --git a/src/ucp/rndv/rndv_put.c b/src/ucp/rndv/rndv_put.c index 6a4d824c55c..27eb137673b 100644 --- a/src/ucp/rndv/rndv_put.c +++ b/src/ucp/rndv/rndv_put.c @@ -227,7 +227,8 @@ ucp_proto_rndv_put_common_probe(const ucp_proto_init_params_t *init_params, uct_ep_operation_t memtype_op, unsigned flags, ucp_md_map_t initial_reg_md_map, uct_completion_callback_t comp_cb, - int support_ppln, uint8_t stat_counter) + int support_ppln, uint8_t stat_counter, + ucs_memory_type_t frag_mem_type) { const size_t atp_size = sizeof(ucp_rndv_ack_hdr_t); ucp_context_t *context = init_params->worker->context; @@ -345,8 +346,9 @@ ucp_proto_rndv_put_common_probe(const ucp_proto_init_params_t *init_params, if (send_atp) { ucs_assert(rpriv.atp_map != 0); } - rpriv.atp_num_lanes = ucs_popcount(rpriv.atp_map); - rpriv.stat_counter = stat_counter; + rpriv.atp_num_lanes = ucs_popcount(rpriv.atp_map); + rpriv.stat_counter = stat_counter; + rpriv.bulk.frag_mem_type = frag_mem_type; priv_size = UCP_PROTO_MULTI_EXTENDED_PRIV_SIZE(&rpriv, bulk.mpriv); ucp_proto_common_add_proto(¶ms.super, &caps, &rpriv, priv_size); @@ -417,7 +419,7 @@ ucp_proto_rndv_put_zcopy_probe(const ucp_proto_init_params_t *init_params) UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY | UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING, 0, ucp_proto_rndv_put_zcopy_completion, 0, - UCP_WORKER_STAT_RNDV_PUT_ZCOPY); + UCP_WORKER_STAT_RNDV_PUT_ZCOPY, UCS_MEMORY_TYPE_HOST); } static void @@ -500,11 +502,12 @@ static ucs_status_t ucp_proto_rndv_put_mtype_copy_progress(uct_pending_req_t *uct_req) { ucp_request_t *req = ucs_container_of(uct_req, ucp_request_t, send.uct); + const ucp_proto_rndv_put_priv_t *rpriv = req->send.proto_config->priv; ucs_status_t status; ucs_assert(!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED)); - status = ucp_proto_rndv_mtype_request_init(req); + status = ucp_proto_rndv_mtype_request_init(req, rpriv->bulk.frag_mem_type); if (status != UCS_OK) { ucp_proto_request_abort(req, status); return UCS_OK; @@ -563,8 +566,19 @@ ucp_proto_rndv_put_mtype_probe(const ucp_proto_init_params_t *init_params) ucp_md_map_t mdesc_md_map; ucs_status_t status; size_t frag_size; + ucs_memory_type_t frag_mem_type; - status = ucp_proto_rndv_mtype_init(init_params, &mdesc_md_map, &frag_size); + if (init_params->rkey_config_key == NULL) { + /* FIXME: maybe can initialize proto with all available types if no + * rkey in RTR. + */ + frag_mem_type = UCS_MEMORY_TYPE_HOST; + } else { + frag_mem_type = init_params->rkey_config_key->mem_type; + } + + status = ucp_proto_rndv_mtype_init(init_params, frag_mem_type, + &mdesc_md_map, &frag_size); if (status != UCS_OK) { return; } @@ -575,21 +589,22 @@ ucp_proto_rndv_put_mtype_probe(const ucp_proto_init_params_t *init_params) comp_cb = ucp_proto_rndv_put_mtype_completion; } - ucp_proto_rndv_put_common_probe(init_params, - UCS_BIT(UCP_RNDV_MODE_PUT_PIPELINE), - frag_size, UCT_EP_OP_GET_ZCOPY, 0, - mdesc_md_map, comp_cb, 1, - UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY); + ucp_proto_rndv_put_common_probe( + init_params, UCS_BIT(UCP_RNDV_MODE_PUT_PIPELINE), frag_size, + UCT_EP_OP_GET_ZCOPY, 0, mdesc_md_map, comp_cb, 1, + UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY, frag_mem_type); } static void ucp_proto_rndv_put_mtype_query(const ucp_proto_query_params_t *params, ucp_proto_query_attr_t *attr) { + const ucp_proto_rndv_put_priv_t *rpriv = params->priv; const char *put_desc; put_desc = ucp_proto_rndv_put_common_query(params, attr); - ucp_proto_rndv_mtype_query_desc(params, attr, put_desc); + ucp_proto_rndv_mtype_query_desc(params, rpriv->bulk.frag_mem_type, attr, + put_desc); } ucp_proto_t ucp_rndv_put_mtype_proto = { diff --git a/src/ucp/rndv/rndv_rkey_ptr.c b/src/ucp/rndv/rndv_rkey_ptr.c index 882de12629c..4e3d68d6f20 100644 --- a/src/ucp/rndv/rndv_rkey_ptr.c +++ b/src/ucp/rndv/rndv_rkey_ptr.c @@ -240,8 +240,9 @@ ucp_proto_rndv_rkey_ptr_mtype_probe(const ucp_proto_init_params_t *init_params) return; } - status = ucp_proto_rndv_mtype_init(init_params, &mdesc_md_map, - ¶ms.super.max_length); + /* 2-stage ppln protocols work with host staging buffers only */ + status = ucp_proto_rndv_mtype_init(init_params, UCS_MEMORY_TYPE_HOST, + &mdesc_md_map, ¶ms.super.max_length); if (status != UCS_OK) { return; } @@ -348,7 +349,7 @@ ucp_proto_rndv_rkey_ptr_mtype_query(const ucp_proto_query_params_t *params, const char *desc = UCP_PROTO_RNDV_RKEY_PTR_DESC; ucp_rndv_rkey_ptr_query_common(params, attr); - ucp_proto_rndv_mtype_query_desc(params, attr, desc); + ucp_proto_rndv_mtype_query_desc(params, UCS_MEMORY_TYPE_HOST, attr, desc); } ucp_proto_t ucp_rndv_rkey_ptr_mtype_proto = { diff --git a/src/ucp/rndv/rndv_rtr.c b/src/ucp/rndv/rndv_rtr.c index 06f6d7e6bd0..ecb6e5178f4 100644 --- a/src/ucp/rndv/rndv_rtr.c +++ b/src/ucp/rndv/rndv_rtr.c @@ -32,6 +32,11 @@ typedef struct { ucp_proto_rndv_rtr_data_received_cb_t data_received; } ucp_proto_rndv_rtr_priv_t; +typedef struct { + ucp_proto_rndv_rtr_priv_t super; + ucs_memory_type_t frag_mem_type; +} ucp_proto_rndv_rtr_mtype_priv_t; + static UCS_F_ALWAYS_INLINE void ucp_proto_rtr_common_request_init(ucp_request_t *req) { @@ -340,10 +345,11 @@ ucp_proto_rndv_rtr_mtype_data_received(ucp_request_t *req, int in_buffer) static ucs_status_t ucp_proto_rndv_rtr_mtype_progress(uct_pending_req_t *self) { ucp_request_t *req = ucs_container_of(self, ucp_request_t, send.uct); + const ucp_proto_rndv_rtr_mtype_priv_t *rpriv = req->send.proto_config->priv; ucs_status_t status; if (!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED)) { - status = ucp_proto_rndv_mtype_request_init(req); + status = ucp_proto_rndv_mtype_request_init(req, rpriv->frag_mem_type); if (status != UCS_OK) { ucp_proto_request_abort(req, status); return UCS_OK; @@ -385,11 +391,11 @@ ucp_proto_rndv_rtr_mtype_probe(const ucp_proto_init_params_t *init_params) .remote_op_id = UCP_OP_ID_RNDV_SEND, .lane = ucp_proto_rndv_find_ctrl_lane(init_params), .perf_bias = 0.0, - .mem_info.type = context->config.ext.rndv_frag_mem_type, .mem_info.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN, .ctrl_msg_name = UCP_PROTO_RNDV_RTR_NAME, }; - ucp_proto_rndv_rtr_priv_t rpriv; + ucs_memory_type_t frag_mem_type; + ucp_proto_rndv_rtr_mtype_priv_t rpriv; ucp_md_map_t dummy_md_map; ucp_md_index_t md_index; ucs_status_t status; @@ -399,49 +405,56 @@ ucp_proto_rndv_rtr_mtype_probe(const ucp_proto_init_params_t *init_params) return; } - status = ucp_proto_rndv_mtype_init(init_params, &dummy_md_map, - ¶ms.super.max_length); - if (status != UCS_OK) { - return; - } + ucs_for_each_bit(frag_mem_type, context->config.ext.rndv_frag_mem_types) { + status = ucp_proto_rndv_mtype_init(init_params, frag_mem_type, + &dummy_md_map, + ¶ms.super.max_length); + if (status != UCS_OK) { + return; + } - status = ucp_proto_init_buffer_copy_time( - init_params->worker, "rtr/mtype unpack", params.mem_info.type, - init_params->select_param->mem_type, UCT_EP_OP_PUT_ZCOPY, - ¶ms.unpack_time, ¶ms.unpack_perf_node); - if (status != UCS_OK) { - return; - } + params.mem_info.type = frag_mem_type; - status = ucp_mm_get_alloc_md_index(context, &md_index, - params.mem_info.type); - if ((status != UCS_OK) || (md_index == UCP_NULL_RESOURCE)) { - params.md_map = 0; - } else { - params.md_map = UCS_BIT(md_index); - } + status = ucp_proto_init_buffer_copy_time( + init_params->worker, "rtr/mtype unpack", frag_mem_type, + init_params->select_param->mem_type, UCT_EP_OP_PUT_ZCOPY, + ¶ms.unpack_time, ¶ms.unpack_perf_node); + if (status != UCS_OK) { + return; + } + + status = ucp_mm_get_alloc_md_index(context, &md_index, frag_mem_type); + if ((status != UCS_OK) || (md_index == UCP_NULL_RESOURCE)) { + params.md_map = 0; + } else { + params.md_map = UCS_BIT(md_index); + } - rpriv.pack_cb = ucp_proto_rndv_rtr_mtype_pack; - rpriv.data_received = ucp_proto_rndv_rtr_mtype_data_received; + rpriv.super.pack_cb = ucp_proto_rndv_rtr_mtype_pack; + rpriv.super.data_received = ucp_proto_rndv_rtr_mtype_data_received; + rpriv.frag_mem_type = frag_mem_type; - ucp_proto_rndv_ctrl_probe(¶ms, &rpriv, sizeof(rpriv)); - ucp_proto_perf_node_deref(¶ms.unpack_perf_node); + ucp_proto_rndv_ctrl_probe(¶ms, &rpriv, sizeof(rpriv)); + ucp_proto_perf_node_deref(¶ms.unpack_perf_node); + } } static void ucp_proto_rndv_rtr_mtype_query(const ucp_proto_query_params_t *params, ucp_proto_query_attr_t *attr) { - const ucp_proto_rndv_ctrl_priv_t *rpriv = params->priv; + const ucp_proto_rndv_rtr_mtype_priv_t *rpriv = params->priv; ucp_proto_query_attr_t remote_attr; - ucp_proto_config_query(params->worker, &rpriv->remote_proto_config, + ucp_proto_config_query(params->worker, + &rpriv->super.super.remote_proto_config, params->msg_length, &remote_attr); attr->is_estimation = 1; attr->max_msg_length = remote_attr.max_msg_length; - attr->lane_map = UCS_BIT(rpriv->lane); - ucp_proto_rndv_mtype_query_desc(params, attr, remote_attr.desc); + attr->lane_map = UCS_BIT(rpriv->super.super.lane); + ucp_proto_rndv_mtype_query_desc(params, rpriv->frag_mem_type, attr, + remote_attr.desc); ucs_strncpy_safe(attr->config, remote_attr.config, sizeof(attr->config)); }