Skip to content

Commit

Permalink
UCP: Enable ppln protos with cuda buffers by default
Browse files Browse the repository at this point in the history
  • Loading branch information
brminich committed Aug 28, 2024
1 parent 2c067f5 commit 744f644
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 85 deletions.
14 changes: 9 additions & 5 deletions src/ucp/core/ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,15 @@ static ucs_config_field_t ucp_context_config_table[] = {
"and the resulting performance.",
ucs_offsetof(ucp_context_config_t, estimated_num_ppn), UCS_CONFIG_TYPE_ULUNITS},

{"RNDV_FRAG_MEM_TYPE", "host",
"Memory type of fragments used for RNDV pipeline protocol.\n"
"Allowed memory types is one of: host, cuda, rocm, ze-host, ze-device",
ucs_offsetof(ucp_context_config_t, rndv_frag_mem_type),
UCS_CONFIG_TYPE_ENUM(ucs_memory_type_names)},
{"RNDV_FRAG_MEM_TYPE", NULL, "",
ucs_offsetof(ucp_context_config_t, rndv_frag_mem_types),
UCS_CONFIG_TYPE_BITMAP(ucs_memory_type_names)},

{"RNDV_FRAG_MEM_TYPES", "host,cuda",
"Memory types of fragments used for RNDV pipeline protocol.\n"
"Allowed memory types are: host, cuda, rocm, ze-host, ze-device",
ucs_offsetof(ucp_context_config_t, rndv_frag_mem_types),
UCS_CONFIG_TYPE_BITMAP(ucs_memory_type_names)},

{"RNDV_PIPELINE_SEND_THRESH", "inf",
"RNDV size threshold to enable sender side pipeline for mem type",
Expand Down
4 changes: 2 additions & 2 deletions src/ucp/core/ucp_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ typedef struct ucp_context_config {
size_t rndv_frag_size[UCS_MEMORY_TYPE_LAST];
/** Number of RNDV pipeline fragments per allocation */
size_t rndv_num_frags[UCS_MEMORY_TYPE_LAST];
/** Memory type of fragments used for RNDV pipeline protocol */
ucs_memory_type_t rndv_frag_mem_type;
/** Memory types of fragments used for RNDV pipeline protocol */
uint64_t rndv_frag_mem_types;
/** RNDV pipeline send threshold */
size_t rndv_pipeline_send_thresh;
/** Enabling 2-stage pipeline rndv protocol */
Expand Down
1 change: 1 addition & 0 deletions src/ucp/rndv/proto_rndv.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ typedef struct {
*/
typedef struct {
ucp_proto_rndv_ack_priv_t super;
ucs_memory_type_t frag_mem_type;

/* Multi-lane common part. Must be the last field, see
@ref ucp_proto_multi_priv_t */
Expand Down
25 changes: 20 additions & 5 deletions src/ucp/rndv/rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@
#include <ucs/datastruct/queue.h>


static UCS_F_ALWAYS_INLINE int
ucp_rndv_frag_memtype(ucp_context_t *context)
{
if (context->config.ext.rndv_frag_mem_types == 0) {
return UCS_MEMORY_TYPE_HOST;
}

/* Just one fragment memory type can be specified for proto v1, so take the
* first one from the map. Anyway for proto v1 UCX_RNDV_FRAG_MEM_TYPE is
* supposed to be used.
*/

return ucs_ffs64(context->config.ext.rndv_frag_mem_types);
}

static UCS_F_ALWAYS_INLINE int
ucp_rndv_memtype_direct_support(ucp_context_h context, size_t reg_length,
uint64_t reg_mem_types)
Expand Down Expand Up @@ -70,7 +85,7 @@ static int ucp_rndv_is_recv_pipeline_needed(ucp_request_t *rndv_req,
const ucp_ep_config_t *ep_config = ucp_ep_config(rndv_req->send.ep);
ucp_context_h context = rndv_req->send.ep->worker->context;
int found = 0;
ucs_memory_type_t frag_mem_type = context->config.ext.rndv_frag_mem_type;
ucs_memory_type_t frag_mem_type = ucp_rndv_frag_memtype(context);
ucp_md_index_t md_index;
uct_md_attr_v2_t *md_attr;
uint64_t mem_types;
Expand Down Expand Up @@ -1170,7 +1185,7 @@ static void ucp_rndv_send_frag_get_mem_type(ucp_request_t *sreq, size_t length,
uct_completion_callback_t comp_cb)
{
ucp_worker_h worker = sreq->send.ep->worker;
ucs_memory_type_t frag_mem_type = worker->context->config.ext.rndv_frag_mem_type;
ucs_memory_type_t frag_mem_type = ucp_rndv_frag_memtype(worker->context);
ucp_request_t *freq;
ucp_mem_desc_t *mdesc;

Expand Down Expand Up @@ -1254,7 +1269,7 @@ ucp_rndv_recv_start_get_pipeline(ucp_worker_h worker, ucp_request_t *rndv_req,
size_t frag_size;

/* use ucp_rkey_packed_mem_type(rkey_buffer) with non-host fragments */
frag_mem_type = context->config.ext.rndv_frag_mem_type;
frag_mem_type = ucp_rndv_frag_memtype(context);
frag_size = context->config.ext.rndv_frag_size[frag_mem_type];

min_zcopy = config->rndv.get_zcopy.min;
Expand Down Expand Up @@ -1337,7 +1352,7 @@ static void ucp_rndv_send_frag_rtr(ucp_worker_h worker, ucp_request_t *rndv_req,
ucp_trace_req(rreq, "using rndv pipeline protocol rndv_req %p", rndv_req);

offset = 0;
frag_mem_type = worker->context->config.ext.rndv_frag_mem_type;
frag_mem_type = ucp_rndv_frag_memtype(worker->context);
max_frag_size = worker->context->config.ext.rndv_frag_size[frag_mem_type];
num_frags = ucs_div_round_up(rndv_rts_hdr->size, max_frag_size);

Expand Down Expand Up @@ -2069,7 +2084,7 @@ static ucs_status_t ucp_rndv_send_start_put_pipeline(ucp_request_t *sreq,
ucp_worker_h worker = sreq->send.ep->worker;
ucp_context_h context = worker->context;
size_t rndv_base_offset = rndv_rtr_hdr->offset;
ucs_memory_type_t frag_mem_type = context->config.ext.rndv_frag_mem_type;
ucs_memory_type_t frag_mem_type = ucp_rndv_frag_memtype(context);
size_t rndv_size = ucs_min(rndv_rtr_hdr->size,
sreq->send.length);
const uct_md_attr_v2_t *md_attr;
Expand Down
38 changes: 23 additions & 15 deletions src/ucp/rndv/rndv_get.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ ucp_proto_rndv_get_common_probe(const ucp_proto_init_params_t *init_params,
uint64_t rndv_modes, size_t max_length,
uct_ep_operation_t memtype_op, unsigned flags,
ucp_md_map_t initial_reg_md_map,
int support_ppln)
int support_ppln,
ucs_memory_type_t frag_mem_type)
{
ucp_context_t *context = init_params->worker->context;
ucp_proto_multi_init_params_t params = {
Expand Down Expand Up @@ -76,6 +77,7 @@ ucp_proto_rndv_get_common_probe(const ucp_proto_init_params_t *init_params,
return;
}

rpriv.frag_mem_type = frag_mem_type;
priv_size = UCP_PROTO_MULTI_EXTENDED_PRIV_SIZE(&rpriv, mpriv);
ucp_proto_common_add_proto(&params.super, &caps, &rpriv, priv_size);
}
Expand Down Expand Up @@ -125,7 +127,7 @@ ucp_proto_rndv_get_zcopy_probe(const ucp_proto_init_params_t *init_params)
UCT_EP_OP_LAST,
UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY |
UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING,
0, 0);
0, 0, UCS_MEMORY_TYPE_HOST);
}

static void
Expand Down Expand Up @@ -284,11 +286,11 @@ static ucs_status_t
ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req)
{
ucp_request_t *req = ucs_container_of(uct_req, ucp_request_t, send.uct);
const ucp_proto_rndv_bulk_priv_t *rpriv;
const ucp_proto_rndv_bulk_priv_t *rpriv = req->send.proto_config->priv;
ucs_status_t status;

if (!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED)) {
status = ucp_proto_rndv_mtype_request_init(req);
status = ucp_proto_rndv_mtype_request_init(req, rpriv->frag_mem_type);
if (status != UCS_OK) {
ucp_proto_request_abort(req, status);
return UCS_OK;
Expand All @@ -300,8 +302,6 @@ ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req)
req->flags |= UCP_REQUEST_FLAG_PROTO_INITIALIZED;
}

/* coverity[tainted_data_downcast] */
rpriv = req->send.proto_config->priv;
return ucp_proto_multi_progress(req, &rpriv->mpriv,
ucp_proto_rndv_get_mtype_send_func,
ucp_request_invoke_uct_completion_success,
Expand All @@ -311,27 +311,35 @@ ucp_proto_rndv_get_mtype_fetch_progress(uct_pending_req_t *uct_req)
static void
ucp_proto_rndv_get_mtype_probe(const ucp_proto_init_params_t *init_params)
{
ucp_context_t *context = init_params->worker->context;
ucp_md_map_t mdesc_md_map;
ucs_status_t status;
size_t frag_size;
ucs_memory_type_t frag_mem_type;

status = ucp_proto_rndv_mtype_init(init_params, &mdesc_md_map, &frag_size);
if (status != UCS_OK) {
return;
}
ucs_for_each_bit(frag_mem_type, context->config.ext.rndv_frag_mem_types) {
status = ucp_proto_rndv_mtype_init(init_params, frag_mem_type,
&mdesc_md_map, &frag_size);
if (status != UCS_OK) {
return;
}

ucp_proto_rndv_get_common_probe(init_params,
UCS_BIT(UCP_RNDV_MODE_GET_PIPELINE),
frag_size, UCT_EP_OP_PUT_ZCOPY, 0,
mdesc_md_map, 1);
ucp_proto_rndv_get_common_probe(init_params,
UCS_BIT(UCP_RNDV_MODE_GET_PIPELINE),
frag_size, UCT_EP_OP_PUT_ZCOPY, 0,
mdesc_md_map, 1, frag_mem_type);
}
}

static void
ucp_proto_rndv_get_mtype_query(const ucp_proto_query_params_t *params,
ucp_proto_query_attr_t *attr)
{
const ucp_proto_rndv_bulk_priv_t *rpriv = params->priv;

ucp_proto_rndv_bulk_query(params, attr);
ucp_proto_rndv_mtype_query_desc(params, attr, UCP_PROTO_RNDV_GET_DESC);
ucp_proto_rndv_mtype_query_desc(params, rpriv->frag_mem_type, attr,
UCP_PROTO_RNDV_GET_DESC);
}

static ucs_status_t ucp_proto_rndv_get_mtype_reset(ucp_request_t *req)
Expand Down
29 changes: 16 additions & 13 deletions src/ucp/rndv/rndv_mtype.inl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@


static ucp_ep_h ucp_proto_rndv_mtype_ep(ucp_worker_t *worker,
ucs_memory_type_t frag_mem_type,
ucs_memory_type_t buf_mem_type)
{
ucs_memory_type_t frag_mem_type =
worker->context->config.ext.rndv_frag_mem_type;

if (worker->mem_type_ep[buf_mem_type] != NULL) {
return worker->mem_type_ep[buf_mem_type];
}
Expand All @@ -27,15 +25,15 @@ static ucp_ep_h ucp_proto_rndv_mtype_ep(ucp_worker_t *worker,

static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_proto_rndv_mtype_init(const ucp_proto_init_params_t *init_params,
ucs_memory_type_t frag_mem_type,
ucp_md_map_t *mdesc_md_map_p, size_t *frag_size_p)
{
ucp_worker_h worker = init_params->worker;
ucp_context_h context = worker->context;
ucs_memory_type_t mem_type = init_params->select_param->mem_type;
ucs_memory_type_t frag_mem_type = context->config.ext.rndv_frag_mem_type;
ucp_worker_h worker = init_params->worker;
ucp_context_h context = worker->context;
ucs_memory_type_t mem_type = init_params->select_param->mem_type;

if ((init_params->select_param->dt_class != UCP_DATATYPE_CONTIG) ||
(ucp_proto_rndv_mtype_ep(worker, mem_type) == NULL) ||
(ucp_proto_rndv_mtype_ep(worker, frag_mem_type, mem_type) == NULL) ||
!ucp_proto_init_check_op(init_params, UCP_PROTO_RNDV_OP_ID_MASK)) {
return UCS_ERR_UNSUPPORTED;
}
Expand All @@ -47,11 +45,10 @@ ucp_proto_rndv_mtype_init(const ucp_proto_init_params_t *init_params,
}

static UCS_F_ALWAYS_INLINE ucs_status_t
ucp_proto_rndv_mtype_request_init(ucp_request_t *req)
ucp_proto_rndv_mtype_request_init(ucp_request_t *req,
ucs_memory_type_t frag_mem_type)
{
ucp_worker_h worker = req->send.ep->worker;
ucs_memory_type_t frag_mem_type =
worker->context->config.ext.rndv_frag_mem_type;
ucp_worker_h worker = req->send.ep->worker;

req->send.rndv.mdesc = ucp_rndv_mpool_get(worker, frag_mem_type,
UCS_SYS_DEVICE_ID_UNKNOWN);
Expand Down Expand Up @@ -82,6 +79,7 @@ static ucp_ep_h ucp_proto_rndv_req_mtype_ep(ucp_request_t *req)
ucp_ep_h mem_type_ep;

mem_type_ep = ucp_proto_rndv_mtype_ep(req->send.ep->worker,
req->send.rndv.mdesc->memh->mem_type,
req->send.state.dt_iter.mem_info.type);
ucs_assert(mem_type_ep != NULL);

Expand Down Expand Up @@ -142,7 +140,7 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mtype_copy(
ucp_trace_req(req, "buffer %p copy-%s %p %s-%s using memtype-ep %p lane[%d]",
buffer, mode, req->send.state.dt_iter.type.contig.buffer,
ucs_memory_type_names[req->send.state.dt_iter.mem_info.type],
ucs_memory_type_names[context->config.ext.rndv_frag_mem_type],
ucs_memory_type_names[req->send.rndv.mdesc->memh->mem_type],
mtype_ep, lane);

ucp_proto_completion_init(&req->send.state.uct_comp, comp_func);
Expand All @@ -169,13 +167,15 @@ static UCS_F_ALWAYS_INLINE ucs_status_t ucp_proto_rndv_mtype_copy(

static UCS_F_ALWAYS_INLINE void
ucp_proto_rndv_mtype_query_desc(const ucp_proto_query_params_t *params,
ucs_memory_type_t frag_mem_type,
ucp_proto_query_attr_t *attr,
const char *xfer_desc)
{
UCS_STRING_BUFFER_FIXED(strb, attr->desc, sizeof(attr->desc));
ucp_context_h context = params->worker->context;
ucs_memory_type_t mem_type = params->select_param->mem_type;
ucp_ep_h mtype_ep = ucp_proto_rndv_mtype_ep(params->worker,
frag_mem_type,
mem_type);
ucp_lane_index_t lane;
ucp_rsc_index_t rsc_index;
Expand All @@ -196,6 +196,9 @@ ucp_proto_rndv_mtype_query_desc(const ucp_proto_query_params_t *params,
if (ucp_proto_select_op_id(params->select_param) == UCP_OP_ID_RNDV_RECV) {
ucs_string_buffer_appendf(&strb, ", %s", tl_name);
}

ucs_string_buffer_appendf(&strb, ", frag %s",
ucs_memory_type_names[frag_mem_type]);
}

#endif
39 changes: 27 additions & 12 deletions src/ucp/rndv/rndv_put.c
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ ucp_proto_rndv_put_common_probe(const ucp_proto_init_params_t *init_params,
uct_ep_operation_t memtype_op, unsigned flags,
ucp_md_map_t initial_reg_md_map,
uct_completion_callback_t comp_cb,
int support_ppln, uint8_t stat_counter)
int support_ppln, uint8_t stat_counter,
ucs_memory_type_t frag_mem_type)
{
const size_t atp_size = sizeof(ucp_rndv_ack_hdr_t);
ucp_context_t *context = init_params->worker->context;
Expand Down Expand Up @@ -345,8 +346,9 @@ ucp_proto_rndv_put_common_probe(const ucp_proto_init_params_t *init_params,
if (send_atp) {
ucs_assert(rpriv.atp_map != 0);
}
rpriv.atp_num_lanes = ucs_popcount(rpriv.atp_map);
rpriv.stat_counter = stat_counter;
rpriv.atp_num_lanes = ucs_popcount(rpriv.atp_map);
rpriv.stat_counter = stat_counter;
rpriv.bulk.frag_mem_type = frag_mem_type;

priv_size = UCP_PROTO_MULTI_EXTENDED_PRIV_SIZE(&rpriv, bulk.mpriv);
ucp_proto_common_add_proto(&params.super, &caps, &rpriv, priv_size);
Expand Down Expand Up @@ -417,7 +419,7 @@ ucp_proto_rndv_put_zcopy_probe(const ucp_proto_init_params_t *init_params)
UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY |
UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING,
0, ucp_proto_rndv_put_zcopy_completion, 0,
UCP_WORKER_STAT_RNDV_PUT_ZCOPY);
UCP_WORKER_STAT_RNDV_PUT_ZCOPY, UCS_MEMORY_TYPE_HOST);
}

static void
Expand Down Expand Up @@ -500,11 +502,12 @@ static ucs_status_t
ucp_proto_rndv_put_mtype_copy_progress(uct_pending_req_t *uct_req)
{
ucp_request_t *req = ucs_container_of(uct_req, ucp_request_t, send.uct);
const ucp_proto_rndv_put_priv_t *rpriv = req->send.proto_config->priv;
ucs_status_t status;

ucs_assert(!(req->flags & UCP_REQUEST_FLAG_PROTO_INITIALIZED));

status = ucp_proto_rndv_mtype_request_init(req);
status = ucp_proto_rndv_mtype_request_init(req, rpriv->bulk.frag_mem_type);
if (status != UCS_OK) {
ucp_proto_request_abort(req, status);
return UCS_OK;
Expand Down Expand Up @@ -563,8 +566,19 @@ ucp_proto_rndv_put_mtype_probe(const ucp_proto_init_params_t *init_params)
ucp_md_map_t mdesc_md_map;
ucs_status_t status;
size_t frag_size;
ucs_memory_type_t frag_mem_type;

status = ucp_proto_rndv_mtype_init(init_params, &mdesc_md_map, &frag_size);
if (init_params->rkey_config_key == NULL) {
/* FIXME: maybe can initialize proto with all available types if no
* rkey in RTR.
*/
frag_mem_type = UCS_MEMORY_TYPE_HOST;
} else {
frag_mem_type = init_params->rkey_config_key->mem_type;
}

status = ucp_proto_rndv_mtype_init(init_params, frag_mem_type,
&mdesc_md_map, &frag_size);
if (status != UCS_OK) {
return;
}
Expand All @@ -575,21 +589,22 @@ ucp_proto_rndv_put_mtype_probe(const ucp_proto_init_params_t *init_params)
comp_cb = ucp_proto_rndv_put_mtype_completion;
}

ucp_proto_rndv_put_common_probe(init_params,
UCS_BIT(UCP_RNDV_MODE_PUT_PIPELINE),
frag_size, UCT_EP_OP_GET_ZCOPY, 0,
mdesc_md_map, comp_cb, 1,
UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY);
ucp_proto_rndv_put_common_probe(
init_params, UCS_BIT(UCP_RNDV_MODE_PUT_PIPELINE), frag_size,
UCT_EP_OP_GET_ZCOPY, 0, mdesc_md_map, comp_cb, 1,
UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY, frag_mem_type);
}

static void
ucp_proto_rndv_put_mtype_query(const ucp_proto_query_params_t *params,
ucp_proto_query_attr_t *attr)
{
const ucp_proto_rndv_put_priv_t *rpriv = params->priv;
const char *put_desc;

put_desc = ucp_proto_rndv_put_common_query(params, attr);
ucp_proto_rndv_mtype_query_desc(params, attr, put_desc);
ucp_proto_rndv_mtype_query_desc(params, rpriv->bulk.frag_mem_type, attr,
put_desc);
}

ucp_proto_t ucp_rndv_put_mtype_proto = {
Expand Down
Loading

0 comments on commit 744f644

Please sign in to comment.