Skip to content

Commit

Permalink
coll/accelerator: allow to select functions to register
Browse files Browse the repository at this point in the history
This PR introduces the ability to register the component only to the
select functions specified by an MCA parameter string. The idea and the
code is based on the UCC component, and some of the bits might be moved
later to coll/base to make the mechanism more gnerally available to
other components as well.

Note, that the PR introduces the define statments for all MPI collective
operations, not just the ones support by the component at the moment,
since it is a bitmask based operation, and we anticipate to add support
for more collective operations into coll/accelerator shortly

Signed-off-by: Edgar Gabriel <[email protected]>
  • Loading branch information
edgargabriel committed Dec 18, 2024
1 parent 60807d7 commit 519234a
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 12 deletions.
32 changes: 32 additions & 0 deletions ompi/mca/coll/accelerator/coll_accelerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,36 @@

BEGIN_C_DECLS

#define COLL_ACC_ALLGATHER 0x00000001
#define COLL_ACC_ALLGATHERV 0x00000002
#define COLL_ACC_ALLREDUCE 0x00000004
#define COLL_ACC_ALLTOALL 0x00000008
#define COLL_ACC_ALLTOALLV 0x00000010
#define COLL_ACC_ALLTOALLW 0x00000020
#define COLL_ACC_BARRIER 0x00000040
#define COLL_ACC_BCAST 0x00000080
#define COLL_ACC_EXSCAN 0x00000100
#define COLL_ACC_GATHER 0x00000200
#define COLL_ACC_GATHERV 0x00000400
#define COLL_ACC_REDUCE 0x00000800
#define COLL_ACC_REDUCE_SCATTER 0x00001000
#define COLL_ACC_REDUCE_SCATTER_BLOCK 0x00002000
#define COLL_ACC_REDUCE_LOCAL 0x00004000
#define COLL_ACC_SCAN 0x00008000
#define COLL_ACC_SCATTER 0x00010000
#define COLL_ACC_SCATTERV 0x00020000
#define COLL_ACC_NEIGHBOR_ALLGATHER 0x00040000
#define COLL_ACC_NEIGHBOR_ALLGATHERV 0x00080000
#define COLL_ACC_NEIGHBOR_ALLTOALL 0x00100000
#define COLL_ACC_NEIGHBOR_ALLTTOALLV 0x00200000
#define COLL_ACC_NEIGHBOR_ALLTTOALLW 0x00400000
#define COLL_ACC_LASTCOLL 0x00800000

#define COLL_ACCELERATOR_CTS_STR "allreduce,reduce_scatter_block,reduce_local,reduce,scan,exscan"
#define COLL_ACCELERATOR_CTS COLL_ACC_ALLREDUCE | COLL_ACC_REDUCE | \
COLL_ACC_REDUCE_SCATTER_BLOCK | COLL_ACC_REDUCE_LOCAL | \
COLL_ACC_EXSCAN | COLL_ACC_SCAN

/* API functions */

int mca_coll_accelerator_init_query(bool enable_progress_threads,
Expand Down Expand Up @@ -134,6 +164,8 @@ typedef struct mca_coll_accelerator_component_t {

int priority; /* Priority of this component */
int disable_accelerator_coll; /* Force disable of the accelerator collective component */
char *cts; /* String of collective operations which the component shall register itself */
uint64_t cts_requested;
} mca_coll_accelerator_component_t;

/* Globally exported variables */
Expand Down
78 changes: 77 additions & 1 deletion ompi/mca/coll/accelerator/coll_accelerator_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Copyright (c) 2015 Los Alamos National Security, LLC. All rights
* reserved.
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
* Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand All @@ -21,6 +22,7 @@
#include "mpi.h"
#include "ompi/constants.h"
#include "coll_accelerator.h"
#include "opal/util/argv.h"

/*
* Public string showing the coll ompi_accelerator component version number
Expand All @@ -31,6 +33,7 @@ const char *mca_coll_accelerator_component_version_string =
/*
* Local function
*/
static int accelerator_open(void);
static int accelerator_register(void);

/*
Expand All @@ -52,6 +55,7 @@ mca_coll_accelerator_component_t mca_coll_accelerator_component = {
OMPI_RELEASE_VERSION),

/* Component open and close functions */
.mca_open_component = accelerator_open,
.mca_register_component_params = accelerator_register,
},
.collm_data = {
Expand All @@ -75,7 +79,8 @@ mca_coll_accelerator_component_t mca_coll_accelerator_component = {
static int accelerator_register(void)
{
(void) mca_base_component_var_register(&mca_coll_accelerator_component.super.collm_version,
"priority", "Priority of the accelerator coll component; only relevant if barrier_before or barrier_after is > 0",
"priority", "Priority of the accelerator coll component; only relevant if barrier_before "
"or barrier_after is > 0",
MCA_BASE_VAR_TYPE_INT, NULL, 0, 0,
OPAL_INFO_LVL_6,
MCA_BASE_VAR_SCOPE_READONLY,
Expand All @@ -88,5 +93,76 @@ static int accelerator_register(void)
MCA_BASE_VAR_SCOPE_READONLY,
&mca_coll_accelerator_component.disable_accelerator_coll);

mca_coll_accelerator_component.cts = COLL_ACCELERATOR_CTS_STR;
(void)mca_base_component_var_register(&mca_coll_accelerator_component.super.collm_version,
"cts", "Comma separated list of collectives to be enabled",
MCA_BASE_VAR_TYPE_STRING, NULL, 0, MCA_BASE_VAR_FLAG_SETTABLE,
OPAL_INFO_LVL_6, MCA_BASE_VAR_SCOPE_ALL, &mca_coll_accelerator_component.cts);

return OMPI_SUCCESS;
}


/* The string parsing is based on the code available in the coll/ucc component */
static uint64_t mca_coll_accelerator_str_to_type(const char *str)
{
if (0 == strcasecmp(str, "allreduce")) {
return COLL_ACC_ALLREDUCE;
} else if (0 == strcasecmp(str, "reduce_scatter_block")) {
return COLL_ACC_REDUCE_SCATTER_BLOCK;
} else if (0 == strcasecmp(str, "reduce_local")) {
return COLL_ACC_REDUCE_LOCAL;
} else if (0 == strcasecmp(str, "reduce")) {
return COLL_ACC_REDUCE;
} else if (0 == strcasecmp(str, "exscan")) {
return COLL_ACC_EXSCAN;
} else if (0 == strcasecmp(str, "scan")) {
return COLL_ACC_SCAN;
}
opal_output(0, "incorrect value for cts: %s, allowed: %s",
str, COLL_ACCELERATOR_CTS_STR);
return COLL_ACC_LASTCOLL;
}

static void accelerator_init_default_cts(void)
{
mca_coll_accelerator_component_t *cm = &mca_coll_accelerator_component;
bool disable;
char** cts;
int n_cts, i;
char* str;
uint64_t *ct, c;

disable = (cm->cts[0] == '^') ? true : false;
cts = opal_argv_split(disable ? (cm->cts + 1) : cm->cts, ',');
n_cts = opal_argv_count(cts);
cm->cts_requested = disable ? COLL_ACCELERATOR_CTS : 0;
for (i = 0; i < n_cts; i++) {
if (('i' == cts[i][0]) || ('I' == cts[i][0])) {
/* non blocking collective setting */
opal_output(0, "coll/accelerator component does not support non-blocking collectives at this time."
" Ignoring collective: %s\n", cts[i]);
continue;
} else {
str = cts[i];
ct = &cm->cts_requested;
}
c = mca_coll_accelerator_str_to_type(str);
if (COLL_ACC_LASTCOLL == c) {
*ct = COLL_ACCELERATOR_CTS;
break;
}
if (disable) {
(*ct) &= ~c;
} else {
(*ct) |= c;
}
}
opal_argv_free(cts);
}

static int accelerator_open(void)
{
accelerator_init_default_cts();
return OMPI_SUCCESS;
}
25 changes: 14 additions & 11 deletions ompi/mca/coll/accelerator/coll_accelerator_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
* Copyright (c) 2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
* $COPYRIGHT$
*
Expand Down Expand Up @@ -106,18 +106,21 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
}


#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \
#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api, __API) \
do \
{ \
if ((__comm)->c_coll->coll_##__api) \
{ \
MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \
if (mca_coll_accelerator_component.cts_requested & COLL_ACC_##__API) \
{ \
MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \
} \
} \
else \
{ \
opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \
"cuda", #__api, ompi_process_info.nodename, \
"accelerator", #__api, ompi_process_info.nodename, \
mca_coll_accelerator_component.priority); \
} \
} while (0)
Expand All @@ -141,14 +144,14 @@ mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
{
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;

ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce, ALLREDUCE);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce, REDUCE);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local, REDUCE_LOCAL);
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block, REDUCE_SCATTER_BLOCK);
if (!OMPI_COMM_IS_INTER(comm)) {
/* MPI does not define scan/exscan on intercommunicators */
ACCELERATOR_INSTALL_COLL_API(comm, s, exscan);
ACCELERATOR_INSTALL_COLL_API(comm, s, scan);
ACCELERATOR_INSTALL_COLL_API(comm, s, exscan, EXSCAN);
ACCELERATOR_INSTALL_COLL_API(comm, s, scan, SCAN);
}

return OMPI_SUCCESS;
Expand Down

0 comments on commit 519234a

Please sign in to comment.