From 9dc152c91fe3ed24b78f9b50f7f4dadd952bb465 Mon Sep 17 00:00:00 2001 From: "Vinciguerra, Armando" Date: Wed, 16 Oct 2024 15:34:07 -0400 Subject: [PATCH] Added code for special case of in place scan --- mpp/shmemx.h4 | 15 ++- src/collectives.c | 214 +++++++++++++++++++++++----------------- src/collectives_c.c4 | 8 +- src/shmem_collectives.h | 48 ++++----- 4 files changed, 165 insertions(+), 120 deletions(-) diff --git a/mpp/shmemx.h4 b/mpp/shmemx.h4 index 92b38519..6729ddf5 100644 --- a/mpp/shmemx.h4 +++ b/mpp/shmemx.h4 @@ -38,7 +38,7 @@ include(shmemx_c_func.h4)dnl /* SHMEMX constant(s) are included in MAX_HINTS value in shmem-def.h */ #define SHMEMX_MALLOC_NO_BARRIER (1l<<2) -/* C++ overloaded declarations */ +/* C overloaded declarations */ #ifdef __cplusplus } /* extern "C" */ @@ -119,6 +119,19 @@ SHMEM_BIND_C11_RMA(`SHMEM_C11_GEN_IBGET', `, \') \ uint64_t*: shmemx_signal_add \ )(__VA_ARGS__) +define(`SHMEM_C11_GEN_EXSCAN', ` $2*: shmemx_$1_sum_exscan')dnl +#define shmemx_sum_exscan(...) \ + _Generic(SHMEM_C11_TYPE_EVAL_PTR(SHMEM_C11_ARG1(__VA_ARGS__)), \ +SHMEM_BIND_C11_RMA(`SHMEM_C11_GEN_EXSCAN', `, \') \ + )(__VA_ARGS__) + +define(`SHMEM_C11_GEN_INSCAN', ` $2*: shmemx_$1_sum_inscan')dnl +#define shmemx_sum_inscan(...) \ + _Generic(SHMEM_C11_TYPE_EVAL_PTR(SHMEM_C11_ARG1(__VA_ARGS__)), \ +SHMEM_BIND_C11_RMA(`SHMEM_C11_GEN_INSCAN', `, \') \ + )(__VA_ARGS__) + + #endif /* C11 */ #endif /* SHMEMX_H */ diff --git a/src/collectives.c b/src/collectives.c index 4a6a1177..5461aa54 100644 --- a/src/collectives.c +++ b/src/collectives.c @@ -998,99 +998,115 @@ shmem_internal_scan_linear(void *target, const void *source, size_t count, size_ { /* scantype is 0 for inscan and 1 for exscan */ - - long zero = 0, one = 1; + long zero = 0, one = 1; long completion = 0; + int free_source = 0; if (count == 0) return; - - int pe, i; + + int pe, i; + + /* In-place scan: copy source data to a temporary buffer so we can use + * the symmetric buffer to accumulate scan data. */ + if (target == source) { + void *tmp = malloc(count * type_size); + + if (NULL == tmp) + RAISE_ERROR_MSG("Unable to allocate %zub temporary buffer\n", count*type_size); + + shmem_internal_copy_self(tmp, target, count * type_size); + free_source = 1; + source = tmp; + + shmem_internal_sync(PE_start, PE_stride, PE_size, pSync + 2); + } if (PE_start == shmem_internal_my_pe) { - - /* initialize target buffer. The put + + /* initialize target buffer. The put will flush any atomic cache value that may currently exist. */ - if(scantype) - { - /* Exclude own value for EXSCAN */ - //Create an array of size (count * type_size) of zeroes - uint8_t *zeroes = (uint8_t *) calloc(count, type_size); - shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size, + if(scantype) + { + /* Exclude own value for EXSCAN */ + //Create an array of size (count * type_size) of zeroes + uint8_t *zeroes = (uint8_t *) calloc(count, type_size); + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size, shmem_internal_my_pe, &completion); - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_quiet(SHMEM_CTX_DEFAULT); - free(zeroes); - } - - + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_quiet(SHMEM_CTX_DEFAULT); + free(zeroes); + } + + /* Send contribution to all */ for (pe = PE_start + PE_stride*scantype, i = scantype ; i < PE_size ; i++, pe += PE_stride) { - - shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size, - pe, &completion); - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_fence(SHMEM_CTX_DEFAULT); - + + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size, + pe, &completion); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); + } - - for (pe = PE_start + PE_stride, i = 1 ; + + for (pe = PE_start + PE_stride, i = 1 ; i < PE_size ; i++, pe += PE_stride) { - shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe); - } - - /* Wait for others to acknowledge initialization */ + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe); + } + + /* Wait for others to acknowledge initialization */ SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, PE_size - 1); - - /* reset pSync */ + + /* reset pSync */ shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); - - - /* Let everyone know sending can start */ - for (pe = PE_start + PE_stride, i = 1 ; + + + /* Let everyone know sending can start */ + for (pe = PE_start + PE_stride, i = 1 ; i < PE_size ; i++, pe += PE_stride) { - shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe); - } - - + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe); + } } else { - - /* wait for clear to intialization */ + + /* wait for clear to intialization */ SHMEM_WAIT(pSync, 0); /* reset pSync */ shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); - /* Send contribution to all pes larger than itself */ - for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ; + /* Send contribution to all pes larger than itself */ + for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ; i < PE_size; i++, pe += PE_stride) { - shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size, + shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size, pe, op, datatype, &completion); - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_fence(SHMEM_CTX_DEFAULT); - + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); + } - - shmem_internal_atomic(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), + + shmem_internal_atomic(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), PE_start, SHM_INTERNAL_SUM, SHM_INTERNAL_LONG); - - SHMEM_WAIT(pSync, 0); - - /* reset pSync */ + + SHMEM_WAIT(pSync, 0); + + /* reset pSync */ shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); } + + if (free_source) + free((void *)source); } @@ -1103,48 +1119,61 @@ shmem_internal_scan_ring(void *target, const void *source, size_t count, size_t { /* scantype is 0 for inscan and 1 for exscan */ - - long zero = 0, one = 1; + long zero = 0, one = 1; long completion = 0; + int free_source = 0; + + /* In-place scan: copy source data to a temporary buffer so we can use + * the symmetric buffer to accumulate scan data. */ + if (target == source) { + void *tmp = malloc(count * type_size); + + if (NULL == tmp) + RAISE_ERROR_MSG("Unable to allocate %zub temporary buffer\n", count*type_size); + + shmem_internal_copy_self(tmp, target, count * type_size); + free_source = 1; + source = tmp; + + shmem_internal_sync(PE_start, PE_stride, PE_size, pSync + 2); + } if (count == 0) return; - - int pe, i; + + int pe, i; if (PE_start == shmem_internal_my_pe) { - - /* initialize target buffer. The put + /* initialize target buffer. The put will flush any atomic cache value that may currently exist. */ - if(scantype) - { - /* Exclude own value for EXSCAN */ - //Create an array of size (count * type_size) of zeroes - uint8_t *zeroes = (uint8_t *) calloc(count, type_size); - shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size, + if(scantype) + { + /* Exclude own value for EXSCAN */ + //Create an array of size (count * type_size) of zeroes + uint8_t *zeroes = (uint8_t *) calloc(count, type_size); + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size, shmem_internal_my_pe, &completion); - free(zeroes); - } - - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_quiet(SHMEM_CTX_DEFAULT); - + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_quiet(SHMEM_CTX_DEFAULT); + free(zeroes); + } + /* Send contribution to all */ for (pe = PE_start + PE_stride*scantype, i = scantype ; i < PE_size ; i++, pe += PE_stride) { - - shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size, - pe, &completion); - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_fence(SHMEM_CTX_DEFAULT); + + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size, + pe, &completion); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); } - - /* Let next pe know that it's safe to send to us */ - if(shmem_internal_my_pe + PE_stride < PE_size) - shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride); + + /* Let next pe know that it's safe to send to us */ + if(shmem_internal_my_pe + PE_stride < PE_size) + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride); /* Wait for others to acknowledge sending data */ SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, PE_size - 1); @@ -1161,24 +1190,27 @@ shmem_internal_scan_ring(void *target, const void *source, size_t count, size_t shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); - /* Send contribution to all pes larger than itself */ - for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ; + /* Send contribution to all pes larger than itself */ + for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ; i < PE_size; i++, pe += PE_stride) { - shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size, + shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size, pe, op, datatype, &completion); - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_fence(SHMEM_CTX_DEFAULT); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); } - - /* Let next pe know that it's safe to send to us */ - if(shmem_internal_my_pe + PE_stride < PE_size) - shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride); - + + /* Let next pe know that it's safe to send to us */ + if(shmem_internal_my_pe + PE_stride < PE_size) + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride); + shmem_internal_atomic(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), PE_start, SHM_INTERNAL_SUM, SHM_INTERNAL_LONG); } + + if (free_source) + free((void *)source); } /***************************************** diff --git a/src/collectives_c.c4 b/src/collectives_c.c4 index eb6d20d2..2287fff7 100644 --- a/src/collectives_c.c4 +++ b/src/collectives_c.c4 @@ -83,14 +83,14 @@ SHMEM_BIND_C_COLL_MIN_MAX(`SHMEM_PROF_DEF_REDUCE', `min', `SHM_INTERNAL_MIN') SHMEM_BIND_C_COLL_MIN_MAX(`SHMEM_PROF_DEF_REDUCE', `max', `SHM_INTERNAL_MAX') define(`SHMEM_PROF_DEF_EXSCAN', -`#pragma weak shmem_$1_$4_exscan = pshmem_$1_$4_exscan -#define shmem_$1_$4_exscan pshmem_$1_$4_exscan')dnl +`#pragma weak shmemx_$1_$4_exscan = pshmemx_$1_$4_exscan +#define shmemx_$1_$4_exscan pshmemx_$1_$4_exscan')dnl dnl SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_PROF_DEF_EXSCAN', `sum', `SHM_INTERNAL_SUM') define(`SHMEM_PROF_DEF_INSCAN', -`#pragma weak shmem_$1_$4_inscan = pshmem_$1_$4_inscan -#define shmem_$1_$4_inscan pshmem_$1_$4_inscan')dnl +`#pragma weak shmemx_$1_$4_inscan = pshmemx_$1_$4_inscan +#define shmemx_$1_$4_inscan pshmemx_$1_$4_inscan')dnl dnl SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_PROF_DEF_INSCAN', `sum', `SHM_INTERNAL_SUM') diff --git a/src/shmem_collectives.h b/src/shmem_collectives.h index 42ae8d3e..3bbe3610 100644 --- a/src/shmem_collectives.h +++ b/src/shmem_collectives.h @@ -242,12 +242,12 @@ void shmem_internal_scan_linear(void *target, const void *source, size_t count, int PE_start, int PE_stride, int PE_size, void *pWrk, long *pSync, shm_internal_op_t op, shm_internal_datatype_t datatype, int scantype); - + void shmem_internal_scan_ring(void *target, const void *source, size_t count, size_t type_size, int PE_start, int PE_stride, int PE_size, void *pWrk, long *pSync, shm_internal_op_t op, shm_internal_datatype_t datatype, int scantype); - + static inline void shmem_internal_exscan(void *target, const void *source, size_t count, @@ -260,19 +260,19 @@ shmem_internal_exscan(void *target, const void *source, size_t count, switch (shmem_internal_scan_type) { case AUTO: - shmem_internal_scan_linear(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 1); + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 1); break; - case LINEAR: - shmem_internal_scan_linear(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 1); + case LINEAR: + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 1); break; - case RING: - shmem_internal_scan_ring(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 1); + case RING: + shmem_internal_scan_ring(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 1); break; default: RAISE_ERROR_MSG("Illegal exscan type (%d)\n", @@ -293,19 +293,19 @@ shmem_internal_inscan(void *target, const void *source, size_t count, switch (shmem_internal_scan_type) { case AUTO: - shmem_internal_scan_linear(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 0); + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 0); break; - case LINEAR: - shmem_internal_scan_linear(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 0); + case LINEAR: + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 0); break; - case RING: - shmem_internal_scan_ring(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 0); + case RING: + shmem_internal_scan_ring(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 0); break; default: RAISE_ERROR_MSG("Illegal exscan type (%d)\n",