Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/components/cl/hier/cl_hier.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) Meta Platforms, Inc. and affiliates. 2022.
*
* See file LICENSE for terms.
Expand All @@ -13,6 +13,7 @@
#include "coll_score/ucc_coll_score.h"
#include "utils/ucc_mpool.h"
#include "schedule/ucc_schedule_pipelined.h"
#include "core/ucc_service_coll.h"

#ifdef HAVE_PROFILING_CL_HIER
#include "utils/profile/ucc_profile_on.h"
Expand Down Expand Up @@ -98,13 +99,20 @@ typedef struct ucc_hier_sbgp {
int n_tls;
} ucc_hier_sbgp_t;

typedef struct ucc_cl_hier_team_create_req {
ucc_team_multiple_req_t *create_req;
ucc_service_coll_req_t *global_status_req;
ucc_status_t local_status;
ucc_status_t global_status;
} ucc_cl_hier_team_create_req_t;

typedef struct ucc_cl_hier_team {
ucc_cl_team_t super;
ucc_team_multiple_req_t *team_create_req;
unsigned n_tl_teams;
ucc_coll_score_t *score;
ucc_hier_sbgp_t sbgps[UCC_HIER_SBGP_LAST];
ucc_hier_sbgp_type_t top_sbgp;
ucc_cl_team_t super;
ucc_cl_hier_team_create_req_t *team_req;
unsigned n_tl_teams;
ucc_coll_score_t *score;
ucc_hier_sbgp_t sbgps[UCC_HIER_SBGP_LAST];
ucc_hier_sbgp_type_t top_sbgp;
} ucc_cl_hier_team_t;
UCC_CLASS_DECLARE(ucc_cl_hier_team_t, ucc_base_context_t *,
const ucc_base_team_params_t *);
Expand Down
96 changes: 70 additions & 26 deletions src/components/cl/hier/cl_hier_team.c
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
/**
* Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "cl_hier.h"
#include "utils/ucc_malloc.h"
#include "core/ucc_team.h"
#include "core/ucc_service_coll.h"
#include "cl_hier_coll.h"

#define SBGP_SET(_team, _sbgp, _enable) \
Expand Down Expand Up @@ -41,6 +40,8 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
ucc_config_names_array_t *tls;
ucc_subset_t subset;
struct ucc_team_team_desc *d;
ucc_team_multiple_req_t *team_create_req;

if (!params->team->topo) {
cl_info(cl_context->lib,
"can't create hier team without topology data");
Expand All @@ -53,6 +54,12 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
}

UCC_CLASS_CALL_SUPER_INIT(ucc_cl_team_t, &ctx->super, params);
self->team_req = (ucc_cl_hier_team_create_req_t*)
ucc_malloc(sizeof(ucc_cl_hier_team_create_req_t));
if (!self->team_req) {
return UCC_ERR_NO_MEMORY;
}

memset(self->sbgps, 0, sizeof(self->sbgps));
ucc_cl_hier_enable_sbgps(self);
n_sbgp_teams = 0;
Expand Down Expand Up @@ -88,7 +95,7 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
}
}

status = ucc_team_multiple_req_alloc(&self->team_create_req, n_sbgp_teams);
status = ucc_team_multiple_req_alloc(&team_create_req, n_sbgp_teams);
if (UCC_OK != status) {
cl_error(cl_context->lib, "failed to allocate team req multiple");
goto err;
Expand All @@ -102,7 +109,7 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
hs = &self->sbgps[i];
if (hs->state == UCC_HIER_SBGP_ENABLED) {
for (t = 0; t < hs->n_tls; t++) {
d = &self->team_create_req->descs[j];
d = &team_create_req->descs[j];
d->param.params.mask = UCC_TEAM_PARAM_FIELD_EP_RANGE |
UCC_TEAM_PARAM_FIELD_EP |
UCC_TEAM_PARAM_FIELD_TEAM_SIZE |
Expand Down Expand Up @@ -134,15 +141,18 @@ UCC_CLASS_INIT_FUNC(ucc_cl_hier_team_t, ucc_base_context_t *cl_context,
}
}

status = ucc_tl_team_create_multiple(self->team_create_req);

status = ucc_tl_team_create_multiple(team_create_req);
if (status < 0) {
cl_error(cl_context->lib, "failed to post tl team create (%d)", status);
goto err;
}
self->team_req->create_req = team_create_req;
self->team_req->global_status_req = NULL;
cl_info(cl_context->lib, "posted cl team: %p", self);
return UCC_OK;
err:
ucc_team_multiple_req_free(self->team_create_req);
ucc_team_multiple_req_free(team_create_req);
return status;
}

Expand All @@ -162,15 +172,15 @@ ucc_status_t ucc_cl_hier_team_destroy(ucc_base_team_t *cl_team)
int i, j;
ucc_hier_sbgp_t *hs;

if (NULL == team->team_create_req) {
status = ucc_team_multiple_req_alloc(&team->team_create_req,
if (NULL == team->team_req->create_req) {
status = ucc_team_multiple_req_alloc(&team->team_req->create_req,
team->n_tl_teams);
if (UCC_OK != status) {
cl_error(ctx->super.super.lib,
"failed to allocate team req multiple");
return status;
}
team->team_create_req->n_teams = 0;
team->team_req->create_req->n_teams = 0;
for (i = 0; i < UCC_HIER_SBGP_LAST; i++) {
hs = &team->sbgps[i];
if (hs->state == UCC_HIER_SBGP_ENABLED) {
Expand All @@ -180,26 +190,27 @@ ucc_status_t ucc_cl_hier_team_destroy(ucc_base_team_t *cl_team)
for (j = 0; j < hs->n_tls; j++) {
if (hs->tl_teams[j]) {
ucc_tl_context_put(hs->tl_ctxs[j]);
team->team_create_req
->descs[team->team_create_req->n_teams++]
team->team_req->create_req
->descs[team->team_req->create_req->n_teams++]
.team = hs->tl_teams[j];
}
}
}
}
}
status = ucc_tl_team_destroy_multiple(team->team_create_req);
status = ucc_tl_team_destroy_multiple(team->team_req->create_req);
if (UCC_INPROGRESS == status) {
return status;
}
for (i = 0; i < team->team_create_req->n_teams; i++) {
if (team->team_create_req->descs[i].status != UCC_OK) {
for (i = 0; i < team->team_req->create_req->n_teams; i++) {
if (team->team_req->create_req->descs[i].status != UCC_OK) {
cl_error(ctx->super.super.lib, "tl team destroy failed (%d)",
status);
status = team->team_create_req->descs[i].status;
status = team->team_req->create_req->descs[i].status;
}
}
ucc_team_multiple_req_free(team->team_create_req);
ucc_team_multiple_req_free(team->team_req->create_req);
ucc_free(team->team_req);
UCC_CLASS_DELETE_FUNC_NAME(ucc_cl_hier_team_t)(cl_team);
return status;
}
Expand All @@ -208,24 +219,32 @@ ucc_status_t ucc_cl_hier_team_create_test(ucc_base_team_t *cl_team)
{
ucc_cl_hier_team_t *team = ucc_derived_of(cl_team, ucc_cl_hier_team_t);
ucc_cl_hier_context_t *ctx = UCC_CL_HIER_TEAM_CTX(team);
ucc_status_t status;
int i;
ucc_coll_score_t *score, *score_merge;
ucc_status_t status;
ucc_coll_score_t *score, *score_merge;
struct ucc_team_team_desc *d;
ucc_hier_sbgp_t *hs;
ucc_subset_t subset;
int i;

status = ucc_tl_team_create_multiple(team->team_create_req);
if (status != UCC_OK) {
if (team->team_req->global_status_req) {
/* all team create stages are done, checking global status */
goto check_global_status;
}

status = ucc_tl_team_create_multiple(team->team_req->create_req);
if (status == UCC_INPROGRESS) {
return status;
} else if (status != UCC_OK) {
goto check_global_status;
}

team->n_tl_teams = 0;

/* TL teams are created: get scores and merge them to produce
* score map for each sbgp
*/
for (i = 0; i < team->team_create_req->n_teams; i++) {
d = &team->team_create_req->descs[i];
for (i = 0; i < team->team_req->create_req->n_teams; i++) {
d = &team->team_req->create_req->descs[i];
ucc_hier_sbgp_type_t st = (ucc_hier_sbgp_type_t)d->args[0];
int tl = (int)d->args[1];

Expand Down Expand Up @@ -288,8 +307,8 @@ ucc_status_t ucc_cl_hier_team_create_test(ucc_base_team_t *cl_team)
}
}
}
ucc_team_multiple_req_free(team->team_create_req);
team->team_create_req = NULL;
ucc_team_multiple_req_free(team->team_req->create_req);
team->team_req->create_req = NULL;

if (SBGP_EXISTS(team, NODE_LEADERS)) {
team->top_sbgp = UCC_HIER_SBGP_NODE_LEADERS;
Expand All @@ -298,7 +317,32 @@ ucc_status_t ucc_cl_hier_team_create_test(ucc_base_team_t *cl_team)
team->top_sbgp = UCC_HIER_SBGP_NODE;
}

return status;
check_global_status:
if (!team->team_req->global_status_req) {
subset.map.type = UCC_EP_MAP_FULL;
subset.map.ep_num = team->super.super.params.size;
subset.myrank = team->super.super.params.rank;
team->team_req->local_status = status;
status = ucc_service_allreduce(team->super.super.params.team,
&team->team_req->local_status,
&team->team_req->global_status,
UCC_DT_INT32, 1, UCC_OP_MIN, subset,
&team->team_req->global_status_req);
if (status != UCC_OK) {
cl_error(ctx->super.super.lib, "failed to start service allreduce");
return status;
}
}

status = ucc_service_coll_test(team->team_req->global_status_req);
if (status == UCC_INPROGRESS) {
return status;
}
ucc_service_coll_finalize(team->team_req->global_status_req);
if (status != UCC_OK) {
return status;
}
return team->team_req->global_status;
}

ucc_status_t ucc_cl_hier_team_get_scores(ucc_base_team_t *cl_team,
Expand Down
11 changes: 7 additions & 4 deletions src/core/ucc_team.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) Meta Platforms, Inc. and affiliates. 2022.
*
* See file LICENSE for terms.
Expand Down Expand Up @@ -282,7 +282,8 @@ static ucc_status_t ucc_team_create_cls(ucc_context_t *context,
status = cl_iface->team.create_test(b_team);
if (status < 0) {
team->n_cl_teams--;
ucc_info("failed to create CL %s team", cl_iface->super.name);
ucc_info("failed to create CL %s team, team_id %d",
cl_iface->super.name, team->id);
cl_iface->team.destroy(b_team);
} else if (status == UCC_INPROGRESS) {
return status;
Expand All @@ -294,12 +295,14 @@ static ucc_status_t ucc_team_create_cls(ucc_context_t *context,
status = cl_iface->team.create_post(&context->cl_ctx[i]->super,
&team->bp, &b_team);
if (status != UCC_OK) {
ucc_info("failed to create CL %s team", cl_iface->super.name);
ucc_info("failed to create CL %s team, team_id %d",
cl_iface->super.name, team->id);
continue;
}
status = cl_iface->team.create_test(b_team);
if (status < 0) {
ucc_info("failed to create CL %s team", cl_iface->super.name);
ucc_info("failed to create CL %s team, team_id %d",
cl_iface->super.name, team->id);
cl_iface->team.destroy(b_team);
continue;
}
Expand Down