Skip to content

Commit

Permalink
UCP: Fix ppln with cuda + minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
brminich committed Sep 18, 2024
1 parent c1f065e commit cef74eb
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 42 deletions.
23 changes: 12 additions & 11 deletions src/ucp/proto/proto_common.c
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ ucp_lane_index_t
ucp_proto_common_find_lanes(const ucp_proto_init_params_t *params,
uct_ep_operation_t memtype_op, unsigned flags,
ptrdiff_t max_iov_offs, size_t min_iov,
ucp_lane_type_t lane_type, uint64_t tl_cap_flags,
ucp_lane_type_t lane_type,
ucs_memory_type_t mem_type, uint64_t tl_cap_flags,
ucp_lane_index_t max_lanes,
ucp_lane_map_t exclude_map, ucp_lane_index_t *lanes)
{
Expand Down Expand Up @@ -494,25 +495,24 @@ ucp_proto_common_find_lanes(const ucp_proto_init_params_t *params,
}

/* Check memory registration capabilities for zero-copy case */
if (flags & UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY) {
if (flags & (UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY |
UCP_PROTO_COMMON_INIT_FLAG_MTYPE_ZCOPY)) {
if (md_attr->flags & UCT_MD_FLAG_NEED_MEMH) {
/* Memory domain must support registration on the relevant
* memory type */
if (!(context->reg_md_map[select_param->mem_type] &
UCS_BIT(md_index))) {
if (!(context->reg_md_map[mem_type] & UCS_BIT(md_index))) {
ucs_trace("%s: md %s cannot register %s memory", lane_desc,
context->tl_mds[md_index].rsc.md_name,
ucs_memory_type_names[select_param->mem_type]);
ucs_memory_type_names[mem_type]);
continue;
}
} else if (!(md_attr->access_mem_types &
UCS_BIT(select_param->mem_type))) {
} else if (!(md_attr->access_mem_types & UCS_BIT(mem_type))) {
/*
* Memory domain which does not require a registration for zero
* copy operation must be able to access the relevant memory type
*/
ucs_trace("%s: no access to mem type %s", lane_desc,
ucs_memory_type_names[select_param->mem_type]);
ucs_memory_type_names[mem_type]);
continue;
}
}
Expand Down Expand Up @@ -598,8 +598,9 @@ ucp_proto_common_reg_md_map(const ucp_proto_common_init_params_t *params,

ucp_lane_index_t ucp_proto_common_find_lanes_with_min_frag(
const ucp_proto_common_init_params_t *params, ucp_lane_type_t lane_type,
uint64_t tl_cap_flags, ucp_lane_index_t max_lanes,
ucp_lane_map_t exclude_map, ucp_lane_index_t *lanes)
ucs_memory_type_t mem_type, uint64_t tl_cap_flags,
ucp_lane_index_t max_lanes, ucp_lane_map_t exclude_map,
ucp_lane_index_t *lanes)
{
ucp_lane_index_t lane_index, lane, num_lanes, num_valid_lanes;
const uct_iface_attr_t *iface_attr;
Expand All @@ -608,7 +609,7 @@ ucp_lane_index_t ucp_proto_common_find_lanes_with_min_frag(
num_lanes = ucp_proto_common_find_lanes(&params->super, params->memtype_op,
params->flags, params->max_iov_offs,
params->min_iov, lane_type,
tl_cap_flags, max_lanes,
mem_type, tl_cap_flags, max_lanes,
exclude_map, lanes);

num_valid_lanes = 0;
Expand Down
12 changes: 7 additions & 5 deletions src/ucp/proto/proto_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ typedef enum {

/* Supports starting the request when its datatype iterator offset is > 0 */
UCP_PROTO_COMMON_INIT_FLAG_RESUME = UCS_BIT(10),
UCP_PROTO_COMMON_KEEP_MD_MAP = UCS_BIT(11)
UCP_PROTO_COMMON_INIT_FLAG_KEEP_MD_MAP = UCS_BIT(11),
UCP_PROTO_COMMON_INIT_FLAG_MTYPE_ZCOPY = UCS_BIT(12)
} ucp_proto_common_init_flags_t;


Expand Down Expand Up @@ -248,16 +249,17 @@ ucp_proto_common_get_lane_perf(const ucp_proto_common_init_params_t *params,
/* @return number of lanes found */
ucp_lane_index_t ucp_proto_common_find_lanes_with_min_frag(
const ucp_proto_common_init_params_t *params, ucp_lane_type_t lane_type,
uint64_t tl_cap_flags, ucp_lane_index_t max_lanes,
ucp_lane_map_t exclude_map, ucp_lane_index_t *lanes);
ucs_memory_type_t mem_type, uint64_t tl_cap_flags,
ucp_lane_index_t max_lanes, ucp_lane_map_t exclude_map,
ucp_lane_index_t *lanes);


ucp_lane_index_t
ucp_proto_common_find_lanes(const ucp_proto_init_params_t *params,
uct_ep_operation_t memtype_op, unsigned flags,
ptrdiff_t max_iov_offs, size_t min_iov,
ucp_lane_type_t lane_type, uint64_t tl_cap_flags,
ucp_lane_index_t max_lanes,
ucp_lane_type_t lane_type, ucs_memory_type_t mem_type,
uint64_t tl_cap_flags, ucp_lane_index_t max_lanes,
ucp_lane_map_t exclude_map,
ucp_lane_index_t *lanes);

Expand Down
9 changes: 5 additions & 4 deletions src/ucp/proto/proto_multi.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

ucs_status_t ucp_proto_multi_init(const ucp_proto_multi_init_params_t *params,
ucp_proto_caps_t *caps,
ucs_memory_type_t mem_type,
ucp_proto_multi_priv_t *mpriv)
{
ucp_context_h context = params->super.super.worker->context;
Expand Down Expand Up @@ -55,8 +56,8 @@ ucs_status_t ucp_proto_multi_init(const ucp_proto_multi_init_params_t *params,

/* Find first lane */
num_lanes = ucp_proto_common_find_lanes_with_min_frag(
&params->super, params->first.lane_type, params->first.tl_cap_flags,
1, 0, lanes);
&params->super, params->first.lane_type, mem_type,
params->first.tl_cap_flags, 1, 0, lanes);
if (num_lanes == 0) {
ucs_trace("no lanes for %s",
ucp_proto_id_field(params->super.super.proto_id, name));
Expand All @@ -65,7 +66,7 @@ ucs_status_t ucp_proto_multi_init(const ucp_proto_multi_init_params_t *params,

/* Find rest of the lanes */
num_lanes += ucp_proto_common_find_lanes_with_min_frag(
&params->super, params->middle.lane_type,
&params->super, params->middle.lane_type, mem_type,
params->middle.tl_cap_flags, UCP_PROTO_MAX_LANES - 1,
UCS_BIT(lanes[0]), lanes + 1);

Expand Down Expand Up @@ -264,7 +265,7 @@ void ucp_proto_multi_probe(const ucp_proto_multi_init_params_t *params)
ucp_proto_caps_t caps;
ucs_status_t status;

status = ucp_proto_multi_init(params, &caps, &mpriv);
status = ucp_proto_multi_init(params, &caps, UCS_MEMORY_TYPE_HOST, &mpriv);
if (status != UCS_OK) {
return;
}
Expand Down
1 change: 1 addition & 0 deletions src/ucp/proto/proto_multi.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ typedef ucs_status_t (*ucp_proto_multi_lane_send_func_t)(ucp_request_t *req,

ucs_status_t ucp_proto_multi_init(const ucp_proto_multi_init_params_t *params,
ucp_proto_caps_t *caps,
ucs_memory_type_t mem_type,
ucp_proto_multi_priv_t *mpriv);


Expand Down
2 changes: 0 additions & 2 deletions src/ucp/proto/proto_select.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ struct ucp_proto_select_param {
uint8_t sys_dev; /* Reply buffer system device */
} UCS_S_PACKED reply;

uint8_t mem_type_flags;

/* Align struct size to uint64_t */
uint8_t padding[2];

Expand Down
3 changes: 2 additions & 1 deletion src/ucp/proto/proto_single.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ ucs_status_t ucp_proto_single_init(const ucp_proto_single_init_params_t *params,
ucs_status_t status;

num_lanes = ucp_proto_common_find_lanes_with_min_frag(
&params->super, params->lane_type, params->tl_cap_flags, 1,
&params->super, params->lane_type,
params->super.super.select_param->mem_type, params->tl_cap_flags, 1,
params->super.exclude_map, &lane);
if (num_lanes == 0) {
ucs_trace("no lanes for %s",
Expand Down
15 changes: 7 additions & 8 deletions src/ucp/rndv/proto_rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ ucp_proto_rndv_ctrl_init_priv(const ucp_proto_rndv_ctrl_init_params_t *params,
/* Use only memory domains for which the unpacking of the remote key was
* successful */
if ((init_params->rkey_config_key != NULL) &&
!(params->super.flags & UCP_PROTO_COMMON_KEEP_MD_MAP)) {
!(params->super.flags & UCP_PROTO_COMMON_INIT_FLAG_KEEP_MD_MAP)) {
rpriv->md_map &= ~init_params->rkey_config_key->unreachable_md_map;
}

Expand Down Expand Up @@ -519,12 +519,10 @@ ucp_proto_rndv_find_ctrl_lane(const ucp_proto_init_params_t *params)
{
ucp_lane_index_t lane, num_lanes;

num_lanes = ucp_proto_common_find_lanes(params, UCT_EP_OP_LAST,
UCP_PROTO_COMMON_INIT_FLAG_HDR_ONLY,
UCP_PROTO_COMMON_OFFSET_INVALID, 1,
UCP_LANE_TYPE_AM,
UCT_IFACE_FLAG_AM_BCOPY, 1, 0,
&lane);
num_lanes = ucp_proto_common_find_lanes(
params, UCT_EP_OP_LAST, UCP_PROTO_COMMON_INIT_FLAG_HDR_ONLY,
UCP_PROTO_COMMON_OFFSET_INVALID, 1, UCP_LANE_TYPE_AM,
UCS_MEMORY_TYPE_HOST, UCT_IFACE_FLAG_AM_BCOPY, 1, 0, &lane);
if (num_lanes == 0) {
ucs_debug("no active message lane for %s",
ucp_proto_id_field(params->proto_id, name));
Expand Down Expand Up @@ -709,7 +707,8 @@ ucp_proto_rndv_bulk_init(const ucp_proto_multi_init_params_t *init_params,
ucp_proto_caps_t multi_caps;
ucs_status_t status;

status = ucp_proto_multi_init(init_params, &multi_caps, mpriv);
status = ucp_proto_multi_init(init_params, &multi_caps,
rpriv->frag_mem_type, mpriv);
if (status != UCS_OK) {
return status;
}
Expand Down
9 changes: 5 additions & 4 deletions src/ucp/rndv/rndv_get.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ ucp_proto_rndv_get_common_probe(const ucp_proto_init_params_t *init_params,
return;
}

rpriv.frag_mem_type = frag_mem_type;
status = ucp_proto_rndv_bulk_init(&params, UCP_PROTO_RNDV_GET_DESC,
UCP_PROTO_RNDV_ATS_NAME, &rpriv, &caps);
if (status != UCS_OK) {
return;
}

rpriv.frag_mem_type = frag_mem_type;
priv_size = UCP_PROTO_MULTI_EXTENDED_PRIV_SIZE(&rpriv, mpriv);
ucp_proto_common_add_proto(&params.super, &caps, &rpriv, priv_size);
}
Expand Down Expand Up @@ -127,7 +127,7 @@ ucp_proto_rndv_get_zcopy_probe(const ucp_proto_init_params_t *init_params)
UCT_EP_OP_LAST,
UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY |
UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING,
0, 0, UCS_MEMORY_TYPE_HOST);
0, 0, init_params->select_param->mem_type);
}

static void
Expand Down Expand Up @@ -322,12 +322,13 @@ ucp_proto_rndv_get_mtype_probe(const ucp_proto_init_params_t *init_params)
status = ucp_proto_rndv_mtype_init(init_params, frag_mem_type,
&mdesc_md_map, &frag_size);
if (status != UCS_OK) {
return;
continue;
}

ucp_proto_rndv_get_common_probe(init_params,
UCS_BIT(UCP_RNDV_MODE_GET_PIPELINE),
frag_size, UCT_EP_OP_PUT_ZCOPY, 0,
frag_size, UCT_EP_OP_PUT_ZCOPY,
UCP_PROTO_COMMON_INIT_FLAG_MTYPE_ZCOPY,
mdesc_md_map, 1, frag_mem_type);
}
}
Expand Down
13 changes: 7 additions & 6 deletions src/ucp/rndv/rndv_put.c
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ ucp_proto_rndv_put_common_probe(const ucp_proto_init_params_t *init_params,
return;
}

rpriv.bulk.frag_mem_type = frag_mem_type;
status = ucp_proto_rndv_bulk_init(&params, UCP_PROTO_RNDV_PUT_DESC,
UCP_PROTO_RNDV_ATP_NAME, &rpriv.bulk,
&caps);
Expand Down Expand Up @@ -346,9 +347,8 @@ ucp_proto_rndv_put_common_probe(const ucp_proto_init_params_t *init_params,
if (send_atp) {
ucs_assert(rpriv.atp_map != 0);
}
rpriv.atp_num_lanes = ucs_popcount(rpriv.atp_map);
rpriv.stat_counter = stat_counter;
rpriv.bulk.frag_mem_type = frag_mem_type;
rpriv.atp_num_lanes = ucs_popcount(rpriv.atp_map);
rpriv.stat_counter = stat_counter;

priv_size = UCP_PROTO_MULTI_EXTENDED_PRIV_SIZE(&rpriv, bulk.mpriv);
ucp_proto_common_add_proto(&params.super, &caps, &rpriv, priv_size);
Expand Down Expand Up @@ -419,7 +419,7 @@ ucp_proto_rndv_put_zcopy_probe(const ucp_proto_init_params_t *init_params)
UCP_PROTO_COMMON_INIT_FLAG_SEND_ZCOPY |
UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING,
0, ucp_proto_rndv_put_zcopy_completion, 0,
UCP_WORKER_STAT_RNDV_PUT_ZCOPY, UCS_MEMORY_TYPE_HOST);
UCP_WORKER_STAT_RNDV_PUT_ZCOPY, init_params->select_param->mem_type);
}

static void
Expand Down Expand Up @@ -592,8 +592,9 @@ ucp_proto_rndv_put_mtype_probe(const ucp_proto_init_params_t *init_params)

ucp_proto_rndv_put_common_probe(
init_params, UCS_BIT(UCP_RNDV_MODE_PUT_PIPELINE), frag_size,
UCT_EP_OP_GET_ZCOPY, 0, mdesc_md_map, comp_cb, 1,
UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY, frag_mem_type);
UCT_EP_OP_GET_ZCOPY, UCP_PROTO_COMMON_INIT_FLAG_MTYPE_ZCOPY,
mdesc_md_map, comp_cb, 1, UCP_WORKER_STAT_RNDV_PUT_MTYPE_ZCOPY,
frag_mem_type);
}

static void
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/rndv/rndv_rtr.c
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ ucp_proto_rndv_rtr_mtype_probe(const ucp_proto_init_params_t *init_params)
.super.memtype_op = UCT_EP_OP_LAST,
.super.flags = UCP_PROTO_COMMON_INIT_FLAG_RESPONSE |
UCP_PROTO_COMMON_INIT_FLAG_ERR_HANDLING |
UCP_PROTO_COMMON_KEEP_MD_MAP,
UCP_PROTO_COMMON_INIT_FLAG_KEEP_MD_MAP,
.super.exclude_map = 0,
.remote_op_id = UCP_OP_ID_RNDV_SEND,
.lane = ucp_proto_rndv_find_ctrl_lane(init_params),
Expand Down

0 comments on commit cef74eb

Please sign in to comment.