Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: extend mem_map structure #711

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/components/tl/ucp/alltoall/alltoall_onesided.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,17 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)

void ucc_tl_ucp_alltoall_onesided_progress(ucc_coll_task_t *ctask)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
long * pSync = TASK_ARGS(task).global_work_buffer;
ucc_tl_ucp_task_t *task = ucc_derived_of(ctask, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
long *pSync = TASK_ARGS(task).global_work_buffer;
ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);

if ((*pSync < gsize) ||
(task->onesided.put_completed < task->onesided.put_posted)) {
ucp_worker_progress(UCC_TL_UCP_TEAM_CTX(team)->worker.ucp_worker);
for (int i = 0; i < ctx->n_rinfo_segs; i++) {
ucp_worker_progress(ctx->remote_info[i].ucp_worker);
}
return;
}

Expand Down
3 changes: 3 additions & 0 deletions src/components/tl/ucp/tl_ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ UCC_CLASS_DECLARE(ucc_tl_ucp_lib_t, const ucc_base_lib_params_t *,
const ucc_base_config_t *);

typedef struct ucc_tl_ucp_remote_info {
ucp_context_h ucp_context;
ucp_worker_h ucp_worker;
ucp_ep_h *eps;
void * va_base;
size_t len;
void * mem_h;
Expand Down
112 changes: 90 additions & 22 deletions src/components/tl/ucp/tl_ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
goto go; \
}

#define ATTR_CHECK(attr, msg, go, ctx) \
if (attr == NULL) { \
tl_error(ctx->super.super.lib, msg); \
ucc_status = UCC_ERR_INVALID_PARAM; \
goto go; \
}

unsigned ucc_tl_ucp_service_worker_progress(void *progress_arg)
{
ucc_tl_ucp_context_t *ctx = (ucc_tl_ucp_context_t *)progress_arg;
Expand Down Expand Up @@ -441,10 +448,10 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx,
{
uint32_t size = oob.n_oob_eps;
uint64_t nsegs = map.n_segments;
ucc_status_t ucc_status = UCC_OK;
ucp_mem_map_params_t mmap_params;
ucp_mem_h mh;
ucs_status_t status;
ucc_status_t ucc_status;
int i;

if (size < 2) {
Expand Down Expand Up @@ -476,30 +483,91 @@ ucc_status_t ucc_tl_ucp_ctx_remote_populate(ucc_tl_ucp_context_t * ctx,
}

for (i = 0; i < nsegs; i++) {
mmap_params.field_mask =
UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mmap_params.address = map.segments[i].address;
mmap_params.length = map.segments[i].len;
mh = NULL;

ctx->remote_info[i].va_base = map.segments[i].address;
ctx->remote_info[i].len = map.segments[i].len;
ctx->remote_info[i].ucp_context = NULL;
ctx->remote_info[i].ucp_worker = NULL;

for (int j = 0; j < map.segments[i].attribute_set.n_attributes; j++) {
ucc_attribute_t *tmp_attribute =
&map.segments[i].attribute_set.attribute[j];
if (tmp_attribute->type == UCC_ATTR_CTX) {
ctx->remote_info[i].ucp_context =
(ucp_context_h)tmp_attribute->attr;
} else if (tmp_attribute->type == UCC_ATTR_QP) {
ctx->remote_info[i].ucp_worker =
(ucp_worker_h)tmp_attribute->attr;
} else if (tmp_attribute->type == UCC_ATTR_MEMH) {
mmap_params.field_mask |=
UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER;
mmap_params.exported_memh_buffer = tmp_attribute->attr;

ATTR_CHECK(ctx->remote_info[i].ucp_context,
"Shared attribute context either missing or ordered "
"after attribute",
fail_mem_map, ctx);

status = ucp_mem_map(ctx->remote_info[i].ucp_context,
&mmap_params, &mh);
if (status == UCS_ERR_UNREACHABLE) {
tl_error(ctx->super.super.lib, "exported memh unsupported");
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map; //FIXME
} else if (status < UCS_OK) {
tl_error(ctx->super.super.lib, "error on ucp_mem_map");
goto fail_mem_map;
}

status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh);
if (UCS_OK != status) {
tl_error(ctx->super.super.lib,
"ucp_mem_map failed with error code: %d", status);
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;
status = ucp_rkey_pack(ctx->remote_info[i].ucp_context, mh,
&ctx->remote_info[i].packed_key,
&ctx->remote_info[i].packed_key_len);
if (UCS_OK != status) {
tl_error(ctx->super.super.lib,
"failed to pack UCP key with error code: %d",
status);
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;
}
} else if (tmp_attribute->type == UCC_ATTR_MKEY) {
ctx->remote_info[i].packed_key = tmp_attribute->attr;
ctx->remote_info[i].packed_key_len = tmp_attribute->attr_len;
}
}
ctx->remote_info[i].mem_h = (void *)mh;
status = ucp_rkey_pack(ctx->worker.ucp_context, mh,
&ctx->remote_info[i].packed_key,
&ctx->remote_info[i].packed_key_len);
if (UCS_OK != status) {
tl_error(ctx->super.super.lib,
"failed to pack UCP key with error code: %d", status);
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;

if (!ctx->remote_info[i].ucp_context) {
ctx->remote_info[i].ucp_context = ctx->worker.ucp_context;
}
if (!ctx->remote_info[i].ucp_worker) {
ctx->remote_info[i].ucp_worker = ctx->worker.ucp_worker;
}

if (ctx->remote_info[i].packed_key == NULL) {
mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mmap_params.address = map.segments[i].address;
mmap_params.length = map.segments[i].len;
/* we do everything we used to do */
status =
ucp_mem_map(ctx->remote_info[i].ucp_context, &mmap_params, &mh);
if (UCS_OK != status) {
tl_error(ctx->super.super.lib,
"ucp_mem_map failed with error code: %d", status);
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;
}
status = ucp_rkey_pack(ctx->worker.ucp_context, mh,
&ctx->remote_info[i].packed_key,
&ctx->remote_info[i].packed_key_len);
if (UCS_OK != status) {
tl_error(ctx->super.super.lib,
"failed to pack UCP key with error code: %d", status);
ucc_status = ucs_status_to_ucc_status(status);
goto fail_mem_map;
}
}
ctx->remote_info[i].va_base = map.segments[i].address;
ctx->remote_info[i].len = map.segments[i].len;
ctx->remote_info[i].mem_h = mh;
}
ctx->n_rinfo_segs = nsegs;

Expand Down
44 changes: 37 additions & 7 deletions src/components/tl/ucp/tl_ucp_sendrecv.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep,
void *keys;
void *offset;
ptrdiff_t base_offset;
void *addrs;
uint64_t *addr_lens;

*segment = -1;
core_rank = ucc_ep_map_eval(UCC_TL_TEAM_MAP(team), peer);
Expand All @@ -249,7 +251,9 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep,
base_offset = (ptrdiff_t)(TL_UCP_EP_ADDR_ONESIDED_INFO(offset, ctx));
rvas = (uint64_t *)base_offset;
key_sizes = PTR_OFFSET(base_offset, (section_offset * 2));
keys = PTR_OFFSET(base_offset, (section_offset * 3));
addr_lens = PTR_OFFSET(base_offset, (section_offset * 3));
keys = PTR_OFFSET(base_offset, (section_offset * 4));
addrs = PTR_OFFSET(base_offset, (section_offset * 5));//FIXME

for (int i = 0; i < ctx->n_rinfo_segs; i++) {
if ((uint64_t)va >= (uint64_t)team->va_base[i] &&
Expand All @@ -264,6 +268,32 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep,
"attempt to perform one-sided operation on non-registered memory");
return UCC_ERR_NOT_FOUND;
}
if (ctx->worker.ucp_context == ctx->remote_info[*segment].ucp_context) {
ucc_status_t ucc_status = ucc_tl_ucp_get_ep(team, peer, ep);
if (ucc_unlikely(UCC_OK != ucc_status)) {
return ucc_status;
}
} else {
if (!ctx->remote_info[*segment].eps[peer]) {
// make it
ucs_status_t ucs_status;
ucp_ep_params_t ep_params = {
.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS,
.address = (ucp_address_t *)PTR_OFFSET(
addrs, peer * addr_lens[peer]), //FIXME
};
ucs_status = ucp_ep_create(ctx->remote_info[*segment].ucp_worker,
&ep_params, ep);
if (UCS_OK != ucs_status) {
tl_error(ctx->super.super.lib, "ucp returned connect error: %s",
ucs_status_string(ucs_status));
return ucs_status_to_ucc_status(ucs_status);
}
ctx->remote_info[*segment].eps[peer] = *ep;
} else {
*ep = ctx->remote_info[*segment].eps[peer];
}
}
if (ucc_unlikely(NULL == UCC_TL_UCP_REMOTE_RKEY(ctx, peer, *segment))) {
ucs_status_t ucs_status =
ucp_ep_rkey_unpack(*ep, PTR_OFFSET(keys, key_offset),
Expand Down Expand Up @@ -328,12 +358,12 @@ static inline ucc_status_t ucc_tl_ucp_put_nb(void *buffer, void *target,
ucs_status_ptr_t ucp_status;
ucc_status_t status;
ucp_ep_h ep;

/*
status = ucc_tl_ucp_get_ep(team, dest_group_rank, &ep);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}

*/
status = ucc_tl_ucp_resolve_p2p_by_va(team, target, &ep, dest_group_rank,
&rva, &rkey, &segment);
if (ucc_unlikely(UCC_OK != status)) {
Expand Down Expand Up @@ -371,12 +401,12 @@ static inline ucc_status_t ucc_tl_ucp_get_nb(void *buffer, void *target,
ucs_status_ptr_t ucp_status;
ucc_status_t status;
ucp_ep_h ep;

/*
status = ucc_tl_ucp_get_ep(team, dest_group_rank, &ep);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}

*/
status = ucc_tl_ucp_resolve_p2p_by_va(team, target, &ep, dest_group_rank,
&rva, &rkey, &segment);
if (ucc_unlikely(UCC_OK != status)) {
Expand Down Expand Up @@ -414,12 +444,12 @@ static inline ucc_status_t ucc_tl_ucp_atomic_inc(void * target,
ucs_status_ptr_t ucp_status;
ucc_status_t status;
ucp_ep_h ep;

/*
status = ucc_tl_ucp_get_ep(team, dest_group_rank, &ep);
if (ucc_unlikely(UCC_OK != status)) {
return status;
}

*/
status = ucc_tl_ucp_resolve_p2p_by_va(team, target, &ep, dest_group_rank,
&rva, &rkey, &segment);
if (ucc_unlikely(UCC_OK != status)) {
Expand Down
24 changes: 22 additions & 2 deletions src/ucc/api/ucc.h
Original file line number Diff line number Diff line change
Expand Up @@ -868,13 +868,33 @@ typedef struct ucc_oob_coll {
typedef ucc_oob_coll_t ucc_context_oob_coll_t;
typedef ucc_oob_coll_t ucc_team_oob_coll_t;

typedef enum {
UCC_ATTR_MKEY,
UCC_ATTR_MEMH,
UCC_ATTR_CTX,
UCC_ATTR_QP,
} ucc_attribute_type;

typedef struct ucc_attribute {
ucc_attribute_type type;
void *attr;
size_t attr_len;
} ucc_attribute_t;

typedef struct ucc_attribute_set {
ucc_attribute_t *attribute;
uint64_t n_attributes;
} ucc_attribute_set_t;

/**
*
* @ingroup UCC_CONTEXT_DT
*/
typedef struct ucc_mem_map {
void * address; /*!< the address of a buffer to be attached to a UCC context */
size_t len; /*!< the length of the buffer */
void *address; /*!< The address of a buffer to be attached to
a UCC context */
size_t len; /*!< The length of the buffer */
ucc_attribute_set_t attribute_set;
} ucc_mem_map_t;

/**
Expand Down
Loading