From 4ef621c9fc90f0cfa0a3416ab47c48b29c85020d Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sat, 5 Aug 2023 01:01:14 -0700 Subject: [PATCH 01/29] Transfer from Vc to xsimd --- .gitmodules | 3 + CMakeLists.txt | 34 ++- cmake/ArchDetect2.cmake | 243 ++++++++++++++++++ .../include/librapid/array/arrayContainer.hpp | 6 +- librapid/include/librapid/array/storage.hpp | 13 +- librapid/include/librapid/core/config.hpp | 23 +- .../include/librapid/core/librapidPch.hpp | 10 +- librapid/include/librapid/core/traits.hpp | 60 ++--- librapid/include/librapid/math/vectorImpl.hpp | 10 +- librapid/include/librapid/simd/vecOps.hpp | 4 + librapid/vendor/xsimd | 1 + 11 files changed, 317 insertions(+), 90 deletions(-) create mode 100644 cmake/ArchDetect2.cmake create mode 160000 librapid/vendor/xsimd diff --git a/.gitmodules b/.gitmodules index 3fb7a9dd..9497040a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -19,3 +19,6 @@ [submodule "librapid/vendor/CLBlast"] path = librapid/vendor/CLBlast url = https://github.com/CNugteren/CLBlast.git +[submodule "librapid/vendor/xsimd"] + path = librapid/vendor/xsimd + url = https://github.com/xtensor-stack/xsimd.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 2307cd45..30005daf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -415,19 +415,14 @@ endif () # Add dependencies add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/fmt") -add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/Vc") - -if (NOT MINGW) - # scnlib does not support MinGW, since it does not implement std::from_chars, which is required by the library - add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/scnlib") -else () - message(WARNING "[ LIBRAPID ] scnlib cannot be built by MinGW, so it will not be enabled") - target_compile_definitions(${module_name} PUBLIC LIBRAPID_MINGW) -endif () +# add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/Vc") +add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/xsimd") +add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/librapid/vendor/scnlib") target_compile_definitions(fmt PUBLIC FMT_HEADER_ONLY) -target_compile_definitions(Vc PRIVATE Vc_HACK_OSTREAM_FOR_TTY) -target_link_libraries(${module_name} PUBLIC fmt scn Vc) +# target_compile_definitions(Vc PRIVATE Vc_HACK_OSTREAM_FOR_TTY) +# target_link_libraries(${module_name} PUBLIC fmt scn Vc xsimd) +target_link_libraries(${module_name} PUBLIC fmt scn xsimd) if (${LIBRAPID_USE_MULTIPREC}) # Load MPIR @@ -484,15 +479,18 @@ if (LIBRAPID_FAST_MATH) target_compile_definitions(${module_name} PUBLIC LIBRAPID_FAST_MATH) endif () -set(LIBRAPID_ARCH_FLAGS) if (LIBRAPID_NATIVE_ARCH) message(STATUS "[ LIBRAPID ] Compiling for native architecture") - OptimizeForArchitecture() - target_compile_options(${module_name} PUBLIC ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS}) - target_compile_definitions(${module_name} PUBLIC LIBRAPID_NATIVE_ARCH) - set(LIBRAPID_ARCH_FLAGS ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS}) - message(STATUS "[ LIBRAPID ] Additional Definitions: ${Vc_DEFINITIONS}") - message(STATUS "[ LIBRAPID ] Supported flags: ${Vc_ARCHITECTURE_FLAGS}") + + include(ArchDetect2) + target_compile_options(${module_name} PUBLIC ${LIBRAPID_ARCH_FLAGS}) + +# OptimizeForArchitecture() +# target_compile_options(${module_name} PUBLIC ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS}) +# target_compile_definitions(${module_name} PUBLIC LIBRAPID_NATIVE_ARCH) +# set(LIBRAPID_ARCH_FLAGS ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS}) +# message(STATUS "[ LIBRAPID ] Additional Definitions: ${Vc_DEFINITIONS}") +# message(STATUS "[ LIBRAPID ] Supported flags: ${Vc_ARCHITECTURE_FLAGS}") endif () # Add defines for CUDA vector widths diff --git a/cmake/ArchDetect2.cmake b/cmake/ArchDetect2.cmake new file mode 100644 index 00000000..f8bba755 --- /dev/null +++ b/cmake/ArchDetect2.cmake @@ -0,0 +1,243 @@ +INCLUDE(CheckCXXSourceRuns) + +set(COMPILER_GNU false) +set(COMPILER_INTEL false) +set(COMPILER_CLANG false) +set(COMPILER_MSVC false) + +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(COMPILER_GNU true) +elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Intel") + set(COMPILER_INTEL true) +elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(COMPILER_CLANG true) +elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + set(COMPILER_MSVC true) +else () + # Unknown Compiler +endif () + +set(LIBRAPID_ARCH_FLAGS) +set(LIBRAPID_ARCH_FOUND) + +# Function to test a given SIMD capability +function(check_simd_capability FLAG_GNU FLAG_MSVC NAME TEST_SOURCE VAR) + set(CMAKE_REQUIRED_FLAGS) + if (COMPILER_GNU OR COMPILER_INTEL OR COMPILER_CLANG) + set(CMAKE_REQUIRED_FLAGS "${FLAG_GNU}") + elseif (COMPILER_MSVC) # reserve for WINDOWS + set(CMAKE_REQUIRED_FLAGS "${FLAG_MSVC}") + endif () + + CHECK_CXX_SOURCE_RUNS("${TEST_SOURCE}" ${VAR}) + + if (${${VAR}}) + if (COMPILER_GNU OR COMPILER_INTEL OR COMPILER_CLANG) + # set(LIBRAPID_ARCH_FLAGS "${LIBRAPID_ARCH_FLAGS} ${FLAG_GNU}" PARENT_SCOPE) + + list(APPEND LIBRAPID_ARCH_FLAGS ${FLAG_GNU}) + set(LIBRAPID_ARCH_FLAGS ${LIBRAPID_ARCH_FLAGS} PARENT_SCOPE) + + message(STATUS "[ LIBRAPID ] ${NAME} found: ${FLAG_GNU}") + elseif (MSVC) + # set(LIBRAPID_ARCH_FLAGS "${LIBRAPID_ARCH_FLAGS} ${FLAG_MSVC}" PARENT_SCOPE) + + list(APPEND LIBRAPID_ARCH_FLAGS ${FLAG_MSVC}) + set(LIBRAPID_ARCH_FLAGS ${LIBRAPID_ARCH_FLAGS} PARENT_SCOPE) + + message(STATUS "[ LIBRAPID ] ${NAME} found: ${FLAG_MSVC}") + endif () + set(LIBRAPID_ARCH_FOUND TRUE PARENT_SCOPE) + else () + message(STATUS "[ LIBRAPID ] ${NAME} not found") + endif () +endfunction() + +# Check SSE2 (not a valid flag for MSVC) +check_simd_capability("-msse2" "" "SSE2" " +#include +int main() { + __m128i a = _mm_set_epi32 (-1, 2, -3, 4); + __m128i result = _mm_abs_epi32 (a); + return 0; +}" SIMD_SSE2) + +# Check SSE3 (not a valid flag for MSVC) +check_simd_capability("-msse3" "" "SSE3" " +#include +int main() { + __m128 a = _mm_set_ps (-1.0f, 2.0f, -3.0f, 4.0f); + __m128 b = _mm_set_ps (1.0f, 2.0f, 3.0f, 4.0f); + __m128 result = _mm_addsub_ps (a, b); + return 0; +}" SIMD_SSE3) + +# Check SSSE3 (not a valid flag for MSVC) +check_simd_capability("-mssse3" "" "SSSE3" " +#include +int main() { + __m128i a = _mm_set_epi8(-1, 2, -3, 4, -1, 2, -3, 4, -1, 2, -3, 4, -1, 2, -3, 4); + __m128i result = _mm_abs_epi8(a); + return 0; +}" SIMD_SSSE3) + +# Check SSE4.1 (not a valid flag for MSVC) +check_simd_capability("-msse4.1" "" "SSE4.1" " +#include +int main() { + __m128i a = _mm_set_epi32(-1, 2, -3, 4); + __m128i result = _mm_abs_epi32(a); + return 0; +}" SIMD_SSE4_1) + +# Check SSE4.2 (not a valid flag for MSVC) +check_simd_capability("-msse4.2" "" "SSE4.2" " +#include +int main() { + __m128i a = _mm_set_epi32(-1, 2, -3, 4); + __m128i result = _mm_abs_epi32(a); + return 0; +}" SIMD_SSE4_2) + +# Check AVX +check_simd_capability("-mavx" "/arch:AVX" "AVX" " +#include +int main() { + __m256 a = _mm256_set_ps(-1.0f, 2.0f, -3.0f, 4.0f, -1.0f, 2.0f, -3.0f, 4.0f); + __m256 result = _mm256_abs_ps(a); + return 0; +}" SIMD_AVX) + +# Check AVX2 +check_simd_capability("-mavx2" "/arch:AVX2" "AVX2" " +#include +int main() { + __m256i a = _mm256_set_epi32(-1, 2, -3, 4, -1, 2, -3, 4); + __m256i result = _mm256_abs_epi32(a); + return 0; +}" SIMD_AVX2) + +# Check AVX512F +check_simd_capability("-mavx512f" "/arch:AVX512" "AVX512F" " +#include +int main() { + __m512i a = _mm512_set_epi32(-1, 2, -3, 4, -1, 2, -3, 4, -1, 2, -3, 4, -1, 2, -3, 4); + __m512i result = _mm512_abs_epi32(a); + return 0; +}" SIMD_AVX512F) + +# Check AVX512BW +check_simd_capability("-mavx512bw" "/arch:AVX512" "AVX512BW" " +#include +int main() { + __m512i a = _mm512_set_epi64(-1, 2, -3, 4, -1, 2, -3, 4); + __m512i result = _mm512_abs_epi8(a); + return 0; +}" SIMD_AVX512BW) + +# Check AVX512CD +check_simd_capability("-mavx512cd" "/arch:AVX512" "AVX512CD" " +#include +int main() { + __m512i a = _mm512_set_epi64(-1, 2, -3, 4, -1, 2, -3, 4); + __m512i result = _mm512_conflict_epi64(a); + return 0; +}" SIMD_AVX512CD) + +# Check AVX512DQ +check_simd_capability("-mavx512dq" "/arch:AVX512" "AVX512DQ" " +#include +int main() { + __m512d a = _mm512_set_pd(-1.0, 2.0, -3.0, 4.0, -1.0, 2.0, -3.0, 4.0); + __m512d result = _mm512_abs_pd(a); + return 0; +}" SIMD_AVX512DQ) + +# Check AVX512ER +check_simd_capability("-mavx512er" "/arch:AVX512" "AVX512ER" " +#include +int main() { + __m512d a = _mm512_set_pd(-1.0, 2.0, -3.0, 4.0, -1.0, 2.0, -3.0, 4.0); + __m512d result = _mm512_exp_pd(a); + return 0; +}" SIMD_AVX512ER) + +# Check AVX512PF +check_simd_capability("-mavx512pf" "/arch:AVX512" "AVX512PF" " +#include +int main() { + __m512 a = _mm512_set_ps(-1.0f, 2.0f, -3.0f, 4.0f, -1.0f, 2.0f, -3.0f, 4.0f); + __m512 result = _mm512_exp_ps(a); + return 0; +}" SIMD_AVX512PF) + +# ARM +check_simd_capability("-march=armv7-a" "" "ARMv7" " +#include +int main() { + int32x4_t a = vdupq_n_s32(1); + int32x4_t b = vdupq_n_s32(2); + int32x4_t result = vaddq_s32(a, b); + return 0; +}" SIMD_ARMv7) + +check_simd_capability("-march=armv8-a" "" "ARMv8" " +#include +int main() { + int32x4_t a = vdupq_n_s32(1); + int32x4_t b = vdupq_n_s32(2); + int32x4_t result = vaddq_s32(a, b); + return 0; +}" SIMD_ARMv8) + +# ARM64 +check_simd_capability("-march=armv8.1-a" "" "ARMv8.1" " +#include +int main() { + int32x4_t a = vdupq_n_s32(1); + int32x4_t b = vdupq_n_s32(2); + int32x4_t result = vaddq_s32(a, b); + return 0; +}" SIMD_ARMv8_1) + +check_simd_capability("-march=armv8.2-a" "" "ARMv8.2" " +#include +int main() { + int32x4_t a = vdupq_n_s32(1); + int32x4_t b = vdupq_n_s32(2); + int32x4_t result = vaddq_s32(a, b); + return 0; +}" SIMD_ARMv8_2) + +check_simd_capability("-march=armv8.3-a" "" "ARMv8.3" " +#include +int main() { + int32x4_t a = vdupq_n_s32(1); + int32x4_t b = vdupq_n_s32(2); + int32x4_t result = vaddq_s32(a, b); + return 0; +}" SIMD_ARMv8_3) + +check_simd_capability("-march=armv8.4-a" "" "ARMv8.4" " +#include +int main() { + int32x4_t a = vdupq_n_s32(1); + int32x4_t b = vdupq_n_s32(2); + int32x4_t result = vaddq_s32(a, b); + return 0; +}" SIMD_ARMv8_4) + +check_simd_capability("-march=armv8.5-a" "" "ARMv8.5" " +#include +int main() { + int32x4_t a = vdupq_n_s32(1); + int32x4_t b = vdupq_n_s32(2); + int32x4_t result = vaddq_s32(a, b); + return 0; +}" SIMD_ARMv8_5) + +if (LIBRAPID_ARCH_FOUND) + message(STATUS "[ LIBRAPID ] Architecture Flags: ${LIBRAPID_ARCH_FLAGS}") +else() + message(STATUS "[ LIBRAPID ] Architecture Flags Not Found") +endif() diff --git a/librapid/include/librapid/array/arrayContainer.hpp b/librapid/include/librapid/array/arrayContainer.hpp index 26940510..358025c2 100644 --- a/librapid/include/librapid/array/arrayContainer.hpp +++ b/librapid/include/librapid/array/arrayContainer.hpp @@ -668,9 +668,7 @@ namespace librapid { template auto ArrayContainer::packet(size_t index) const -> Packet { - Packet res; - res.load(m_storage.begin() + index); - return res; + return xsimd::load_aligned(m_storage.begin() + index); } template @@ -681,7 +679,7 @@ namespace librapid { template void ArrayContainer::writePacket(size_t index, const Packet &value) { - value.store(m_storage.begin() + index); + value.store_aligned(m_storage.begin() + index); } template diff --git a/librapid/include/librapid/array/storage.hpp b/librapid/include/librapid/array/storage.hpp index fff1fc46..b81f6fbb 100644 --- a/librapid/include/librapid/array/storage.hpp +++ b/librapid/include/librapid/array/storage.hpp @@ -184,11 +184,14 @@ namespace librapid { /// \param newSize New size of the Storage object LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize); -#if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE) - alignas(LIBRAPID_DEFAULT_MEM_ALIGN) Pointer m_begin = nullptr; -#else - Pointer m_begin = nullptr; // Pointer to the beginning of the data -#endif +//#if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE) +// alignas(LIBRAPID_DEFAULT_MEM_ALIGN) Pointer m_begin = nullptr; +//#else +// Pointer m_begin = nullptr; // Pointer to the beginning of the data +//#endif + + Pointer m_begin = nullptr; + SizeType m_size = 0; // Number of elements in the Storage object bool m_ownsData = true; // Whether this Storage object owns the data it points to }; diff --git a/librapid/include/librapid/core/config.hpp b/librapid/include/librapid/core/config.hpp index b1b3d178..55bef30e 100644 --- a/librapid/include/librapid/core/config.hpp +++ b/librapid/include/librapid/core/config.hpp @@ -165,62 +165,63 @@ # define LIBRAPID_AVX512 # define LIBRAPID_ARCH AVX512_2 # define LIBRAPID_ARCH_NAME "AVX512" -# define LIBRAPID_DEFAULT_MEM_ALIGN 64 +# define LIBRAPID_DEFAULT_MEM_ALIGN 256 #elif defined(__AVX512F__) || defined(__AVX512__) # define LIBRAPID_AVX512 # define LIBRAPID_ARCH AVX512 # define LIBRAPID_ARCH_NAME "AVX512" +# define LIBRAPID_DEFAULT_MEM_ALIGN 256 #elif defined(__AVX2__) # define LIBRAPID_AVX2 # define LIBRAPID_ARCH AVX2 # define LIBRAPID_ARCH_NAME "AVX2" -# define LIBRAPID_DEFAULT_MEM_ALIGN 32 +# define LIBRAPID_DEFAULT_MEM_ALIGN 128 #elif defined(__AVX__) # define LIBRAPID_AVX # define LIBRAPID_ARCH AVX # define LIBRAPID_ARCH_NAME "AVX" -# define LIBRAPID_DEFAULT_MEM_ALIGN 32 +# define LIBRAPID_DEFAULT_MEM_ALIGN 128 #elif defined(__SSE4_2__) # define LIBRAPID_SSE42 # define LIBRAPID_ARCH SSE4_2 # define LIBRAPID_ARCH_NAME "SSE4.2" -# define LIBRAPID_DEFAULT_MEM_ALIGN 16 +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(__SSE4_1__) # define LIBRAPID_SSE41 # define LIBRAPID_ARCH SSE4_1 # define LIBRAPID_ARCH_NAME "SSE4.1" -# define LIBRAPID_DEFAULT_MEM_ALIGN 16 +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(__SSSE3__) # define LIBRAPID_SSSE3 # define LIBRAPID_ARCH SSSE3 # define LIBRAPID_ARCH_NAME "SSSE3" -# define LIBRAPID_DEFAULT_MEM_ALIGN 16 +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(__SSE3__) # define LIBRAPID_SSE3 # define LIBRAPID_ARCH SSE3 # define LIBRAPID_ARCH_NAME "SSE3" -# define LIBRAPID_DEFAULT_MEM_ALIGN 16 +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(__SSE2__) || defined(__x86_64__) # define LIBRAPID_SSE2 # define LIBRAPID_ARCH SSE2 # define LIBRAPID_ARCH_NAME "SSE2" -# define LIBRAPID_DEFAULT_MEM_ALIGN 16 +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(__SSE__) # define LIBRAPID_SSE # define LIBRAPID_ARCH SSE # define LIBRAPID_ARCH_NAME "SSE" -# define LIBRAPID_DEFAULT_MEM_ALIGN 16 +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(_M_IX86_FP) // Defined in MS compiler. 1: SSE, 2: SSE2 # if _M_IX86_FP == 1 # define LIBRAPID_SSE # define LIBRAPID_ARCH SSE # define LIBRAPID_ARCH_NAME "SSE" -# define LIBRAPID_DEFAULT_MEM_ALIGN 16 +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 # elif _M_IX86_FP == 2 # define LIBRAPID_SSE2 # define LIBRAPID_ARCH SSE2 # define LIBRAPID_ARCH_NAME "SSE2" -# define LIBRAPID_DEFAULT_MEM_ALIGN 16 +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 # endif // _M_IX86_FP #else # define LIBRAPID_ARCH 0 diff --git a/librapid/include/librapid/core/librapidPch.hpp b/librapid/include/librapid/core/librapidPch.hpp index c552bcd2..d8457401 100644 --- a/librapid/include/librapid/core/librapidPch.hpp +++ b/librapid/include/librapid/core/librapidPch.hpp @@ -71,10 +71,12 @@ # pragma warning(disable : 4127) // conditional expression is constant #endif -#include -#include -#include -#include +// #include +// #include +// #include +// #include + +#include #if defined(_MSC_VER) # pragma warning(pop) diff --git a/librapid/include/librapid/core/traits.hpp b/librapid/include/librapid/core/traits.hpp index e60295b4..9d5125e1 100644 --- a/librapid/include/librapid/core/traits.hpp +++ b/librapid/include/librapid/core/traits.hpp @@ -211,9 +211,9 @@ namespace librapid { struct TypeInfo { static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; using Scalar = int8_t; - using Packet = Vc::Vector; + using Packet = xsimd::batch; using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size(); + static constexpr int64_t packetWidth = Packet::size; static constexpr char name[] = "int8_t"; static constexpr bool supportsArithmetic = true; static constexpr bool supportsLogical = true; @@ -242,9 +242,9 @@ namespace librapid { struct TypeInfo { static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; using Scalar = uint8_t; - using Packet = Vc::Vector; + using Packet = xsimd::batch; using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size(); + static constexpr int64_t packetWidth = Packet::size; static constexpr char name[] = "uint8_t"; static constexpr bool supportsArithmetic = true; static constexpr bool supportsLogical = true; @@ -273,9 +273,9 @@ namespace librapid { struct TypeInfo { static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; using Scalar = int16_t; - using Packet = Vc::Vector; + using Packet = xsimd::batch; using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size(); + static constexpr int64_t packetWidth = Packet::size; static constexpr char name[] = "int16_t"; static constexpr bool supportsArithmetic = true; static constexpr bool supportsLogical = true; @@ -304,9 +304,9 @@ namespace librapid { struct TypeInfo { static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; using Scalar = uint16_t; - using Packet = Vc::Vector; + using Packet = xsimd::batch; using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size(); + static constexpr int64_t packetWidth = Packet::size; static constexpr char name[] = "uint16_t"; static constexpr bool supportsArithmetic = true; static constexpr bool supportsLogical = true; @@ -335,9 +335,9 @@ namespace librapid { struct TypeInfo { static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; using Scalar = int32_t; - using Packet = Vc::Vector; + using Packet = xsimd::batch; using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size(); + static constexpr int64_t packetWidth = Packet::size; static constexpr char name[] = "int32_t"; static constexpr bool supportsArithmetic = true; static constexpr bool supportsLogical = true; @@ -366,9 +366,9 @@ namespace librapid { struct TypeInfo { static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; using Scalar = uint32_t; - using Packet = Vc::Vector; + using Packet = xsimd::batch; using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size(); + static constexpr int64_t packetWidth = Packet::size; static constexpr char name[] = "uint32_t"; static constexpr bool supportsArithmetic = true; static constexpr bool supportsLogical = true; @@ -459,9 +459,9 @@ namespace librapid { struct TypeInfo { static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; using Scalar = float; - using Packet = Vc::Vector; + using Packet = xsimd::batch; using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size(); + static constexpr int64_t packetWidth = Packet::size; static constexpr char name[] = "float"; static constexpr bool supportsArithmetic = true; static constexpr bool supportsLogical = true; @@ -490,9 +490,9 @@ namespace librapid { struct TypeInfo { static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; using Scalar = double; - using Packet = Vc::Vector; + using Packet = xsimd::batch; using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size(); + static constexpr int64_t packetWidth = Packet::size; static constexpr char name[] = "double"; static constexpr bool supportsArithmetic = true; static constexpr bool supportsLogical = true; @@ -518,7 +518,7 @@ namespace librapid { }; template - struct TypeInfo> { + struct TypeInfo> { static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; using Scalar = T; using Packet = std::false_type; @@ -736,32 +736,6 @@ namespace librapid { }; #endif - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = typename VectorType::EntryType; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "Vc::ElementReference"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = false; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = false; - - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = false; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - template<> struct TypeInfo { static constexpr char name[] = "CPU"; diff --git a/librapid/include/librapid/math/vectorImpl.hpp b/librapid/include/librapid/math/vectorImpl.hpp index cdf97417..9ba60f27 100644 --- a/librapid/include/librapid/math/vectorImpl.hpp +++ b/librapid/include/librapid/math/vectorImpl.hpp @@ -756,11 +756,11 @@ namespace librapid { return val; } - template - constexpr auto scalarExtractor(const Vc_1::Detail::ElementReference &val) { - using Scalar = typename Vc_1::Detail::ElementReference::value_type; - return static_cast(val); - } + // template + // constexpr auto scalarExtractor(const Vc_1::Detail::ElementReference &val) { + // using Scalar = typename Vc_1::Detail::ElementReference::value_type; + // return static_cast(val); + // } template constexpr auto scalarVectorCaster(const T &val) { diff --git a/librapid/include/librapid/simd/vecOps.hpp b/librapid/include/librapid/simd/vecOps.hpp index d830168f..1dd8199b 100644 --- a/librapid/include/librapid/simd/vecOps.hpp +++ b/librapid/include/librapid/simd/vecOps.hpp @@ -1,6 +1,8 @@ #ifndef LIBRAPID_SIMD_TRIGONOMETRY #define LIBRAPID_SIMD_TRIGONOMETRY +#if 0 + namespace librapid { namespace typetraits { template @@ -215,4 +217,6 @@ namespace librapid { } } // namespace librapid +#endif + #endif // LIBRAPID_SIMD_TRIGONOMETRY \ No newline at end of file diff --git a/librapid/vendor/xsimd b/librapid/vendor/xsimd new file mode 160000 index 00000000..e6fa5aca --- /dev/null +++ b/librapid/vendor/xsimd @@ -0,0 +1 @@ +Subproject commit e6fa5aca6320d6ccaf24c123ab2af9b0f2f09cc1 From c588da59bcc11cc6b0a46465be024df96c63b248 Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sat, 5 Aug 2023 02:14:05 -0700 Subject: [PATCH 02/29] Switch SIMD backend to xsimd --- .gitmodules | 5 +- CMakeLists.txt | 1 + examples/example-vector-1.cpp | 2 + .../librapid/array/linalg/arrayMultiply.hpp | 19 ++-- .../include/librapid/array/operations.hpp | 5 +- librapid/include/librapid/array/storage.hpp | 21 ++-- librapid/include/librapid/math/vectorImpl.hpp | 10 +- librapid/include/librapid/simd/vecOps.hpp | 97 ++++++++++--------- librapid/vendor/Vc | 1 - 9 files changed, 79 insertions(+), 82 deletions(-) delete mode 160000 librapid/vendor/Vc diff --git a/.gitmodules b/.gitmodules index 9497040a..cec183e8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,6 @@ [submodule "librapid/vendor/jitify"] path = librapid/vendor/jitify url = https://github.com/Pencilcaseman/jitify.git -[submodule "librapid/vendor/Vc"] - path = librapid/vendor/Vc - url = https://github.com/Pencilcaseman/Vc.git [submodule "librapid/vendor/fmt"] path = librapid/vendor/fmt url = https://github.com/fmtlib/fmt.git @@ -21,4 +18,4 @@ url = https://github.com/CNugteren/CLBlast.git [submodule "librapid/vendor/xsimd"] path = librapid/vendor/xsimd - url = https://github.com/xtensor-stack/xsimd.git + url = https://github.com/xtensor-stack/xsimd.git \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 30005daf..684fc3b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -484,6 +484,7 @@ if (LIBRAPID_NATIVE_ARCH) include(ArchDetect2) target_compile_options(${module_name} PUBLIC ${LIBRAPID_ARCH_FLAGS}) + target_compile_definitions(${module_name} PUBLIC LIBRAPID_NATIVE_ARCH) # OptimizeForArchitecture() # target_compile_options(${module_name} PUBLIC ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS}) diff --git a/examples/example-vector-1.cpp b/examples/example-vector-1.cpp index 6207b6f6..39acf258 100644 --- a/examples/example-vector-1.cpp +++ b/examples/example-vector-1.cpp @@ -5,6 +5,7 @@ namespace lrc = librapid; auto main() -> int { fmt::print("LibRapid Example -- Vector 1\n"); +#if 0 // Currently broken -- switching SIMD backend // Create a 3 dimensional vector lrc::Vec3d myVector(2, 3, 4); lrc::Vec3d myOtherVector(10, 5, 8); @@ -52,6 +53,7 @@ auto main() -> int { fmt::print("One vector: {}\n", one); fmt::print("Full vector: {}\n", full); fmt::print("Random vector: {:.3f}\n", random); +#endif return 0; } \ No newline at end of file diff --git a/librapid/include/librapid/array/linalg/arrayMultiply.hpp b/librapid/include/librapid/array/linalg/arrayMultiply.hpp index 98ea19d7..386658c9 100644 --- a/librapid/include/librapid/array/linalg/arrayMultiply.hpp +++ b/librapid/include/librapid/array/linalg/arrayMultiply.hpp @@ -675,15 +675,22 @@ namespace librapid { namespace typetraits { template - struct TypeInfo> { + struct TypeInfo< + linalg::ArrayMultiply> { detail::LibRapidType type = detail::LibRapidType::ArrayFunction; - using Type = linalg::ArrayMultiply; + using Type = linalg::ArrayMultiply; using Scalar = typename Type::Scalar; - using Backend = typename Type::Backend; + using Backend = typename Type::Backend; + static constexpr bool allowVectorisation = false; }; - } + + LIBRAPID_DEFINE_AS_TYPE(typename ShapeTypeA COMMA typename StorageTypeA COMMA + typename ShapeTypeB COMMA typename StorageTypeB COMMA + typename Alpha COMMA typename Beta, + linalg::ArrayMultiply); + } // namespace typetraits } // namespace librapid LIBRAPID_SIMPLE_IO_IMPL( diff --git a/librapid/include/librapid/array/operations.hpp b/librapid/include/librapid/array/operations.hpp index a859ab36..690dab6d 100644 --- a/librapid/include/librapid/array/operations.hpp +++ b/librapid/include/librapid/array/operations.hpp @@ -27,10 +27,7 @@ template \ LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto packet(const Packet &lhs, \ const Packet &rhs) const { \ - auto mask = lhs OP_ rhs; \ - Packet res(1); \ - res.setZero(!mask); \ - return res; \ + return Packet(lhs OP_ rhs); \ } \ } diff --git a/librapid/include/librapid/array/storage.hpp b/librapid/include/librapid/array/storage.hpp index b81f6fbb..0af75de7 100644 --- a/librapid/include/librapid/array/storage.hpp +++ b/librapid/include/librapid/array/storage.hpp @@ -184,11 +184,11 @@ namespace librapid { /// \param newSize New size of the Storage object LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize); -//#if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE) -// alignas(LIBRAPID_DEFAULT_MEM_ALIGN) Pointer m_begin = nullptr; -//#else -// Pointer m_begin = nullptr; // Pointer to the beginning of the data -//#endif + // #if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE) + // alignas(LIBRAPID_DEFAULT_MEM_ALIGN) Pointer m_begin = nullptr; + // #else + // Pointer m_begin = nullptr; // Pointer to the beginning of the data + // #endif Pointer m_begin = nullptr; @@ -368,21 +368,16 @@ namespace librapid { // MKL has its own memory allocation function auto ptr = static_cast(mkl_malloc(size * sizeof(T), 64)); #else -# if defined(LIBRAPID_NATIVE_ARCH) // Force aligned memory -# if defined(LIBRAPID_APPLE) +# if defined(LIBRAPID_APPLE) // No memory allignment. It breaks everything for some reason auto ptr = static_cast(std::malloc(size * sizeof(T))); -# elif defined(LIBRAPID_MSVC) || defined(LIBRAPID_MINGW) +# elif defined(LIBRAPID_MSVC) || defined(LIBRAPID_MINGW) auto ptr = static_cast(_aligned_malloc(size * sizeof(T), global::memoryAlignment)); -# else +# else auto ptr = static_cast( std::aligned_alloc(global::memoryAlignment, size * sizeof(T))); -# endif -# else - // No memory alignment - auto ptr = static_cast(std::malloc(size * sizeof(T))); # endif #endif diff --git a/librapid/include/librapid/math/vectorImpl.hpp b/librapid/include/librapid/math/vectorImpl.hpp index 9ba60f27..e4b3e7c6 100644 --- a/librapid/include/librapid/math/vectorImpl.hpp +++ b/librapid/include/librapid/math/vectorImpl.hpp @@ -41,8 +41,8 @@ namespace librapid { static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; using Scalar = T; using Packet = typename TypeInfo::Packet; - using IndexType = typename std::decay_t()[0])>; - using IndexTypeConst = typename std::decay_t()[0])>; + using IndexType = Scalar &; + using IndexTypeConst = const Scalar &; using GetType = const Packet &; using StorageType = vectorDetail::SimdVectorStorage; @@ -279,7 +279,7 @@ namespace librapid { length); const int64_t packetIndex = index / packetWidth; const int64_t elementIndex = index % packetWidth; - return data[packetIndex][elementIndex]; + return data[packetIndex].get(elementIndex); } LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) { @@ -287,9 +287,7 @@ namespace librapid { "Index {} out of bounds for Vector of length {}", index, length); - const int64_t packetIndex = index / packetWidth; - const int64_t elementIndex = index % packetWidth; - return data[packetIndex][elementIndex]; + static_assert(false, "Not implemented"); } LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sum() const -> Scalar { diff --git a/librapid/include/librapid/simd/vecOps.hpp b/librapid/include/librapid/simd/vecOps.hpp index 1dd8199b..cb3ca605 100644 --- a/librapid/include/librapid/simd/vecOps.hpp +++ b/librapid/include/librapid/simd/vecOps.hpp @@ -1,18 +1,13 @@ #ifndef LIBRAPID_SIMD_TRIGONOMETRY #define LIBRAPID_SIMD_TRIGONOMETRY -#if 0 - namespace librapid { namespace typetraits { template struct IsSIMD : std::false_type {}; - // template - // struct IsSIMD> : std::true_type {}; - template - struct IsSIMD> : std::true_type {}; + struct IsSIMD> : std::true_type {}; } // namespace typetraits #define REQUIRE_SIMD(TYPE) typename std::enable_if_t::value, int> = 0 @@ -20,41 +15,47 @@ namespace librapid { template auto sin(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::sin(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = sin(static_cast(x[i])); } - return result; - } +// using Scalar = typename T::value_type; +// IF_FLOATING(T) { return xsimd::sin(x); } +// else { +// T result; +// for (int i = 0; i < x.size(); ++i) { result[i] = sin(static_cast(x[i])); } +// return result; +// } + + return xsimd::sin(x); } template auto cos(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::cos(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = cos(static_cast(x[i])); } - return result; - } +// using Scalar = typename T::value_type; +// IF_FLOATING(T) { return xsimd::cos(x); } +// else { +// T result; +// for (int i = 0; i < x.size(); ++i) { result[i] = cos(static_cast(x[i])); } +// return result; +// } + + return xsimd::cos(x); } template auto tan(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::sin(x) / Vc::cos(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = tan(static_cast(x[i])); } - return result; - } +// using Scalar = typename T::value_type; +// IF_FLOATING(T) { return xsimd::sin(x) / xsimd::cos(x); } +// else { +// T result; +// for (int i = 0; i < x.size(); ++i) { result[i] = tan(static_cast(x[i])); } +// return result; +// } + + return xsimd::tan(x); } template auto asin(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::asin(x); } + IF_FLOATING(T) { return xsimd::asin(x); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = asin(static_cast(x[i])); } @@ -66,8 +67,8 @@ namespace librapid { auto acos(const T &x) -> T { using Scalar = typename T::value_type; IF_FLOATING(T) { - static const auto asin1 = Vc::asin(T(1)); - return asin1 - Vc::asin(x); + static const auto asin1 = xsimd::asin(T(1)); + return asin1 - xsimd::asin(x); } else { T result; @@ -79,7 +80,7 @@ namespace librapid { template auto atan(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::atan(x); } + IF_FLOATING(T) { return xsimd::atan(x); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = atan(static_cast(x[i])); } @@ -90,7 +91,7 @@ namespace librapid { template auto sinh(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return (Vc::exp(x) - Vc::exp(-x)) * T(0.5); } + IF_FLOATING(T) { return (xsimd::exp(x) - xsimd::exp(-x)) * T(0.5); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = sinh(static_cast(x[i])); } @@ -101,7 +102,7 @@ namespace librapid { template auto cosh(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return (Vc::exp(x) + Vc::exp(-x)) * T(0.5); } + IF_FLOATING(T) { return (xsimd::exp(x) + xsimd::exp(-x)) * T(0.5); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = cosh(static_cast(x[i])); } @@ -112,7 +113,7 @@ namespace librapid { template auto tanh(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return (Vc::exp(2 * x) - 1) / (Vc::exp(2 * x) + 1); } + IF_FLOATING(T) { return (xsimd::exp(2 * x) - 1) / (xsimd::exp(2 * x) + 1); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = tanh(static_cast(x[i])); } @@ -123,7 +124,7 @@ namespace librapid { template auto exp(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::exp(x); } + IF_FLOATING(T) { return xsimd::exp(x); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = exp(static_cast(x[i])); } @@ -134,7 +135,7 @@ namespace librapid { template auto log(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::log(x); } + IF_FLOATING(T) { return xsimd::log(x); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = log(static_cast(x[i])); } @@ -145,7 +146,7 @@ namespace librapid { template auto log2(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::log2(x); } + IF_FLOATING(T) { return xsimd::log2(x); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = log2(static_cast(x[i])); } @@ -156,7 +157,7 @@ namespace librapid { template auto log10(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::log10(x); } + IF_FLOATING(T) { return xsimd::log10(x); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = log10(static_cast(x[i])); } @@ -167,7 +168,7 @@ namespace librapid { template auto sqrt(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::sqrt(x); } + IF_FLOATING(T) { return xsimd::sqrt(x); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = sqrt(static_cast(x[i])); } @@ -177,16 +178,18 @@ namespace librapid { template auto cbrt(const T &x) -> T { - using Scalar = typename T::value_type; - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = cbrt(static_cast(x[i])); } - return result; + // using Scalar = typename T::value_type; + // T result; + // for (int i = 0; i < x.size(); ++i) { result[i] = cbrt(static_cast(x[i])); } + // return result; + + return xsimd::cbrt(x); } template auto abs(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::abs(x); } + IF_FLOATING(T) { return xsimd::abs(x); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = abs(static_cast(x[i])); } @@ -197,7 +200,7 @@ namespace librapid { template auto floor(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::floor(x); } + IF_FLOATING(T) { return xsimd::floor(x); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = floor(static_cast(x[i])); } @@ -208,7 +211,7 @@ namespace librapid { template auto ceil(const T &x) -> T { using Scalar = typename T::value_type; - IF_FLOATING(T) { return Vc::ceil(x); } + IF_FLOATING(T) { return xsimd::ceil(x); } else { T result; for (int i = 0; i < x.size(); ++i) { result[i] = ceil(static_cast(x[i])); } @@ -217,6 +220,4 @@ namespace librapid { } } // namespace librapid -#endif - #endif // LIBRAPID_SIMD_TRIGONOMETRY \ No newline at end of file diff --git a/librapid/vendor/Vc b/librapid/vendor/Vc deleted file mode 160000 index ce37915d..00000000 --- a/librapid/vendor/Vc +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ce37915d1a897fe5160b69923a1c04275e388f19 From 6cf0961bc265a5280cf2e9e8794eda22afdf8bf8 Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sat, 5 Aug 2023 02:30:56 -0700 Subject: [PATCH 03/29] Update dual number library (UNTESTED) --- librapid/include/librapid/autodiff/dual.hpp | 38 +++++++++++++++++---- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/librapid/include/librapid/autodiff/dual.hpp b/librapid/include/librapid/autodiff/dual.hpp index 5af75108..5d5f8fb2 100644 --- a/librapid/include/librapid/autodiff/dual.hpp +++ b/librapid/include/librapid/autodiff/dual.hpp @@ -23,6 +23,9 @@ namespace librapid { using Scalar = typename typetraits::TypeInfo::Scalar; #endif + using Packet = typename typetraits::TypeInfo::Packet; + static constexpr uint64_t packetWidth = typetraits::TypeInfo::packetWidth; + Dual() = default; explicit Dual(T value) : value(value), derivative(T()) {} Dual(T value, T derivative) : value(value), derivative(derivative) {} @@ -52,16 +55,37 @@ namespace librapid { template LIBRAPID_ALWAYS_INLINE void store(P *ptr) const { - auto casted = reinterpret_cast(ptr); - auto ret = Vc::interleave(value, derivative); - ret.first.store(casted); - ret.second.store(casted + size()); + // Load the data into batches. + auto casted = reinterpret_cast(ptr); + + // Compute interleaved values. + std::array interleaved; + for (std::size_t i = 0; i < packetWidth; ++i) { + interleaved[2 * i] = value.get(i); + interleaved[2 * i + 1] = derivative.get(i); + } + + // Store the interleaved values back to memory. + std::copy(interleaved.begin(), interleaved.end(), casted); } template LIBRAPID_ALWAYS_INLINE void load(const P *ptr) { + // auto casted = reinterpret_cast(ptr); + // Vc::deinterleave(&value, &derivative, casted, Vc::Aligned); + + // Load the data into batches. auto casted = reinterpret_cast(ptr); - Vc::deinterleave(&value, &derivative, casted, Vc::Aligned); + + // Compute interleaved values. + std::array interleaved; + std::copy(casted, casted + 2 * packetWidth, interleaved.begin()); + + // Store the interleaved values back to memory. + for (std::size_t i = 0; i < packetWidth; ++i) { + value.set(i, interleaved[2 * i]); + derivative.set(i, interleaved[2 * i + 1]); + } } LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+=(const Dual &other) { @@ -363,7 +387,7 @@ namespace librapid { struct TypeInfo> { static constexpr detail::LibRapidType type = detail::LibRapidType::Dual; using Scalar = T; - using Packet = Dual::Packet>; + using Packet = Dual::Packet>; static constexpr int64_t packetWidth = TypeInfo::Scalar>::packetWidth; using Backend = backend::CPU; @@ -397,7 +421,7 @@ namespace librapid { struct TypeInfo> { static constexpr detail::LibRapidType type = detail::LibRapidType::Dual; using Scalar = float; - using Packet = Dual::Packet>; + using Packet = Dual::Packet>; static constexpr int64_t packetWidth = TypeInfo::Scalar>::packetWidth; using Backend = backend::CPU; From cda9f72d7a2dfe505ea0978973fe5dafaaab25ca Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sat, 5 Aug 2023 17:02:04 -0700 Subject: [PATCH 04/29] Removed submodule librapid/vendor/xsimd --- .gitmodules | 5 +---- librapid/vendor/xsimd | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) delete mode 160000 librapid/vendor/xsimd diff --git a/.gitmodules b/.gitmodules index cec183e8..e9ac46b2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -15,7 +15,4 @@ url = https://github.com/mreineck/pocketfft [submodule "librapid/vendor/CLBlast"] path = librapid/vendor/CLBlast - url = https://github.com/CNugteren/CLBlast.git -[submodule "librapid/vendor/xsimd"] - path = librapid/vendor/xsimd - url = https://github.com/xtensor-stack/xsimd.git \ No newline at end of file + url = https://github.com/CNugteren/CLBlast.git \ No newline at end of file diff --git a/librapid/vendor/xsimd b/librapid/vendor/xsimd deleted file mode 160000 index e6fa5aca..00000000 --- a/librapid/vendor/xsimd +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e6fa5aca6320d6ccaf24c123ab2af9b0f2f09cc1 From 96d28f2ab141a977e175982910ca4f15b215d90e Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sat, 5 Aug 2023 22:57:33 -0700 Subject: [PATCH 05/29] Continue transfer to xsimd --- .gitmodules | 5 +- librapid/include/librapid/autodiff/dual.hpp | 92 +++++----- librapid/include/librapid/core/traits.hpp | 32 +++- librapid/include/librapid/math/complex.hpp | 34 ++-- librapid/include/librapid/simd/vecOps.hpp | 183 ++++---------------- librapid/vendor/xsimd | 1 + 6 files changed, 133 insertions(+), 214 deletions(-) create mode 160000 librapid/vendor/xsimd diff --git a/.gitmodules b/.gitmodules index e9ac46b2..565e022c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -15,4 +15,7 @@ url = https://github.com/mreineck/pocketfft [submodule "librapid/vendor/CLBlast"] path = librapid/vendor/CLBlast - url = https://github.com/CNugteren/CLBlast.git \ No newline at end of file + url = https://github.com/CNugteren/CLBlast.git +[submodule "librapid/vendor/xsimd"] + path = librapid/vendor/xsimd + url = https://github.com/LibRapid/xsimd.git diff --git a/librapid/include/librapid/autodiff/dual.hpp b/librapid/include/librapid/autodiff/dual.hpp index 5d5f8fb2..1e368d83 100644 --- a/librapid/include/librapid/autodiff/dual.hpp +++ b/librapid/include/librapid/autodiff/dual.hpp @@ -18,14 +18,15 @@ namespace librapid { T derivative; #if defined(LIBRAPID_IN_JITIFY) - using Scalar = T; + using Scalar = T; + using Packet = T; + static constexpr uint64_t packetWidth = 1; #else - using Scalar = typename typetraits::TypeInfo::Scalar; + using Scalar = typename typetraits::TypeInfo::Scalar; + using Packet = typename typetraits::TypeInfo::Packet; + static constexpr uint64_t packetWidth = typetraits::TypeInfo::packetWidth; #endif - using Packet = typename typetraits::TypeInfo::Packet; - static constexpr uint64_t packetWidth = typetraits::TypeInfo::packetWidth; - Dual() = default; explicit Dual(T value) : value(value), derivative(T()) {} Dual(T value, T derivative) : value(value), derivative(derivative) {} @@ -53,40 +54,40 @@ namespace librapid { static constexpr size_t size() { return typetraits::TypeInfo::packetWidth; } - template - LIBRAPID_ALWAYS_INLINE void store(P *ptr) const { - // Load the data into batches. - auto casted = reinterpret_cast(ptr); - - // Compute interleaved values. - std::array interleaved; - for (std::size_t i = 0; i < packetWidth; ++i) { - interleaved[2 * i] = value.get(i); - interleaved[2 * i + 1] = derivative.get(i); - } - - // Store the interleaved values back to memory. - std::copy(interleaved.begin(), interleaved.end(), casted); - } - - template - LIBRAPID_ALWAYS_INLINE void load(const P *ptr) { - // auto casted = reinterpret_cast(ptr); - // Vc::deinterleave(&value, &derivative, casted, Vc::Aligned); - - // Load the data into batches. - auto casted = reinterpret_cast(ptr); - - // Compute interleaved values. - std::array interleaved; - std::copy(casted, casted + 2 * packetWidth, interleaved.begin()); - - // Store the interleaved values back to memory. - for (std::size_t i = 0; i < packetWidth; ++i) { - value.set(i, interleaved[2 * i]); - derivative.set(i, interleaved[2 * i + 1]); - } - } + // template + // LIBRAPID_ALWAYS_INLINE void store(P *ptr) const { + // // Load the data into batches. + // auto casted = reinterpret_cast(ptr); + // + // // Compute interleaved values. + // std::array interleaved; + // for (std::size_t i = 0; i < packetWidth; ++i) { + // interleaved[2 * i] = value.get(i); + // interleaved[2 * i + 1] = derivative.get(i); + // } + // + // // Store the interleaved values back to memory. + // std::copy(interleaved.begin(), interleaved.end(), casted); + // } + + // template + // LIBRAPID_ALWAYS_INLINE void load(const P *ptr) { + // // auto casted = reinterpret_cast(ptr); + // // Vc::deinterleave(&value, &derivative, casted, Vc::Aligned); + // + // // Load the data into batches. + // auto casted = reinterpret_cast(ptr); + // + // // Compute interleaved values. + // std::array interleaved; + // std::copy(casted, casted + 2 * packetWidth, interleaved.begin()); + // + // // Store the interleaved values back to memory. + // for (std::size_t i = 0; i < packetWidth; ++i) { + // value.set(i, interleaved[2 * i]); + // derivative.set(i, interleaved[2 * i + 1]); + // } + // } LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+=(const Dual &other) { value += other.value; @@ -387,9 +388,9 @@ namespace librapid { struct TypeInfo> { static constexpr detail::LibRapidType type = detail::LibRapidType::Dual; using Scalar = T; - using Packet = Dual::Packet>; + using Packet = std::false_type; // Dual::Packet>; static constexpr int64_t packetWidth = - TypeInfo::Scalar>::packetWidth; + 0; // TypeInfo::Scalar>::packetWidth; using Backend = backend::CPU; static constexpr char name[] = "Dual_T"; @@ -397,7 +398,7 @@ namespace librapid { static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; static constexpr bool supportsLogical = TypeInfo::supportsLogical; static constexpr bool supportsBinary = TypeInfo::supportsBinary; - static constexpr bool allowVectorisation = TypeInfo::allowVectorisation; + static constexpr bool allowVectorisation = false; // TypeInfo::allowVectorisation; # if defined(LIBRAPID_HAS_CUDA) static constexpr cudaDataType_t CudaType = TypeInfo::CudaType; @@ -421,9 +422,9 @@ namespace librapid { struct TypeInfo> { static constexpr detail::LibRapidType type = detail::LibRapidType::Dual; using Scalar = float; - using Packet = Dual::Packet>; + using Packet = std::false_type; // Dual::Packet>; static constexpr int64_t packetWidth = - TypeInfo::Scalar>::packetWidth; + 0; // TypeInfo::Scalar>::packetWidth; using Backend = backend::CPU; static constexpr char name[] = "Dual_float"; @@ -431,7 +432,8 @@ namespace librapid { static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; static constexpr bool supportsLogical = TypeInfo::supportsLogical; static constexpr bool supportsBinary = TypeInfo::supportsBinary; - static constexpr bool allowVectorisation = TypeInfo::allowVectorisation; + static constexpr bool allowVectorisation = + false; // TypeInfo::allowVectorisation; # if defined(LIBRAPID_HAS_CUDA) static constexpr cudaDataType_t CudaType = TypeInfo::CudaType; diff --git a/librapid/include/librapid/core/traits.hpp b/librapid/include/librapid/core/traits.hpp index 9d5125e1..a0d0b674 100644 --- a/librapid/include/librapid/core/traits.hpp +++ b/librapid/include/librapid/core/traits.hpp @@ -736,17 +736,43 @@ namespace librapid { }; #endif + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = typename xsimd::batch_element_reference::Scalar; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "xsimd::batch_element_reference"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = false; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = false; + + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = false; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + template<> struct TypeInfo { static constexpr char name[] = "CPU"; - using Backend = backend::CPU; + using Backend = backend::CPU; }; #if defined(LIBRAPID_HAS_OPENCL) template<> struct TypeInfo { static constexpr char name[] = "OpenCL"; - using Backend = backend::OpenCL; + using Backend = backend::OpenCL; }; #endif @@ -754,7 +780,7 @@ namespace librapid { template<> struct TypeInfo { static constexpr char name[] = "CUDA"; - using Backend = backend::CUDA; + using Backend = backend::CUDA; }; #endif diff --git a/librapid/include/librapid/math/complex.hpp b/librapid/include/librapid/math/complex.hpp index 8637f996..abb74f66 100644 --- a/librapid/include/librapid/math/complex.hpp +++ b/librapid/include/librapid/math/complex.hpp @@ -480,19 +480,19 @@ namespace librapid { return *this; } - template - LIBRAPID_ALWAYS_INLINE void store(P *ptr) const { - auto casted = reinterpret_cast(ptr); - auto ret = Vc::interleave(m_val[RE], m_val[IM]); - ret.first.store(casted); - ret.second.store(casted + size()); - } - - template - LIBRAPID_ALWAYS_INLINE void load(const P *ptr) { - auto casted = reinterpret_cast(ptr); - Vc::deinterleave(&m_val[RE], &m_val[IM], casted, Vc::Aligned); - } + // template + // LIBRAPID_ALWAYS_INLINE void store(P *ptr) const { + // auto casted = reinterpret_cast(ptr); + // auto ret = Vc::interleave(m_val[RE], m_val[IM]); + // ret.first.store(casted); + // ret.second.store(casted + size()); + // } + + // template + // LIBRAPID_ALWAYS_INLINE void load(const P *ptr) { + // auto casted = reinterpret_cast(ptr); + // Vc::deinterleave(&m_val[RE], &m_val[IM], casted, Vc::Aligned); + // } /// \brief Assign to the real component /// @@ -2047,11 +2047,11 @@ namespace librapid { struct TypeInfo> { static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; using Scalar = Complex; - using Packet = - typename std::conditional_t<(TypeInfo::packetWidth > 1), - Complex::Packet>, std::false_type>; + using Packet = std::false_type; + // typename std::conditional_t<(TypeInfo::packetWidth > 1), + // Complex::Packet>, std::false_type>; static constexpr int64_t packetWidth = - TypeInfo::Scalar>::packetWidth; + 0; // TypeInfo::Scalar>::packetWidth; static constexpr char name[] = "Complex"; static constexpr bool supportsArithmetic = true; static constexpr bool supportsLogical = true; diff --git a/librapid/include/librapid/simd/vecOps.hpp b/librapid/include/librapid/simd/vecOps.hpp index cb3ca605..cc3d285f 100644 --- a/librapid/include/librapid/simd/vecOps.hpp +++ b/librapid/include/librapid/simd/vecOps.hpp @@ -8,215 +8,102 @@ namespace librapid { template struct IsSIMD> : std::true_type {}; + + template + struct IsSIMD> : std::true_type {}; } // namespace typetraits #define REQUIRE_SIMD(TYPE) typename std::enable_if_t::value, int> = 0 #define IF_FLOATING(TYPE) if constexpr (std::is_floating_point_v) template - auto sin(const T &x) -> T { -// using Scalar = typename T::value_type; -// IF_FLOATING(T) { return xsimd::sin(x); } -// else { -// T result; -// for (int i = 0; i < x.size(); ++i) { result[i] = sin(static_cast(x[i])); } -// return result; -// } - + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sin(const T &x) { return xsimd::sin(x); } template - auto cos(const T &x) -> T { -// using Scalar = typename T::value_type; -// IF_FLOATING(T) { return xsimd::cos(x); } -// else { -// T result; -// for (int i = 0; i < x.size(); ++i) { result[i] = cos(static_cast(x[i])); } -// return result; -// } - + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cos(const T &x) { return xsimd::cos(x); } template - auto tan(const T &x) -> T { -// using Scalar = typename T::value_type; -// IF_FLOATING(T) { return xsimd::sin(x) / xsimd::cos(x); } -// else { -// T result; -// for (int i = 0; i < x.size(); ++i) { result[i] = tan(static_cast(x[i])); } -// return result; -// } - + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto tan(const T &x) { return xsimd::tan(x); } template - auto asin(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return xsimd::asin(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = asin(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto asin(const T &x) { + return xsimd::asin(x); } template - auto acos(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { - static const auto asin1 = xsimd::asin(T(1)); - return asin1 - xsimd::asin(x); - } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = acos(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto acos(const T &x) { + return xsimd::acos(x); } template - auto atan(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return xsimd::atan(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = atan(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto atan(const T &x) { + return xsimd::atan(x); } template - auto sinh(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return (xsimd::exp(x) - xsimd::exp(-x)) * T(0.5); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = sinh(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sinh(const T &x) { + return xsimd::asinh(x); } template - auto cosh(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return (xsimd::exp(x) + xsimd::exp(-x)) * T(0.5); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = cosh(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cosh(const T &x) { + return xsimd::cosh(x); } template - auto tanh(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return (xsimd::exp(2 * x) - 1) / (xsimd::exp(2 * x) + 1); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = tanh(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto tanh(const T &x) { + return xsimd::tanh(x); } template - auto exp(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return xsimd::exp(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = exp(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto exp(const T &x) { + return xsimd::exp(x); } template - auto log(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return xsimd::log(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = log(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log(const T &x) { + return xsimd::log(x); } template - auto log2(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return xsimd::log2(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = log2(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log2(const T &x) { + return xsimd::log2(x); } template - auto log10(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return xsimd::log10(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = log10(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log10(const T &x) { + return xsimd::log10(x); } template - auto sqrt(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return xsimd::sqrt(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = sqrt(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrt(const T &x) { + return xsimd::sqrt(x); } template - auto cbrt(const T &x) -> T { - // using Scalar = typename T::value_type; - // T result; - // for (int i = 0; i < x.size(); ++i) { result[i] = cbrt(static_cast(x[i])); } - // return result; - + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cbrt(const T &x) { return xsimd::cbrt(x); } template - auto abs(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return xsimd::abs(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = abs(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto abs(const T &x) { + return xsimd::abs(x); } template - auto floor(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return xsimd::floor(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = floor(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto floor(const T &x) { + return xsimd::floor(x); } template - auto ceil(const T &x) -> T { - using Scalar = typename T::value_type; - IF_FLOATING(T) { return xsimd::ceil(x); } - else { - T result; - for (int i = 0; i < x.size(); ++i) { result[i] = ceil(static_cast(x[i])); } - return result; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ceil(const T &x) { + return xsimd::ceil(x); } } // namespace librapid diff --git a/librapid/vendor/xsimd b/librapid/vendor/xsimd new file mode 160000 index 00000000..0eb0bfdf --- /dev/null +++ b/librapid/vendor/xsimd @@ -0,0 +1 @@ +Subproject commit 0eb0bfdfa2082994b7141525c765bb80c09aecec From 034c77033951b0bf34fc77559f6daa23c6476c7b Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sat, 5 Aug 2023 23:36:19 -0700 Subject: [PATCH 06/29] xsimd updates --- librapid/include/librapid/math/vectorImpl.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/librapid/include/librapid/math/vectorImpl.hpp b/librapid/include/librapid/math/vectorImpl.hpp index e4b3e7c6..a53f2dea 100644 --- a/librapid/include/librapid/math/vectorImpl.hpp +++ b/librapid/include/librapid/math/vectorImpl.hpp @@ -287,7 +287,9 @@ namespace librapid { "Index {} out of bounds for Vector of length {}", index, length); - static_assert(false, "Not implemented"); + const int64_t packetIndex = index / packetWidth; + const int64_t elementIndex = index % packetWidth; + return data[packetIndex][elementIndex]; } LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sum() const -> Scalar { From 36a598f7abc7b878894700cfcfd067589da02ef4 Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sat, 5 Aug 2023 23:46:34 -0700 Subject: [PATCH 07/29] Set LIBRAPID_NATIVE_ARCH to ON by default --- CMakeLists.txt | 2 +- docs/source/cmakeIntegration.md | 2 +- librapid/include/librapid/core/config.hpp | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 684fc3b8..c3b90784 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ option(LIBRAPID_USE_OPENCL "Search for OpenCL and use it if possible" ON) option(LIBRAPID_USE_CUDA "Attempt to use CUDA" ON) option(LIBRAPID_USE_MULTIPREC "Include MPIR and MPFR in the LibRapid build" OFF) option(LIBRAPID_FAST_MATH "Use potentially less accurate operations to increase performance" OFF) -option(LIBRAPID_NATIVE_ARCH "Use the native architecture of the system" OFF) +option(LIBRAPID_NATIVE_ARCH "Use the native architecture of the system" ON) option(LIBRAPID_CUDA_DOUBLE_VECTOR_WIDTH "Preferred vector width for vectorised kernels" 2) option(LIBRAPID_CUDA_FLOAT_VECTOR_WIDTH "Preferred vector width for vectorised kernels" 4) diff --git a/docs/source/cmakeIntegration.md b/docs/source/cmakeIntegration.md index df548f6a..b6ecff63 100644 --- a/docs/source/cmakeIntegration.md +++ b/docs/source/cmakeIntegration.md @@ -172,7 +172,7 @@ but may cause some functions to return slightly incorrect results due to lower p ### ``LIBRAPID_NATIVE_ARCH`` ``` -DEFAULT: OFF +DEFAULT: ON ``` Enabling this flag compiles librapid with the most advanced instruction set available on the system. This can lead to diff --git a/librapid/include/librapid/core/config.hpp b/librapid/include/librapid/core/config.hpp index 55bef30e..67d3580b 100644 --- a/librapid/include/librapid/core/config.hpp +++ b/librapid/include/librapid/core/config.hpp @@ -226,6 +226,7 @@ #else # define LIBRAPID_ARCH 0 # define LIBRAPID_ARCH_NAME "None" +# define LIBRAPID_DEFAULT_MEM_ALIGN 32 #endif // Instruction set detection // Check for 32bit vs 64bit From 7c8211a2b60fa2d1e91c0902bd64b6f5ab7bcffc Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sun, 6 Aug 2023 00:18:22 -0700 Subject: [PATCH 08/29] Bug fix in sinh --- librapid/include/librapid/simd/vecOps.hpp | 2 +- test/test-arrayOps.cpp | 42 +++++++++++------------ 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/librapid/include/librapid/simd/vecOps.hpp b/librapid/include/librapid/simd/vecOps.hpp index cc3d285f..a4f1f6ad 100644 --- a/librapid/include/librapid/simd/vecOps.hpp +++ b/librapid/include/librapid/simd/vecOps.hpp @@ -48,7 +48,7 @@ namespace librapid { template LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sinh(const T &x) { - return xsimd::asinh(x); + return xsimd::sinh(x); } template diff --git a/test/test-arrayOps.cpp b/test/test-arrayOps.cpp index 1518f575..a70958d4 100644 --- a/test/test-arrayOps.cpp +++ b/test/test-arrayOps.cpp @@ -9,10 +9,10 @@ using CPU = lrc::backend::CPU; using OPENCL = lrc::backend::OpenCL; using CUDA = lrc::backend::CUDA; -#define SCALAR float -#define BACKEND CPU +// #define SCALAR float +// #define BACKEND CPU -#define TEST_OP(NAME) \ +#define TEST_OP(NAME, SCALAR) \ auto NAME##X = lrc::NAME(x).eval(); \ for (int i = 0; i < NAME##X.shape().size(); ++i) { \ REQUIRE(lrc::isClose((SCALAR)NAME##X(i), (SCALAR)lrc::NAME((SCALAR)x(i)), tolerance)); \ @@ -24,25 +24,25 @@ using CUDA = lrc::backend::CUDA; /* Valid range for all functions */ \ auto x = lrc::linspace(0.1, 0.5, 100, false); \ \ - TEST_OP(sin); \ - TEST_OP(cos); \ - TEST_OP(tan); \ - TEST_OP(asin); \ - TEST_OP(acos); \ - TEST_OP(atan); \ - TEST_OP(sinh); \ - TEST_OP(cosh); \ - TEST_OP(tanh); \ + TEST_OP(sin, SCALAR); \ + TEST_OP(cos, SCALAR); \ + TEST_OP(tan, SCALAR); \ + TEST_OP(asin, SCALAR); \ + TEST_OP(acos, SCALAR); \ + TEST_OP(atan, SCALAR); \ + TEST_OP(sinh, SCALAR); \ + TEST_OP(cosh, SCALAR); \ + TEST_OP(tanh, SCALAR); \ \ - TEST_OP(exp); \ - TEST_OP(log); \ - TEST_OP(log2); \ - TEST_OP(log10); \ - TEST_OP(sqrt); \ - TEST_OP(cbrt); \ - TEST_OP(abs); \ - TEST_OP(floor); \ - TEST_OP(ceil); \ + TEST_OP(exp, SCALAR); \ + TEST_OP(log, SCALAR); \ + TEST_OP(log2, SCALAR); \ + TEST_OP(log10, SCALAR); \ + TEST_OP(sqrt, SCALAR); \ + TEST_OP(cbrt, SCALAR); \ + TEST_OP(abs, SCALAR); \ + TEST_OP(floor, SCALAR); \ + TEST_OP(ceil, SCALAR); \ } TRIG_TEST_IMPL(float, CPU) From eb6a65ed7bb686277752a4a72c2ef091b02e133e Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sun, 6 Aug 2023 10:12:00 -0700 Subject: [PATCH 09/29] Native Arch does not work on MacOS --- CMakeLists.txt | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c3b90784..82ab6473 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,13 @@ option(LIBRAPID_NO_WINDOWS_H "Don't include the Windows.h header" OFF) option(LIBRAPID_MKL_CONFIG_PATH "Path to the 'MKLConfig.cmake' file" "") + +# Native arch does not work on MacOS +if (IS_MACOS AND LIBRAPID_NATIVE_ARCH) + message(WARNING "[ LIBRAPID ] Native architecture optimisation is not currently supported on MacOS") + set(LIBRAPID_NATIVE_ARCH OFF) +endif () + MACRO(SUBDIRLIST result curdir) FILE(GLOB children RELATIVE ${curdir} ${curdir}/*) SET(dirlist "") @@ -486,12 +493,12 @@ if (LIBRAPID_NATIVE_ARCH) target_compile_options(${module_name} PUBLIC ${LIBRAPID_ARCH_FLAGS}) target_compile_definitions(${module_name} PUBLIC LIBRAPID_NATIVE_ARCH) -# OptimizeForArchitecture() -# target_compile_options(${module_name} PUBLIC ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS}) -# target_compile_definitions(${module_name} PUBLIC LIBRAPID_NATIVE_ARCH) -# set(LIBRAPID_ARCH_FLAGS ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS}) -# message(STATUS "[ LIBRAPID ] Additional Definitions: ${Vc_DEFINITIONS}") -# message(STATUS "[ LIBRAPID ] Supported flags: ${Vc_ARCHITECTURE_FLAGS}") + # OptimizeForArchitecture() + # target_compile_options(${module_name} PUBLIC ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS}) + # target_compile_definitions(${module_name} PUBLIC LIBRAPID_NATIVE_ARCH) + # set(LIBRAPID_ARCH_FLAGS ${Vc_DEFINITIONS} ${Vc_ARCHITECTURE_FLAGS}) + # message(STATUS "[ LIBRAPID ] Additional Definitions: ${Vc_DEFINITIONS}") + # message(STATUS "[ LIBRAPID ] Supported flags: ${Vc_ARCHITECTURE_FLAGS}") endif () # Add defines for CUDA vector widths From addf6427217499fc7cda548531e8044e043eb266 Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sun, 6 Aug 2023 14:12:47 -0700 Subject: [PATCH 10/29] Update CMakeLists.txt --- CMakeLists.txt | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 82ab6473..d2ee2b66 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,13 +47,6 @@ option(LIBRAPID_NO_WINDOWS_H "Don't include the Windows.h header" OFF) option(LIBRAPID_MKL_CONFIG_PATH "Path to the 'MKLConfig.cmake' file" "") - -# Native arch does not work on MacOS -if (IS_MACOS AND LIBRAPID_NATIVE_ARCH) - message(WARNING "[ LIBRAPID ] Native architecture optimisation is not currently supported on MacOS") - set(LIBRAPID_NATIVE_ARCH OFF) -endif () - MACRO(SUBDIRLIST result curdir) FILE(GLOB children RELATIVE ${curdir} ${curdir}/*) SET(dirlist "") @@ -131,6 +124,12 @@ if (LIBRAPID_STRICT AND LIBRAPID_QUIET) message(FATAL_ERROR "LIBRAPID_STRICT and LIBRAPID_QUIET cannot be enabled at the same time") endif () +# Native arch does not work on MacOS +if (IS_MACOS AND LIBRAPID_NATIVE_ARCH) + message(WARNING "[ LIBRAPID ] Native architecture optimisation is not currently supported on MacOS") + set(LIBRAPID_NATIVE_ARCH OFF) +endif () + if (LIBRAPID_STRICT) # Enable all warnings and treat them as errors if (MSVC) From 9aaa467c83b79f32d7cfacd0d6ce92f71c82a536 Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sun, 6 Aug 2023 14:36:50 -0700 Subject: [PATCH 11/29] MacOS memory alignment --- CMakeLists.txt | 6 ----- librapid/include/librapid/array/storage.hpp | 27 ++++++++++++--------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d2ee2b66..9fe59eaf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,12 +124,6 @@ if (LIBRAPID_STRICT AND LIBRAPID_QUIET) message(FATAL_ERROR "LIBRAPID_STRICT and LIBRAPID_QUIET cannot be enabled at the same time") endif () -# Native arch does not work on MacOS -if (IS_MACOS AND LIBRAPID_NATIVE_ARCH) - message(WARNING "[ LIBRAPID ] Native architecture optimisation is not currently supported on MacOS") - set(LIBRAPID_NATIVE_ARCH OFF) -endif () - if (LIBRAPID_STRICT) # Enable all warnings and treat them as errors if (MSVC) diff --git a/librapid/include/librapid/array/storage.hpp b/librapid/include/librapid/array/storage.hpp index 0af75de7..ae903093 100644 --- a/librapid/include/librapid/array/storage.hpp +++ b/librapid/include/librapid/array/storage.hpp @@ -341,12 +341,12 @@ namespace librapid { #if defined(LIBRAPID_BLAS_MKLBLAS) mkl_free(ptr); -#else -# if defined(LIBRAPID_NATIVE_ARCH) && defined(LIBRAPID_MSVC) +#elif defined(LIBRAPID_APPLE) + free(ptr); +#elif defined(LIBRAPID_NATIVE_ARCH) && defined(LIBRAPID_MSVC) _aligned_free(ptr); -# else +#else free(ptr); -# endif #endif } @@ -367,20 +367,23 @@ namespace librapid { #if defined(LIBRAPID_BLAS_MKLBLAS) // MKL has its own memory allocation function auto ptr = static_cast(mkl_malloc(size * sizeof(T), 64)); -#else - // Force aligned memory -# if defined(LIBRAPID_APPLE) - // No memory allignment. It breaks everything for some reason - auto ptr = static_cast(std::malloc(size * sizeof(T))); -# elif defined(LIBRAPID_MSVC) || defined(LIBRAPID_MINGW) +#elif defined(LIBRAPID_APPLE) + // Use posix_memalign + void *_ptr; + auto err = posix_memalign(&_ptr, global::memoryAlignment, size * sizeof(T)); + LIBRAPID_ASSERT(err == 0, "posix_memalign failed with error code {}", err); + auto ptr = static_cast(_ptr); +#elif defined(LIBRAPID_MSVC) || defined(LIBRAPID_MINGW) auto ptr = static_cast(_aligned_malloc(size * sizeof(T), global::memoryAlignment)); -# else +#else auto ptr = static_cast( std::aligned_alloc(global::memoryAlignment, size * sizeof(T))); -# endif #endif + LIBRAPID_ASSERT( + ptr != nullptr, "Failed to allocate {} bytes of memory", size * sizeof(T)); + // If the type cannot be trivially constructed, we need to // initialize each value if constexpr (!typetraits::TriviallyDefaultConstructible::value && From e5435ccaedfb5815c5f5c2d43edce6accfb882ca Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sun, 6 Aug 2023 15:07:49 -0700 Subject: [PATCH 12/29] Run an example that's erroring for more information --- .github/workflows/continuous-integration.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/continuous-integration.yaml b/.github/workflows/continuous-integration.yaml index b60fce27..59e9a369 100644 --- a/.github/workflows/continuous-integration.yaml +++ b/.github/workflows/continuous-integration.yaml @@ -298,6 +298,11 @@ jobs: CC: ${{ matrix.cc }} CXX: ${{ matrix.cxx }} + - name: Run example-array-1 (Debug) + run: | + cd buildDebug + ./examples/example-array-1 + - name: Run Tests (Debug) run: | cd buildDebug From b453f6bae7c8724b99921f4b08e6cec8a87ed0ed Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sun, 6 Aug 2023 22:02:30 -0700 Subject: [PATCH 13/29] Attempt at MacOS segfault fix --- CMakeLists.txt | 6 ++++++ librapid/include/librapid/array/arrayContainer.hpp | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9fe59eaf..3f9c852b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,6 +124,12 @@ if (LIBRAPID_STRICT AND LIBRAPID_QUIET) message(FATAL_ERROR "LIBRAPID_STRICT and LIBRAPID_QUIET cannot be enabled at the same time") endif () +# SIMD instructions do not currently work on MacOS +if (IS_MACOS AND LIBRAPID_NATIVE_ARCH) + message(WARNING "SIMD instructions are not currently supported on MacOS. Disabling LIBRAPID_NATIVE_ARCH") + set(LIBRAPID_NATIVE_ARCH OFF) +endif () + if (LIBRAPID_STRICT) # Enable all warnings and treat them as errors if (MSVC) diff --git a/librapid/include/librapid/array/arrayContainer.hpp b/librapid/include/librapid/array/arrayContainer.hpp index 358025c2..a03d14b8 100644 --- a/librapid/include/librapid/array/arrayContainer.hpp +++ b/librapid/include/librapid/array/arrayContainer.hpp @@ -668,7 +668,11 @@ namespace librapid { template auto ArrayContainer::packet(size_t index) const -> Packet { +#if defined(LIBRAPID_NATIVE_ARCH) return xsimd::load_aligned(m_storage.begin() + index); +#else + return xsimd::load_unaligned(m_storage.begin() + index); +#endif } template @@ -679,7 +683,11 @@ namespace librapid { template void ArrayContainer::writePacket(size_t index, const Packet &value) { +#if defined(LIBRAPID_NATIVE_ARCH) value.store_aligned(m_storage.begin() + index); +#else + value.store_unaligned(m_storage.begin() + index); +#endif } template From 78a911265931e8b645b9b0c20995b901374858de Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sun, 6 Aug 2023 22:05:31 -0700 Subject: [PATCH 14/29] Update example for more debug info --- CMakeLists.txt | 8 ++++---- examples/example-array-1.cpp | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f9c852b..1c57e6aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -125,10 +125,10 @@ if (LIBRAPID_STRICT AND LIBRAPID_QUIET) endif () # SIMD instructions do not currently work on MacOS -if (IS_MACOS AND LIBRAPID_NATIVE_ARCH) - message(WARNING "SIMD instructions are not currently supported on MacOS. Disabling LIBRAPID_NATIVE_ARCH") - set(LIBRAPID_NATIVE_ARCH OFF) -endif () +#if (IS_MACOS AND LIBRAPID_NATIVE_ARCH) +# message(WARNING "SIMD instructions are not currently supported on MacOS. Disabling LIBRAPID_NATIVE_ARCH") +# set(LIBRAPID_NATIVE_ARCH OFF) +#endif () if (LIBRAPID_STRICT) # Enable all warnings and treat them as errors diff --git a/examples/example-array-1.cpp b/examples/example-array-1.cpp index 4209c2dd..4dfa22aa 100644 --- a/examples/example-array-1.cpp +++ b/examples/example-array-1.cpp @@ -2,42 +2,63 @@ namespace lrc = librapid; +template +void printHelper(Args... args) { + fmt::print(args...); + std::cout << std::flush; // Flush the output buffer +} + auto main() -> int { fmt::print("LibRapid Example -- Array 1\n"); // Create a vector with 10 elements + printHelper("Creating Vector"); lrc::Array myVector(lrc::Shape({5})); // Fill the vector with values + printHelper("Filling Vector"); for (int i = 0; i < 5; i++) { myVector[i] = i; } // Print the vector + printHelper("Printing Vector"); fmt::print("Vector: {}\n", myVector); // Create a matrix with 3x5 elements + printHelper("Creating Matrix"); lrc::Array myMatrix(lrc::Shape({3, 5})); + printHelper("Filling Matrix"); for (int i = 0; i < 3; i++) { for (int j = 0; j < 5; j++) { myMatrix[i][j] = i * 5 + j; } } + printHelper("Printing Matrix"); fmt::print("Matrix:\n{}\n", myMatrix); // Do some simple calculations + printHelper("Adding Vectors"); fmt::print("My Vector + My Vector: {}\n", myVector + myVector); + + printHelper("Adding Elements"); fmt::print("[0] + [4]: {}\n", myVector[0].get() + myVector[4].get()); fmt::print("\n"); + printHelper("Combined Operations"); fmt::print("M + M * M:\n{}\n", myMatrix + myMatrix * myMatrix); // Add a vector to a row of a matrix + printHelper("Vector + Matrix[2]"); fmt::print("My Vector + My matrix [2]: {}\n", myVector + myMatrix[2]); // Compare two arrays + printHelper("Creating Vectors"); lrc::Array leftVector(lrc::Shape({5})); lrc::Array rightVector(lrc::Shape({5})); + printHelper("Comma Initializing Vectors"); leftVector << 1, 2, 3, 4, 5; rightVector << 5, 4, 3, 2, 1; + printHelper("Less Than"); fmt::print("{} < {} --> {}\n", leftVector, rightVector, leftVector < rightVector); + printHelper("Greater or Equal"); fmt::print("{} >= {} --> {}\n", leftVector, rightVector, leftVector >= rightVector); return 0; From 5cbe302465d7d720aa8540377c99dbd01fb617e2 Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sun, 6 Aug 2023 22:56:44 -0700 Subject: [PATCH 15/29] Did I find the magical error? --- .../include/librapid/array/arrayContainer.hpp | 7 ++++- librapid/include/librapid/array/storage.hpp | 28 ++++++++++--------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/librapid/include/librapid/array/arrayContainer.hpp b/librapid/include/librapid/array/arrayContainer.hpp index a03d14b8..e0255fcc 100644 --- a/librapid/include/librapid/array/arrayContainer.hpp +++ b/librapid/include/librapid/array/arrayContainer.hpp @@ -668,7 +668,12 @@ namespace librapid { template auto ArrayContainer::packet(size_t index) const -> Packet { -#if defined(LIBRAPID_NATIVE_ARCH) +#if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_OSX) + // On MacOS (and other platforms??) we cannot use aligned loads in arrays due to one + // annoying edge case. Normally, all SIMD loads will be aligned to a 64-byte boundary. + // Say, however, this array is a sub-array of a larger array. If the outer dimension + // of the larger array does not result in a 64-byte alignment, the data of *this* array + // will not be correctly aligned, hence causing a segfault. return xsimd::load_aligned(m_storage.begin() + index); #else return xsimd::load_unaligned(m_storage.begin() + index); diff --git a/librapid/include/librapid/array/storage.hpp b/librapid/include/librapid/array/storage.hpp index ae903093..47253b23 100644 --- a/librapid/include/librapid/array/storage.hpp +++ b/librapid/include/librapid/array/storage.hpp @@ -28,19 +28,21 @@ namespace librapid { template class Storage { public: - using Scalar = Scalar_; - using RawPointer = Scalar *; - using ConstRawPointer = const Scalar *; - using Pointer = std::shared_ptr; - using ConstPointer = std::shared_ptr; - using Reference = Scalar &; - using ConstReference = const Scalar &; - using SizeType = size_t; - using DifferenceType = ptrdiff_t; - using Iterator = RawPointer; - using ConstIterator = ConstRawPointer; - using ReverseIterator = std::reverse_iterator; - using ConstReverseIterator = std::reverse_iterator; + using Scalar = Scalar_; + using Packet = typename typetraits::TypeInfo::Packet; + static constexpr uint64_t packetWidth = typetraits::TypeInfo::packetWidth; + using RawPointer = Scalar *; + using ConstRawPointer = const Scalar *; + using Pointer = std::shared_ptr; + using ConstPointer = std::shared_ptr; + using Reference = Scalar &; + using ConstReference = const Scalar &; + using SizeType = size_t; + using DifferenceType = ptrdiff_t; + using Iterator = RawPointer; + using ConstIterator = ConstRawPointer; + using ReverseIterator = std::reverse_iterator; + using ConstReverseIterator = std::reverse_iterator; /// Default constructor Storage() = default; From 8abb36bc9b041171847c52a6ef8129cf717f28d0 Mon Sep 17 00:00:00 2001 From: Toby Davis Date: Sun, 6 Aug 2023 23:17:01 -0700 Subject: [PATCH 16/29] Fix MacOS segfault --- .github/workflows/continuous-integration.yaml | 5 ----- examples/example-array-1.cpp | 21 ------------------- 2 files changed, 26 deletions(-) diff --git a/.github/workflows/continuous-integration.yaml b/.github/workflows/continuous-integration.yaml index 59e9a369..b60fce27 100644 --- a/.github/workflows/continuous-integration.yaml +++ b/.github/workflows/continuous-integration.yaml @@ -298,11 +298,6 @@ jobs: CC: ${{ matrix.cc }} CXX: ${{ matrix.cxx }} - - name: Run example-array-1 (Debug) - run: | - cd buildDebug - ./examples/example-array-1 - - name: Run Tests (Debug) run: | cd buildDebug diff --git a/examples/example-array-1.cpp b/examples/example-array-1.cpp index 4dfa22aa..4209c2dd 100644 --- a/examples/example-array-1.cpp +++ b/examples/example-array-1.cpp @@ -2,63 +2,42 @@ namespace lrc = librapid; -template -void printHelper(Args... args) { - fmt::print(args...); - std::cout << std::flush; // Flush the output buffer -} - auto main() -> int { fmt::print("LibRapid Example -- Array 1\n"); // Create a vector with 10 elements - printHelper("Creating Vector"); lrc::Array myVector(lrc::Shape({5})); // Fill the vector with values - printHelper("Filling Vector"); for (int i = 0; i < 5; i++) { myVector[i] = i; } // Print the vector - printHelper("Printing Vector"); fmt::print("Vector: {}\n", myVector); // Create a matrix with 3x5 elements - printHelper("Creating Matrix"); lrc::Array myMatrix(lrc::Shape({3, 5})); - printHelper("Filling Matrix"); for (int i = 0; i < 3; i++) { for (int j = 0; j < 5; j++) { myMatrix[i][j] = i * 5 + j; } } - printHelper("Printing Matrix"); fmt::print("Matrix:\n{}\n", myMatrix); // Do some simple calculations - printHelper("Adding Vectors"); fmt::print("My Vector + My Vector: {}\n", myVector + myVector); - - printHelper("Adding Elements"); fmt::print("[0] + [4]: {}\n", myVector[0].get() + myVector[4].get()); fmt::print("\n"); - printHelper("Combined Operations"); fmt::print("M + M * M:\n{}\n", myMatrix + myMatrix * myMatrix); // Add a vector to a row of a matrix - printHelper("Vector + Matrix[2]"); fmt::print("My Vector + My matrix [2]: {}\n", myVector + myMatrix[2]); // Compare two arrays - printHelper("Creating Vectors"); lrc::Array leftVector(lrc::Shape({5})); lrc::Array rightVector(lrc::Shape({5})); - printHelper("Comma Initializing Vectors"); leftVector << 1, 2, 3, 4, 5; rightVector << 5, 4, 3, 2, 1; - printHelper("Less Than"); fmt::print("{} < {} --> {}\n", leftVector, rightVector, leftVector < rightVector); - printHelper("Greater or Equal"); fmt::print("{} >= {} --> {}\n", leftVector, rightVector, leftVector >= rightVector); return 0; From 0a8bbe32b6117ca8fe2764978c1f0ca35d623114 Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Thu, 10 Aug 2023 00:41:52 -0700 Subject: [PATCH 17/29] Slowly porting code to C++20 and beyond --- CMakeLists.txt | 6 +++- .../include/librapid/array/arrayContainer.hpp | 36 +++++++++++++++++-- .../include/librapid/core/helperMacros.hpp | 15 ++++---- .../include/librapid/core/librapidPch.hpp | 1 + librapid/include/librapid/math/coreMath.hpp | 16 +++++++++ librapid/include/librapid/math/half.hpp | 4 ++- librapid/include/librapid/utils/time.hpp | 3 +- 7 files changed, 68 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c57e6aa..1907147a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,7 +2,11 @@ cmake_minimum_required(VERSION 3.16) project(librapid) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") cmake_policy(SET CMP0077 NEW) -set(CMAKE_CXX_STANDARD 17) + +# LibRapid requires C++20 or later +if (CMAKE_CXX_STANDARD LESS 20) + message(FATAL_ERROR "LibRapid requires C++20 or later") +endif () # Extract version information file(READ "version.txt" ver) diff --git a/librapid/include/librapid/array/arrayContainer.hpp b/librapid/include/librapid/array/arrayContainer.hpp index e0255fcc..f6398cd0 100644 --- a/librapid/include/librapid/array/arrayContainer.hpp +++ b/librapid/include/librapid/array/arrayContainer.hpp @@ -843,8 +843,40 @@ namespace librapid { // Support FMT printing #ifdef FMT_API -LIBRAPID_SIMPLE_IO_IMPL(typename ShapeType_ COMMA typename StorageType_, - librapid::array::ArrayContainer) +// LIBRAPID_SIMPLE_IO_IMPL(typename ShapeType_ COMMA typename StorageType_, +// librapid::array::ArrayContainer) + +template +struct fmt::formatter> { + char formatStr[32] = {'{', ':'}; + constexpr auto parse(format_parse_context &ctx) -> format_parse_context::iterator { + auto it = ctx.begin(); + uint64_t index = 0; + for (; it != ctx.end(); ++it) { + if (*it == '}') break; + formatStr[index++] += *it; + } + formatStr[index] = '}'; + return it; + } + + template + auto format(const librapid::array::ArrayContainer &object, + FormatContext &ctx) { + try { + // return fmt::format_to(ctx.out(), object.str(formatStr)); + return fmt::format_to(ctx.out(), "Hello, World"); + } catch (std::exception &e) { return fmt::format_to(ctx.out(), e.what()); } + } +}; + +template +std::ostream &operator<<(std::ostream &os, + const librapid::array::ArrayContainer &object) { + os << object.str(); + return os; +} + LIBRAPID_SIMPLE_IO_NORANGE(typename ShapeType_ COMMA typename StorageType_, librapid::array::ArrayContainer) #endif // FMT_API diff --git a/librapid/include/librapid/core/helperMacros.hpp b/librapid/include/librapid/core/helperMacros.hpp index e61236b0..8065a64c 100644 --- a/librapid/include/librapid/core/helperMacros.hpp +++ b/librapid/include/librapid/core/helperMacros.hpp @@ -10,17 +10,16 @@ #define LIBRAPID_SIMPLE_IO_IMPL(TEMPLATE_, TYPE_) \ template \ struct fmt::formatter { \ - std::string formatStr = "{}"; \ + char formatStr[32] = {'{', ':'}; \ \ - template \ - constexpr auto parse(ParseContext &ctx) { \ - formatStr = "{:"; \ - auto it = ctx.begin(); \ + constexpr auto parse(format_parse_context &ctx) -> format_parse_context::iterator { \ + auto it = ctx.begin(); \ + uint64_t index = 0; \ for (; it != ctx.end(); ++it) { \ if (*it == '}') break; \ - formatStr += *it; \ + formatStr[index++] += *it; \ } \ - formatStr += "}"; \ + formatStr[index] = '}'; \ return it; \ } \ \ @@ -69,7 +68,7 @@ } #define LIBRAPID_SIMPLE_IO_NORANGE(TEMPLATE, TYPE) \ - template \ + template \ struct fmt::is_range : std::false_type {}; namespace librapid::typetraits { diff --git a/librapid/include/librapid/core/librapidPch.hpp b/librapid/include/librapid/core/librapidPch.hpp index d8457401..79f76913 100644 --- a/librapid/include/librapid/core/librapidPch.hpp +++ b/librapid/include/librapid/core/librapidPch.hpp @@ -49,6 +49,7 @@ #include #include #include +#include #include #include #include diff --git a/librapid/include/librapid/math/coreMath.hpp b/librapid/include/librapid/math/coreMath.hpp index ac5d1c1a..7ace3ee2 100644 --- a/librapid/include/librapid/math/coreMath.hpp +++ b/librapid/include/librapid/math/coreMath.hpp @@ -209,6 +209,22 @@ namespace librapid { } } + // Return 10 raised to a given power. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the exponential. + /// \tparam T Data type + /// \param val Input value + /// \return 10 raised to the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto exp10(T val) { + // C++ standard does not implement exp10 + + if constexpr (std::is_integral_v) { + return std::pow(10.0, static_cast(val)); + } else { + return std::pow(10.0, val); + } + } + /// Return the natural logarithm of a given value. Note that, for integer values, this function /// will cast the input value to a floating point type before calculating the logarithm. /// \tparam T Data type diff --git a/librapid/include/librapid/math/half.hpp b/librapid/include/librapid/math/half.hpp index 97ef4aee..da98f3fd 100644 --- a/librapid/include/librapid/math/half.hpp +++ b/librapid/include/librapid/math/half.hpp @@ -635,7 +635,9 @@ namespace librapid { } std::string half::str(const std::string &format) const { - return fmt::format(format, static_cast(*this)); + // return fmt::vformat(format, fmt::make_wformat_args(detail::halfToFloat(m_value.m_bits))); + + return std::vformat(format, std::make_format_args(detail::halfToFloat(m_value.m_bits))); } LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator+(const half &lhs, diff --git a/librapid/include/librapid/utils/time.hpp b/librapid/include/librapid/utils/time.hpp index 8966e55a..b5de8124 100644 --- a/librapid/include/librapid/utils/time.hpp +++ b/librapid/include/librapid/utils/time.hpp @@ -67,7 +67,8 @@ namespace librapid { static double divisor[] = {1000, 1000, 1000, 60, 60, 24, 365, 1e300}; for (int i = 0; i < numUnits; ++i) { - if (ns < divisor[i]) return std::operator+(fmt::format(format, ns), prefix[i]); + // if (ns < divisor[i]) return std::operator+(fmt::format(format, ns), prefix[i]); + if (ns < divisor[i]) return fmt::vformat(format, fmt::make_format_args(ns)) + prefix[i]; ns /= divisor[i]; } return fmt::format("{}ns", time * ns); From 60466186ce158beba1be6e0a8dcd3294a13e89d9 Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Fri, 11 Aug 2023 01:14:49 -0700 Subject: [PATCH 18/29] Fix array formatting. Still need to do the rest --- .../include/librapid/array/arrayContainer.hpp | 78 +++++++--- librapid/include/librapid/array/arrayView.hpp | 25 ++-- .../librapid/array/arrayViewString.hpp | 139 ++++++++---------- librapid/include/librapid/array/fill.hpp | 25 ++-- .../librapid/array/pseudoConstructors.hpp | 5 +- 5 files changed, 148 insertions(+), 124 deletions(-) diff --git a/librapid/include/librapid/array/arrayContainer.hpp b/librapid/include/librapid/array/arrayContainer.hpp index f6398cd0..f656181f 100644 --- a/librapid/include/librapid/array/arrayContainer.hpp +++ b/librapid/include/librapid/array/arrayContainer.hpp @@ -339,6 +339,10 @@ namespace librapid { /// \return A string representation of the array container LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; + template + void fmtStr(const fmt::formatter &format, char bracket, char separator, + Ctx &ctx) const; + private: ShapeType m_shape; // The shape type of the array StorageType m_storage; // The storage container of the array @@ -804,6 +808,14 @@ namespace librapid { std::string ArrayContainer::str(const std::string &format) const { return ArrayView(*this).str(format); } + + template + template + void ArrayContainer::fmtStr(const fmt::formatter &format, + char bracket, char separator, + Ctx &ctx) const { + ArrayView(*this).fmtStr(format, bracket, separator, ctx); + } } // namespace array namespace detail { @@ -843,37 +855,63 @@ namespace librapid { // Support FMT printing #ifdef FMT_API -// LIBRAPID_SIMPLE_IO_IMPL(typename ShapeType_ COMMA typename StorageType_, -// librapid::array::ArrayContainer) - template struct fmt::formatter> { - char formatStr[32] = {'{', ':'}; - constexpr auto parse(format_parse_context &ctx) -> format_parse_context::iterator { - auto it = ctx.begin(); - uint64_t index = 0; - for (; it != ctx.end(); ++it) { - if (*it == '}') break; - formatStr[index++] += *it; - } - formatStr[index] = '}'; - return it; + using Type = librapid::array::ArrayContainer; + using Scalar = typename librapid::typetraits::TypeInfo::Scalar; + using Formatter = fmt::formatter; + Formatter m_formatter; + char m_bracket; + char m_separator; + + template + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { + // Custom format options: + // - 'r' for round brackets + // - 's' for square brackets + // - 'c' for curly brackets + // - 'a' for angle brackets + // - 'p' for pipe brackets + // - "-," for comma separator + // - "-;" for semicolon separator + // - "-:" for colon separator + // - "-|" for pipe separator + // - "-_" for underscore separator + + auto it = ctx.begin(), end = ctx.end(); + if (it != end && (*it == 'r' || *it == 's' || *it == 'c' || *it == 'a' || *it == 'p')) { + m_bracket = *it++; + } else { + m_bracket = 's'; + } + + if (it != end && *it == '-') { + ++it; + if (it != end) { + m_separator = *it++; + } else { + m_separator = ','; + } + } else { + m_separator = ' '; + } + + ctx.advance_to(it); + + return m_formatter.parse(ctx); } template - auto format(const librapid::array::ArrayContainer &object, - FormatContext &ctx) { - try { - // return fmt::format_to(ctx.out(), object.str(formatStr)); - return fmt::format_to(ctx.out(), "Hello, World"); - } catch (std::exception &e) { return fmt::format_to(ctx.out(), e.what()); } + FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) { + val.fmtStr(m_formatter, m_bracket, m_separator, ctx); + return ctx.out(); } }; template std::ostream &operator<<(std::ostream &os, const librapid::array::ArrayContainer &object) { - os << object.str(); + os << "NOT IMPLEMENTED!"; // object.str(); return os; } diff --git a/librapid/include/librapid/array/arrayView.hpp b/librapid/include/librapid/array/arrayView.hpp index ef8dde1c..88ae7783 100644 --- a/librapid/include/librapid/array/arrayView.hpp +++ b/librapid/include/librapid/array/arrayView.hpp @@ -15,11 +15,11 @@ namespace librapid { } // namespace typetraits namespace array { - template + template class ArrayView { public: // using ArrayType = T; - using BaseType = typename std::decay_t; + using BaseType = typename std::decay_t; using Scalar = typename typetraits::TypeInfo::Scalar; using Reference = BaseType &; using ConstReference = const BaseType &; @@ -34,11 +34,11 @@ namespace librapid { /// Copy an ArrayView object /// \param array The array to copy - explicit ArrayView(T &array); + explicit ArrayView(ArrayViewType &array); /// Copy an ArrayView object (not const) /// \param array The array to copy - explicit ArrayView(T &&array) = delete; + explicit ArrayView(ArrayViewType &&array) = delete; /// Copy an ArrayView object (const) /// \param other The array to copy @@ -71,9 +71,9 @@ namespace librapid { /// Access a sub-array of this ArrayView. /// \param index The index of the sub-array. /// \return An ArrayView from this - const ArrayView operator[](int64_t index) const; + const ArrayView operator[](int64_t index) const; - ArrayView operator[](int64_t index); + ArrayView operator[](int64_t index); /// Since even scalars are represented as an ArrayView object, it can be difficult to /// operate on them directly. This allows you to extract the scalar value stored by a @@ -139,15 +139,19 @@ namespace librapid { /// \return A std::string representation of this ArrayView LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; + template + void fmtStr(const fmt::formatter &format, char bracket, char separator, + Ctx &ctx) const; + private: - T &m_ref; + ArrayViewType &m_ref; ShapeType m_shape; StrideType m_stride; int64_t m_offset = 0; }; - template - ArrayView::ArrayView(T &array) : + template + ArrayView::ArrayView(ArrayViewType &array) : m_ref(array), m_shape(array.shape()), m_stride(array.shape()) {} template @@ -324,7 +328,8 @@ namespace librapid { // Support FMT printing #ifdef FMT_API -LIBRAPID_SIMPLE_IO_IMPL(typename T, librapid::array::ArrayView) +// LIBRAPID_SIMPLE_IO_IMPL(typename T, librapid::array::ArrayView) + LIBRAPID_SIMPLE_IO_NORANGE(typename T, librapid::array::ArrayView) #endif // FMT_API diff --git a/librapid/include/librapid/array/arrayViewString.hpp b/librapid/include/librapid/array/arrayViewString.hpp index e560e46d..3a74a3d3 100644 --- a/librapid/include/librapid/array/arrayViewString.hpp +++ b/librapid/include/librapid/array/arrayViewString.hpp @@ -3,100 +3,77 @@ namespace librapid { namespace detail { - /// Count the width of a value for use in a String representation of an Array. The returned - /// tuple contains the length of the value before and after the central point. For floating - /// point values, the central point is the decimal point. For integer values, the central - /// point is the end of the value - /// \tparam T The type of the value to count the width of - /// \param val The value to count the width of - /// \param format The format string to use when converting the value to a String - /// \return The relevant widths of the value - template, int> = 0> - LIBRAPID_INLINE std::pair countWidth(const T &val, - const std::string &format) { - std::string str = fmt::format(format, val); - auto point = str.find('.'); - if (point == std::string::npos) { return {str.size(), 0}; } - return {point, str.size() - point}; - } + template + void arrayViewToString(const array::ArrayView &view, + const fmt::formatter &formatter, char bracket, + char separator, int64_t indent, Ctx &ctx) { + char bracketCharOpen, bracketCharClose; - template, int> = 0> - LIBRAPID_INLINE std::pair countWidth(const T &val, - const std::string &format) { - std::string str = fmt::format(format, val); - return {str.size(), 0}; - } + switch (bracket) { + case 'r': + bracketCharOpen = '('; + bracketCharClose = ')'; + break; + case 's': + bracketCharOpen = '['; + bracketCharClose = ']'; + break; + case 'c': + bracketCharOpen = '{'; + bracketCharClose = '}'; + break; + case 'a': + bracketCharOpen = '<'; + bracketCharClose = '>'; + break; + case 'p': + bracketCharOpen = '|'; + bracketCharClose = '|'; + break; + default: + bracketCharOpen = '['; + bracketCharClose = ']'; + break; + } + + // Separator char is already the correct character - template - std::vector> countColumnWidths(const array::ArrayView &view, - const std::string &format) { if (view.ndim() == 0) { - // Scalar - return {countWidth(view.scalar(0), format)}; + formatter.format(view.scalar(0), ctx); } else if (view.ndim() == 1) { - // Vector - std::vector> widths(view.shape()[0]); - for (int64_t i = 0; i < static_cast(view.shape()[0]); ++i) { - widths[i] = countWidth(view.scalar(i), format); - } - return widths; - } else { - // General - std::vector> widths = - countColumnWidths(view[0], format); - for (int64_t i = 1; i < static_cast(view.shape()[0]); ++i) { - auto subWidths = countColumnWidths(view[i], format); - for (int64_t j = 0; j < static_cast(widths.size()); ++j) { - widths[j].first = ::librapid::max(widths[j].first, subWidths[j].first); - widths[j].second = ::librapid::max(widths[j].second, subWidths[j].second); + fmt::format_to(ctx.out(), "{}", bracketCharOpen); + for (int64_t i = 0; i < static_cast(view.shape()[0]); i++) { + formatter.format(view.scalar(i), ctx); + if (i != view.shape()[0] - 1) { + if (separator == ' ') { + fmt::format_to(ctx.out(), " "); + } else { + fmt::format_to(ctx.out(), "{} ", separator); + } } } - return widths; - } - } - - template - std::string arrayViewToString(const array::ArrayView &view, const std::string &format, - const std::vector> &widths, - int64_t indent) { - if (view.ndim() == 0) { return fmt::format(format, view.scalar(0)); } - - if (view.ndim() == 1) { - std::string str = "["; + fmt::format_to(ctx.out(), "{}", bracketCharClose); + } else { + fmt::format_to(ctx.out(), "{}", bracketCharOpen); for (int64_t i = 0; i < static_cast(view.shape()[0]); i++) { - std::pair width = detail::countWidth(view.scalar(i), format); - str += fmt::format("{:>{}}{}{:>{}}", - "", - widths[i].first - width.first, - fmt::format(format, view.scalar(i)), - "", - widths[i].second - width.second); - if (i != view.shape()[0] - 1) { str += " "; } - } - str += "]"; - return str; - } - - std::string str = "["; - for (int64_t i = 0; i < static_cast(view.shape()[0]); i++) { - if (i > 0) str += std::string(indent + 1, ' '); - str += arrayViewToString(view[i], format, widths, indent + 1); - if (i != view.shape()[0] - 1) { - str += "\n"; - if (view.ndim() > 2) { str += "\n"; } + if (i > 0) fmt::format_to(ctx.out(), "{}", std::string(indent + 1, ' ')); + arrayViewToString(view[i], formatter, bracket, separator, indent + 1, ctx); + if (i != view.shape()[0] - 1) { + fmt::format_to(ctx.out(), "{}\n", separator); + if (view.ndim() > 2) { fmt::format_to(ctx.out(), "\n"); } + } } + fmt::format_to(ctx.out(), "{}", bracketCharClose); } - str += "]"; - return str; } } // namespace detail namespace array { - template - auto ArrayView::str(const std::string &format) const -> std::string { - std::vector> widths = - detail::countColumnWidths(*this, format); - return detail::arrayViewToString(*this, format, widths, 0); + template + template + void ArrayView::fmtStr(const fmt::formatter &format, char bracket, + char separator, Ctx &ctx) const { + detail::arrayViewToString(*this, format, bracket, separator, 0, ctx); } } // namespace array } // namespace librapid diff --git a/librapid/include/librapid/array/fill.hpp b/librapid/include/librapid/array/fill.hpp index 062dc588..a33f1285 100644 --- a/librapid/include/librapid/array/fill.hpp +++ b/librapid/include/librapid/array/fill.hpp @@ -19,11 +19,13 @@ namespace librapid { if (parallel) { #pragma omp parallel for for (int64_t i = 0; i < shape.size(); ++i) { - data[i] = random(lower, upper); + data[i] = random(static_cast(lower), + static_cast(upper)); } } else { for (int64_t i = 0; i < shape.size(); ++i) { - data[i] = random(lower, upper); + data[i] = random(static_cast(lower), + static_cast(upper)); } } } @@ -102,15 +104,16 @@ namespace librapid { // reseed is controlled by the random module, so we don't need to worry about it here } - cuda::runKernel("fill", - std::is_same_v ? "fillRandomHalf" : "fillRandom", - elements, - dst.storage().data().get(), - elements, - static_cast(lower), - static_cast(upper), - seeds.storage().data().get(), - numSeeds); + cuda::runKernel( + "fill", + std::is_same_v ? "fillRandomHalf" : "fillRandom", + elements, + dst.storage().data().get(), + elements, + static_cast(lower), + static_cast(upper), + seeds.storage().data().get(), + numSeeds); } template diff --git a/librapid/include/librapid/array/pseudoConstructors.hpp b/librapid/include/librapid/array/pseudoConstructors.hpp index 5c947f87..c22d661e 100644 --- a/librapid/include/librapid/array/pseudoConstructors.hpp +++ b/librapid/include/librapid/array/pseudoConstructors.hpp @@ -216,8 +216,9 @@ namespace librapid { return result; } - template - Array random(const ShapeType &shape, Lower lower, Upper upper) { + template + Array random(const ShapeType &shape, Lower lower = 0, Upper upper = 1) { Array result(shape); fillRandom(result, lower, upper); return result; From 0cb3f69a25c9de03b3922bffa88e3b93fc664e3d Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Fri, 11 Aug 2023 18:43:46 -0700 Subject: [PATCH 19/29] Update clang-format and use spaces, not tabs --- librapid/include/librapid/array/array.hpp | 4 +- .../include/librapid/array/arrayContainer.hpp | 1686 ++++--- .../include/librapid/array/arrayFromData.hpp | 168 +- .../include/librapid/array/arrayIterator.hpp | 190 +- .../include/librapid/array/arrayTypeDef.hpp | 191 +- librapid/include/librapid/array/arrayView.hpp | 683 +-- .../librapid/array/arrayViewString.hpp | 140 +- librapid/include/librapid/array/assignOps.hpp | 1008 +++-- .../librapid/array/commaInitializer.hpp | 82 +- librapid/include/librapid/array/fill.hpp | 310 +- .../librapid/array/fourierTransform.hpp | 207 +- librapid/include/librapid/array/function.hpp | 537 +-- .../librapid/array/linalg/arrayMultiply.hpp | 1374 +++--- .../include/librapid/array/linalg/compat.hpp | 52 +- .../librapid/array/linalg/level2/gemv.cl | 56 +- .../librapid/array/linalg/level2/gemv.cu | 24 +- .../librapid/array/linalg/level2/gemv.hpp | 434 +- .../librapid/array/linalg/level3/geam.hpp | 694 ++- .../librapid/array/linalg/level3/gemm.cl | 98 +- .../librapid/array/linalg/level3/gemm.cu | 50 +- .../librapid/array/linalg/level3/gemm.hpp | 572 +-- .../include/librapid/array/linalg/linalg.hpp | 26 +- .../librapid/array/linalg/transpose.hpp | 1386 +++--- .../include/librapid/array/operations.hpp | 2110 ++++----- .../librapid/array/pseudoConstructors.hpp | 442 +- librapid/include/librapid/array/sizetype.hpp | 678 +-- librapid/include/librapid/array/storage.hpp | 1706 +++---- .../include/librapid/array/strideTools.hpp | 90 +- librapid/include/librapid/autodiff/dual.hpp | 880 ++-- librapid/include/librapid/core/config.hpp | 410 +- librapid/include/librapid/core/core.hpp | 20 +- librapid/include/librapid/core/cudaConfig.hpp | 132 +- librapid/include/librapid/core/debugTrap.hpp | 114 +- librapid/include/librapid/core/forward.hpp | 184 +- .../include/librapid/core/genericConfig.hpp | 304 +- librapid/include/librapid/core/global.hpp | 98 +- librapid/include/librapid/core/gnuConfig.hpp | 304 +- .../include/librapid/core/helperMacros.hpp | 112 +- .../include/librapid/core/librapidPch.hpp | 24 +- librapid/include/librapid/core/literals.hpp | 8 +- librapid/include/librapid/core/msvcConfig.hpp | 302 +- .../include/librapid/core/openclConfig.hpp | 22 +- librapid/include/librapid/core/preMain.hpp | 35 +- librapid/include/librapid/core/traits.hpp | 1382 +++--- librapid/include/librapid/core/typetraits.hpp | 90 +- .../include/librapid/core/warningSuppress.hpp | 10 +- .../librapid/cuda/cudaKernelProcesor.hpp | 176 +- .../include/librapid/cuda/cudaStorage.hpp | 1100 ++--- librapid/include/librapid/cuda/exception.h | 60 +- librapid/include/librapid/cuda/helper_cuda.h | 704 +-- .../librapid/cuda/helper_cuda_drvapi.h | 552 +-- .../include/librapid/cuda/helper_cusolver.h | 94 +- .../include/librapid/cuda/helper_functions.h | 4 +- librapid/include/librapid/cuda/helper_image.h | 1390 +++--- librapid/include/librapid/cuda/helper_math.h | 816 ++-- .../include/librapid/cuda/helper_string.h | 506 +-- .../include/librapid/cuda/kernel_header.h | 30 +- librapid/include/librapid/cuda/kernels/abs.cu | 4 +- .../librapid/cuda/kernels/activations.cu | 16 +- .../librapid/cuda/kernels/arithmetic.cu | 137 +- .../librapid/cuda/kernels/expLogPow.cu | 36 +- .../include/librapid/cuda/kernels/fill.cu | 210 +- .../librapid/cuda/kernels/floorCeilRound.cu | 8 +- .../librapid/cuda/kernels/kernelHelper.cuh | 4 +- .../include/librapid/cuda/kernels/negate.cu | 76 +- .../librapid/cuda/kernels/trigonometry.cu | 10 +- .../librapid/cuda/kernels/vectorOps.cuh | 708 +-- librapid/include/librapid/cuda/nvrtc_helper.h | 246 +- .../include/librapid/math/compileTime.hpp | 16 +- librapid/include/librapid/math/complex.hpp | 4000 ++++++++--------- librapid/include/librapid/math/constants.hpp | 204 +- librapid/include/librapid/math/coreMath.hpp | 888 ++-- librapid/include/librapid/math/fastMath.hpp | 2 +- librapid/include/librapid/math/half.hpp | 1476 +++--- librapid/include/librapid/math/multiprec.hpp | 1752 ++++---- librapid/include/librapid/math/random.hpp | 116 +- librapid/include/librapid/math/round.hpp | 264 +- .../librapid/math/utilityFunctions.hpp | 282 +- librapid/include/librapid/math/vector.hpp | 24 +- .../include/librapid/math/vectorForward.hpp | 236 +- librapid/include/librapid/math/vectorImpl.hpp | 1886 ++++---- librapid/include/librapid/ml/activations.hpp | 374 +- .../include/librapid/opencl/kernels/abs.cl | 28 +- .../librapid/opencl/kernels/activations.cl | 64 +- .../librapid/opencl/kernels/arithmetic.cl | 232 +- .../include/librapid/opencl/kernels/dual.cl | 8 +- .../librapid/opencl/kernels/expLogPow.cl | 44 +- .../include/librapid/opencl/kernels/fill.cl | 180 +- .../librapid/opencl/kernels/floorCeilRound.cl | 36 +- .../include/librapid/opencl/kernels/negate.cl | 8 +- .../librapid/opencl/kernels/transpose.cl | 32 +- .../librapid/opencl/kernels/trigonometry.cl | 36 +- .../librapid/opencl/openclConfigure.hpp | 14 +- .../librapid/opencl/openclErrorIdentifier.hpp | 4 +- .../librapid/opencl/openclKernelProcessor.hpp | 74 +- .../include/librapid/opencl/openclStorage.hpp | 810 ++-- librapid/include/librapid/simd/vecOps.hpp | 194 +- .../include/librapid/utils/cacheLineSize.hpp | 10 +- librapid/include/librapid/utils/map.hpp | 218 +- librapid/include/librapid/utils/memUtils.hpp | 204 +- librapid/include/librapid/utils/time.hpp | 292 +- librapid/src/cacheLineSize.cpp | 76 +- librapid/src/compat.cpp | 178 +- librapid/src/cudaKernelProcessor.cpp | 74 +- librapid/src/fastMath.cpp | 50 +- librapid/src/global.cpp | 98 +- librapid/src/helper_cuda.cpp | 414 +- librapid/src/literals.cpp | 2 +- librapid/src/multiprecCasting.cpp | 32 +- librapid/src/multiprecExpLogPow.cpp | 26 +- librapid/src/multiprecFloorCeil.cpp | 6 +- librapid/src/multiprecHypot.cpp | 4 +- librapid/src/multiprecModAbs.cpp | 40 +- librapid/src/multiprecToString.cpp | 76 +- librapid/src/multiprecTrig.cpp | 64 +- librapid/src/openclConfigure.cpp | 560 +-- librapid/src/openclErrorIdentifier.cpp | 162 +- librapid/src/preMain.cpp | 52 +- test/test-array.cpp | 1224 ++--- test/test-arrayOps.cpp | 62 +- test/test-arrayView.cpp | 108 +- test/test-complex.cpp | 644 +-- test/test-cudaStorage.cpp | 256 +- test/test-fixedStorage.cpp | 330 +- test/test-mathUtilities.cpp | 42 +- test/test-multiprecision.cpp | 148 +- test/test-openCLStorage.cpp | 258 +- test/test-pseudoConstructors.cpp | 98 +- test/test-sigmoid.cpp | 78 +- test/test-sizetype.cpp | 98 +- test/test-storage.cpp | 388 +- test/test-vector.cpp | 6 +- 132 files changed, 23019 insertions(+), 22959 deletions(-) diff --git a/librapid/include/librapid/array/array.hpp b/librapid/include/librapid/array/array.hpp index ab25d50d..2cb35815 100644 --- a/librapid/include/librapid/array/array.hpp +++ b/librapid/include/librapid/array/array.hpp @@ -6,11 +6,11 @@ #include "storage.hpp" #if defined(LIBRAPID_HAS_OPENCL) -# include "../OpenCL/openclStorage.hpp" +# include "../OpenCL/openclStorage.hpp" #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) -# include "../cuda/cudaStorage.hpp" +# include "../cuda/cudaStorage.hpp" #endif // LIBRAPID_HAS_CUDA #include "arrayTypeDef.hpp" diff --git a/librapid/include/librapid/array/arrayContainer.hpp b/librapid/include/librapid/array/arrayContainer.hpp index f656181f..a6c3bb21 100644 --- a/librapid/include/librapid/array/arrayContainer.hpp +++ b/librapid/include/librapid/array/arrayContainer.hpp @@ -2,109 +2,109 @@ #define LIBRAPID_ARRAY_ARRAY_CONTAINER_HPP namespace librapid { - namespace detail { - template - struct SubscriptType { - using Scalar = T; - using Direct = const Scalar &; - using Ref = Scalar &; - }; - - template - struct SubscriptType> { - using Scalar = T; - using Direct = const Scalar &; - using Ref = Scalar &; - }; - - template - struct SubscriptType> { - using Scalar = T; - using Direct = const Scalar &; - using Ref = Scalar &; - }; + namespace detail { + template + struct SubscriptType { + using Scalar = T; + using Direct = const Scalar &; + using Ref = Scalar &; + }; + + template + struct SubscriptType> { + using Scalar = T; + using Direct = const Scalar &; + using Ref = Scalar &; + }; + + template + struct SubscriptType> { + using Scalar = T; + using Direct = const Scalar &; + using Ref = Scalar &; + }; #if defined(LIBRAPID_HAS_OPENCL) - template - struct SubscriptType> { - using Scalar = T; - using Direct = const OpenCLRef; - using Ref = OpenCLRef; - }; + template + struct SubscriptType> { + using Scalar = T; + using Direct = const OpenCLRef; + using Ref = OpenCLRef; + }; #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - template - struct SubscriptType> { - using Scalar = T; - using Direct = const detail::CudaRef; - using Ref = detail::CudaRef; - }; + template + struct SubscriptType> { + using Scalar = T; + using Direct = const detail::CudaRef; + using Ref = detail::CudaRef; + }; #endif // LIBRAPID_HAS_CUDA - } // namespace detail - - namespace typetraits { - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayContainer; - using Scalar = typename TypeInfo::Scalar; - using Packet = std::false_type; - using Backend = typename TypeInfo::Backend; - static constexpr int64_t packetWidth = 1; - static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; - static constexpr bool supportsLogical = TypeInfo::supportsLogical; - static constexpr bool supportsBinary = TypeInfo::supportsBinary; - static constexpr bool allowVectorisation = TypeInfo::packetWidth > 1; + } // namespace detail + + namespace typetraits { + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayContainer; + using Scalar = typename TypeInfo::Scalar; + using Packet = std::false_type; + using Backend = typename TypeInfo::Backend; + static constexpr int64_t packetWidth = 1; + static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; + static constexpr bool supportsLogical = TypeInfo::supportsLogical; + static constexpr bool supportsBinary = TypeInfo::supportsBinary; + static constexpr bool allowVectorisation = TypeInfo::packetWidth > 1; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = TypeInfo::CudaType; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = TypeInfo::CudaType; + static constexpr int64_t cudaPacketWidth = 1; #endif // LIBRAPID_HAS_CUDA - static constexpr bool canAlign = false; - static constexpr int64_t canMemcpy = false; - }; - - /// Evaluates as true if the input type is an ArrayContainer instance - /// \tparam T Input type - template - struct IsArrayContainer : std::false_type {}; - - template - struct IsArrayContainer, StorageScalar>> - : std::true_type {}; - - LIBRAPID_DEFINE_AS_TYPE( - typename SizeType COMMA size_t dims COMMA typename StorageScalar, - array::ArrayContainer COMMA StorageScalar>); - } // namespace typetraits - - namespace array { - template - class ArrayContainer { - public: - using StorageType = StorageType_; - using ShapeType = ShapeType_; - using StrideType = Stride; - using SizeType = typename ShapeType::SizeType; - using Scalar = typename StorageType::Scalar; - using Packet = typename typetraits::TypeInfo::Packet; - using Backend = typename typetraits::TypeInfo::Backend; - using Iterator = detail::ArrayIterator>; - - using DirectSubscriptType = typename detail::SubscriptType::Direct; - using DirectRefSubscriptType = typename detail::SubscriptType::Ref; - - /// Default constructor - ArrayContainer(); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer(const std::initializer_list &data); - - template - explicit LIBRAPID_ALWAYS_INLINE ArrayContainer(const std::vector &data); - - // clang-format off + static constexpr bool canAlign = false; + static constexpr int64_t canMemcpy = false; + }; + + /// Evaluates as true if the input type is an ArrayContainer instance + /// \tparam T Input type + template + struct IsArrayContainer : std::false_type {}; + + template + struct IsArrayContainer, StorageScalar>> + : std::true_type {}; + + LIBRAPID_DEFINE_AS_TYPE( + typename SizeType COMMA size_t dims COMMA typename StorageScalar, + array::ArrayContainer COMMA StorageScalar>); + } // namespace typetraits + + namespace array { + template + class ArrayContainer { + public: + using StorageType = StorageType_; + using ShapeType = ShapeType_; + using StrideType = Stride; + using SizeType = typename ShapeType::SizeType; + using Scalar = typename StorageType::Scalar; + using Packet = typename typetraits::TypeInfo::Packet; + using Backend = typename typetraits::TypeInfo::Backend; + using Iterator = detail::ArrayIterator>; + + using DirectSubscriptType = typename detail::SubscriptType::Direct; + using DirectRefSubscriptType = typename detail::SubscriptType::Ref; + + /// Default constructor + ArrayContainer(); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer(const std::initializer_list &data); + + template + explicit LIBRAPID_ALWAYS_INLINE ArrayContainer(const std::vector &data); + + // clang-format off #define SINIT(SUB_TYPE) std::initializer_list #define SVEC(SUB_TYPE) std::vector @@ -126,797 +126,783 @@ namespace librapid { #undef SINIT #undef SVEC - // clang-format on - - /// Constructs an array container from a shape - /// \param shape The shape of the array container - LIBRAPID_ALWAYS_INLINE explicit ArrayContainer(const ShapeType &shape); - - /// Create an array container from a shape and a scalar value. The scalar value - /// represents the value the memory is initialized with. \param shape The shape of the - /// array container \param value The value to initialize the memory with - LIBRAPID_ALWAYS_INLINE ArrayContainer(const ShapeType &shape, const Scalar &value); - - /// Allows for a fixed-size array to be constructed with a fill value - /// \param value The value to fill the array with - LIBRAPID_ALWAYS_INLINE explicit ArrayContainer(const Scalar &value); - - /// Construct an array container from a shape, which is moved, not copied. - /// \param shape The shape of the array container - LIBRAPID_ALWAYS_INLINE explicit ArrayContainer(ShapeType &&shape); - - /// \brief Reference an existing array container - /// - /// This constructor does not copy the data, but instead references the data of the - /// input array container. This means that the input array container must outlive the - /// constructed array container. Please use ``ArrayContainer::copy()`` if you want to - /// copy the data. - /// \param other The array container to reference - LIBRAPID_ALWAYS_INLINE ArrayContainer(const ArrayContainer &other) = default; - - /// Construct an array container from a temporary array container. - /// \param other The array container to move. - LIBRAPID_ALWAYS_INLINE ArrayContainer(ArrayContainer &&other) noexcept = default; - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer(const Transpose &trans); - - template - LIBRAPID_ALWAYS_INLINE - ArrayContainer(const linalg::ArrayMultiply &multiply); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer & - assign(const detail::Function &function); - - /// Construct an array container from a function object. This will assign the result of - /// the function to the array container, evaluating it accordingly. - /// \tparam desc The assignment descriptor - /// \tparam Functor_ The function type - /// \tparam Args The argument types of the function - /// \param function The function to assign - template - LIBRAPID_ALWAYS_INLINE ArrayContainer( - const detail::Function &function) LIBRAPID_RELEASE_NOEXCEPT; - - /// \brief Reference an existing array container - /// - /// This assignment operator does not copy the data, but instead references the data of - /// the input array container. This means that the input array container must outlive - /// the constructed array container. Please use ``ArrayContainer::copy()`` if you want - /// to copy the data. - /// - /// \param other The array container to reference - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator=(const ArrayContainer &other) = default; - - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator=(const Scalar &value); - - /// Assign a temporary array container to this array container. - /// \param other The array container to move. - /// \return A reference to this array container. - LIBRAPID_ALWAYS_INLINE ArrayContainer & - operator=(ArrayContainer &&other) noexcept = default; - - /// Assign a function object to this array container. This will assign the result of - /// the function to the array container, evaluating it accordingly. - /// \tparam Functor_ The function type - /// \tparam Args The argument types of the function - /// \param function The function to assign - /// \return A reference to this array container. - template - LIBRAPID_ALWAYS_INLINE ArrayContainer & - operator=(const detail::Function &function); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer & - operator=(const Transpose &transpose); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer & - operator=(const linalg::ArrayMultiply &multiply); - - /// Allow ArrayContainer objects to be initialized with a comma separated list of - /// values. This makes use of the CommaInitializer class - /// \tparam T The type of the values - /// \param value The value to set in the Array object - /// \return The comma initializer object - template - detail::CommaInitializer operator<<(const T &value); - - // template - // LIBRAPID_NODISCARD auto cast() const; - - LIBRAPID_NODISCARD ArrayContainer copy() const; - - /// Access a sub-array of this ArrayContainer instance. The sub-array will reference - /// the same memory as this ArrayContainer instance. - /// \param index The index of the sub-array - /// \return A reference to the sub-array (ArrayView) - /// \see ArrayView - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index) const; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index); - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE DirectSubscriptType - operator()(Indices... indices) const; - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE DirectRefSubscriptType - operator()(Indices... indices); - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar get() const; - - /// Return the number of dimensions of the ArrayContainer object - /// \return Number of dimensions of the ArrayContainer - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE typename ShapeType::SizeType - ndim() const noexcept; - - /// Return the shape of the array container. This is an immutable reference. - /// \return The shape of the array container. - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const ShapeType &shape() const noexcept; - - /// Return the StorageType object of the ArrayContainer - /// \return The StorageType object of the ArrayContainer - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const StorageType &storage() const noexcept; - - /// Return the StorageType object of the ArrayContainer - /// \return The StorageType object of the ArrayContainer - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE StorageType &storage() noexcept; - - /// Return a Packet object from the array's storage at a specific index. - /// \param index The index to get the packet from - /// \return A Packet object from the array's storage at a specific index - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packet(size_t index) const; - - /// Return a Scalar from the array's storage at a specific index. - /// \param index The index to get the scalar from - /// \return A Scalar from the array's storage at a specific index - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalar(size_t index) const; - - /// Write a Packet object to the array's storage at a specific index - /// \param index The index to write the packet to - /// \param value The value to write to the array's storage - LIBRAPID_ALWAYS_INLINE void writePacket(size_t index, const Packet &value); - - /// Write a Scalar to the array's storage at a specific index - /// \param index The index to write the scalar to - /// \param value The value to write to the array's storage - LIBRAPID_ALWAYS_INLINE void write(size_t index, const Scalar &value); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator+=(const T &other); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator-=(const T &other); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator*=(const T &other); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator/=(const T &other); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator%=(const T &other); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator&=(const T &other); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator|=(const T &other); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator^=(const T &other); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator<<=(const T &other); - - template - LIBRAPID_ALWAYS_INLINE ArrayContainer &operator>>=(const T &other); - - /// \brief Return an iterator to the beginning of the array container - /// \return Iterator - LIBRAPID_INLINE Iterator begin() const noexcept; - - /// \brief Return an iterator to the end of the array container - /// \return Iterator - LIBRAPID_INLINE Iterator end() const noexcept; - - /// \brief Return an iterator to the beginning of the array container - /// \return Iterator - LIBRAPID_INLINE Iterator begin(); - - /// \brief Return an iterator to the end of the array container - /// \return Iterator - LIBRAPID_INLINE Iterator end(); - - /// Return a string representation of the array container - /// \format The format to use for the string representation - /// \return A string representation of the array container - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; - - template - void fmtStr(const fmt::formatter &format, char bracket, char separator, - Ctx &ctx) const; - - private: - ShapeType m_shape; // The shape type of the array - StorageType m_storage; // The storage container of the array - }; - - template - ArrayContainer::ArrayContainer() : - m_shape(StorageType_::template defaultShape()) {} - - template - template - ArrayContainer::ArrayContainer( - const std::initializer_list &data) : - m_shape({data.size()}), - m_storage(StorageType::fromData(data)) {} - - template - template - ArrayContainer::ArrayContainer(const std::vector &data) : - m_shape({data.size()}), m_storage(StorageType::fromData(data)) {} - - template - ArrayContainer::ArrayContainer(const ShapeType &shape) : - m_shape(shape), m_storage(shape.size()) { - static_assert(!typetraits::IsFixedStorage::value, - "For a compile-time-defined shape, " - "the storage type must be " - "a FixedStorage object"); - } - - template - ArrayContainer::ArrayContainer(const ShapeType &shape, - const Scalar &value) : - m_shape(shape), - m_storage(shape.size(), value) { - static_assert(typetraits::IsStorage::value || - typetraits::IsOpenCLStorage::value || - typetraits::IsCudaStorage::value, - "For a runtime-defined shape, " - "the storage type must be " - "either a Storage or a " - "CudaStorage object"); - static_assert(!typetraits::IsFixedStorage::value, - "For a compile-time-defined shape, " - "the storage type must be " - "a FixedStorage object"); - } - - template - ArrayContainer::ArrayContainer(const Scalar &value) : - m_shape(detail::shapeFromFixedStorage(m_storage)), m_storage(value) { - static_assert(typetraits::IsFixedStorage::value, - "For a compile-time-defined shape, " - "the storage type must be " - "a FixedStorage object"); - } - - template - ArrayContainer::ArrayContainer(ShapeType_ &&shape) : - m_shape(std::forward(shape)), m_storage(m_shape.size()) {} - - template - template - ArrayContainer::ArrayContainer( - const array::Transpose &trans) { - *this = trans; - } - - template - template - ArrayContainer::ArrayContainer( - const linalg::ArrayMultiply &multiply) { - *this = multiply; - } - - template - template - auto ArrayContainer::assign( - const detail::Function &function) -> ArrayContainer & { - using FunctionType = detail::Function; - m_storage.resize(function.shape().size(), 0); - if constexpr (std::is_same_v || - std::is_same_v) { - detail::assign(*this, function); - } else { + // clang-format on + + /// Constructs an array container from a shape + /// \param shape The shape of the array container + LIBRAPID_ALWAYS_INLINE explicit ArrayContainer(const ShapeType &shape); + + /// Create an array container from a shape and a scalar value. The scalar value + /// represents the value the memory is initialized with. \param shape The shape of the + /// array container \param value The value to initialize the memory with + LIBRAPID_ALWAYS_INLINE ArrayContainer(const ShapeType &shape, const Scalar &value); + + /// Allows for a fixed-size array to be constructed with a fill value + /// \param value The value to fill the array with + LIBRAPID_ALWAYS_INLINE explicit ArrayContainer(const Scalar &value); + + /// Construct an array container from a shape, which is moved, not copied. + /// \param shape The shape of the array container + LIBRAPID_ALWAYS_INLINE explicit ArrayContainer(ShapeType &&shape); + + /// \brief Reference an existing array container + /// + /// This constructor does not copy the data, but instead references the data of the + /// input array container. This means that the input array container must outlive the + /// constructed array container. Please use ``ArrayContainer::copy()`` if you want to + /// copy the data. + /// \param other The array container to reference + LIBRAPID_ALWAYS_INLINE ArrayContainer(const ArrayContainer &other) = default; + + /// Construct an array container from a temporary array container. + /// \param other The array container to move. + LIBRAPID_ALWAYS_INLINE ArrayContainer(ArrayContainer &&other) noexcept = default; + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer(const Transpose &trans); + + template + LIBRAPID_ALWAYS_INLINE + ArrayContainer(const linalg::ArrayMultiply &multiply); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer & + assign(const detail::Function &function); + + /// Construct an array container from a function object. This will assign the result of + /// the function to the array container, evaluating it accordingly. + /// \tparam desc The assignment descriptor + /// \tparam Functor_ The function type + /// \tparam Args The argument types of the function + /// \param function The function to assign + template + LIBRAPID_ALWAYS_INLINE ArrayContainer( + const detail::Function &function) LIBRAPID_RELEASE_NOEXCEPT; + + /// \brief Reference an existing array container + /// + /// This assignment operator does not copy the data, but instead references the data of + /// the input array container. This means that the input array container must outlive + /// the constructed array container. Please use ``ArrayContainer::copy()`` if you want + /// to copy the data. + /// + /// \param other The array container to reference + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator=(const ArrayContainer &other) = default; + + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator=(const Scalar &value); + + /// Assign a temporary array container to this array container. + /// \param other The array container to move. + /// \return A reference to this array container. + LIBRAPID_ALWAYS_INLINE ArrayContainer & + operator=(ArrayContainer &&other) noexcept = default; + + /// Assign a function object to this array container. This will assign the result of + /// the function to the array container, evaluating it accordingly. + /// \tparam Functor_ The function type + /// \tparam Args The argument types of the function + /// \param function The function to assign + /// \return A reference to this array container. + template + LIBRAPID_ALWAYS_INLINE ArrayContainer & + operator=(const detail::Function &function); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer & + operator=(const Transpose &transpose); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer & + operator=(const linalg::ArrayMultiply &multiply); + + /// Allow ArrayContainer objects to be initialized with a comma separated list of + /// values. This makes use of the CommaInitializer class + /// \tparam T The type of the values + /// \param value The value to set in the Array object + /// \return The comma initializer object + template + detail::CommaInitializer operator<<(const T &value); + + // template + // LIBRAPID_NODISCARD auto cast() const; + + LIBRAPID_NODISCARD ArrayContainer copy() const; + + /// Access a sub-array of this ArrayContainer instance. The sub-array will reference + /// the same memory as this ArrayContainer instance. + /// \param index The index of the sub-array + /// \return A reference to the sub-array (ArrayView) + /// \see ArrayView + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index) const; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index); + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE DirectSubscriptType + operator()(Indices... indices) const; + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE DirectRefSubscriptType + operator()(Indices... indices); + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar get() const; + + /// Return the number of dimensions of the ArrayContainer object + /// \return Number of dimensions of the ArrayContainer + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE typename ShapeType::SizeType + ndim() const noexcept; + + /// Return the shape of the array container. This is an immutable reference. + /// \return The shape of the array container. + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const ShapeType &shape() const noexcept; + + /// Return the StorageType object of the ArrayContainer + /// \return The StorageType object of the ArrayContainer + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const StorageType &storage() const noexcept; + + /// Return the StorageType object of the ArrayContainer + /// \return The StorageType object of the ArrayContainer + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE StorageType &storage() noexcept; + + /// Return a Packet object from the array's storage at a specific index. + /// \param index The index to get the packet from + /// \return A Packet object from the array's storage at a specific index + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packet(size_t index) const; + + /// Return a Scalar from the array's storage at a specific index. + /// \param index The index to get the scalar from + /// \return A Scalar from the array's storage at a specific index + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalar(size_t index) const; + + /// Write a Packet object to the array's storage at a specific index + /// \param index The index to write the packet to + /// \param value The value to write to the array's storage + LIBRAPID_ALWAYS_INLINE void writePacket(size_t index, const Packet &value); + + /// Write a Scalar to the array's storage at a specific index + /// \param index The index to write the scalar to + /// \param value The value to write to the array's storage + LIBRAPID_ALWAYS_INLINE void write(size_t index, const Scalar &value); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator+=(const T &other); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator-=(const T &other); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator*=(const T &other); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator/=(const T &other); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator%=(const T &other); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator&=(const T &other); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator|=(const T &other); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator^=(const T &other); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator<<=(const T &other); + + template + LIBRAPID_ALWAYS_INLINE ArrayContainer &operator>>=(const T &other); + + /// \brief Return an iterator to the beginning of the array container + /// \return Iterator + LIBRAPID_INLINE Iterator begin() const noexcept; + + /// \brief Return an iterator to the end of the array container + /// \return Iterator + LIBRAPID_INLINE Iterator end() const noexcept; + + /// \brief Return an iterator to the beginning of the array container + /// \return Iterator + LIBRAPID_INLINE Iterator begin(); + + /// \brief Return an iterator to the end of the array container + /// \return Iterator + LIBRAPID_INLINE Iterator end(); + + template + void str(const fmt::formatter &format, char bracket, char separator, + Ctx &ctx) const; + + private: + ShapeType m_shape; // The shape type of the array + StorageType m_storage; // The storage container of the array + }; + + template + ArrayContainer::ArrayContainer() : + m_shape(StorageType_::template defaultShape()) {} + + template + template + ArrayContainer::ArrayContainer( + const std::initializer_list &data) : + m_shape({data.size()}), + m_storage(StorageType::fromData(data)) {} + + template + template + ArrayContainer::ArrayContainer(const std::vector &data) : + m_shape({data.size()}), m_storage(StorageType::fromData(data)) {} + + template + ArrayContainer::ArrayContainer(const ShapeType &shape) : + m_shape(shape), m_storage(shape.size()) { + static_assert(!typetraits::IsFixedStorage::value, + "For a compile-time-defined shape, " + "the storage type must be " + "a FixedStorage object"); + } + + template + ArrayContainer::ArrayContainer(const ShapeType &shape, + const Scalar &value) : + m_shape(shape), + m_storage(shape.size(), value) { + static_assert(typetraits::IsStorage::value || + typetraits::IsOpenCLStorage::value || + typetraits::IsCudaStorage::value, + "For a runtime-defined shape, " + "the storage type must be " + "either a Storage or a " + "CudaStorage object"); + static_assert(!typetraits::IsFixedStorage::value, + "For a compile-time-defined shape, " + "the storage type must be " + "a FixedStorage object"); + } + + template + ArrayContainer::ArrayContainer(const Scalar &value) : + m_shape(detail::shapeFromFixedStorage(m_storage)), m_storage(value) { + static_assert(typetraits::IsFixedStorage::value, + "For a compile-time-defined shape, " + "the storage type must be " + "a FixedStorage object"); + } + + template + ArrayContainer::ArrayContainer(ShapeType_ &&shape) : + m_shape(std::forward(shape)), m_storage(m_shape.size()) {} + + template + template + ArrayContainer::ArrayContainer( + const array::Transpose &trans) { + *this = trans; + } + + template + template + ArrayContainer::ArrayContainer( + const linalg::ArrayMultiply &multiply) { + *this = multiply; + } + + template + template + auto ArrayContainer::assign( + const detail::Function &function) -> ArrayContainer & { + using FunctionType = detail::Function; + m_storage.resize(function.shape().size(), 0); + if constexpr (std::is_same_v || + std::is_same_v) { + detail::assign(*this, function); + } else { #if !defined(LIBRAPID_OPTIMISE_SMALL_ARRAYS) - if (m_storage.size() > global::multithreadThreshold && global::numThreads > 1) - detail::assignParallel(*this, function); - else + if (m_storage.size() > global::multithreadThreshold && global::numThreads > 1) + detail::assignParallel(*this, function); + else #endif // LIBRAPID_OPTIMISE_SMALL_ARRAYS - detail::assign(*this, function); - } - return *this; - } - - template - template - ArrayContainer::ArrayContainer( - const detail::Function &function) LIBRAPID_RELEASE_NOEXCEPT - : m_shape(function.shape()), - m_storage(m_shape.size()) { - assign(function); - } - - template - template - auto ArrayContainer::operator=( - const detail::Function &function) -> ArrayContainer & { - return assign(function); - } - - template - template - auto ArrayContainer::operator=( - const Transpose &transpose) -> ArrayContainer & { - m_shape = transpose.shape(); - m_storage.resize(m_shape.size(), 0); - transpose.applyTo(*this); - return *this; - } - - template - template - auto ArrayContainer::operator=( - const linalg::ArrayMultiply &arrayMultiply) -> ArrayContainer & { - m_shape = arrayMultiply.shape(); - m_storage.resize(m_shape.size(), 0); - arrayMultiply.applyTo(*this); - return *this; - } - - template - auto ArrayContainer::operator=(const Scalar &value) - -> ArrayContainer & { - LIBRAPID_ASSERT(m_shape.ndim() == 0, "Cannot assign a scalar to an array"); - m_storage[0] = value; - return *this; - } - - template - template - auto ArrayContainer::operator<<(const T &value) - -> detail::CommaInitializer { - return detail::CommaInitializer(*this, static_cast(value)); - } - - template - auto ArrayContainer::copy() const -> ArrayContainer { - ArrayContainer res(m_shape); - res.m_storage = m_storage.copy(); - return res; - } - - template - auto ArrayContainer::operator[](int64_t index) const { - LIBRAPID_ASSERT( - index >= 0 && index < static_cast(m_shape[0]), - "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}", - index, - m_shape[0]); - - if constexpr (typetraits::IsOpenCLStorage::value) { + detail::assign(*this, function); + } + return *this; + } + + template + template + ArrayContainer::ArrayContainer( + const detail::Function &function) LIBRAPID_RELEASE_NOEXCEPT + : m_shape(function.shape()), + m_storage(m_shape.size()) { + assign(function); + } + + template + template + auto ArrayContainer::operator=( + const detail::Function &function) -> ArrayContainer & { + return assign(function); + } + + template + template + auto ArrayContainer::operator=( + const Transpose &transpose) -> ArrayContainer & { + m_shape = transpose.shape(); + m_storage.resize(m_shape.size(), 0); + transpose.applyTo(*this); + return *this; + } + + template + template + auto ArrayContainer::operator=( + const linalg::ArrayMultiply &arrayMultiply) -> ArrayContainer & { + m_shape = arrayMultiply.shape(); + m_storage.resize(m_shape.size(), 0); + arrayMultiply.applyTo(*this); + return *this; + } + + template + auto ArrayContainer::operator=(const Scalar &value) + -> ArrayContainer & { + LIBRAPID_ASSERT(m_shape.ndim() == 0, "Cannot assign a scalar to an array"); + m_storage[0] = value; + return *this; + } + + template + template + auto ArrayContainer::operator<<(const T &value) + -> detail::CommaInitializer { + return detail::CommaInitializer(*this, static_cast(value)); + } + + template + auto ArrayContainer::copy() const -> ArrayContainer { + ArrayContainer res(m_shape); + res.m_storage = m_storage.copy(); + return res; + } + + template + auto ArrayContainer::operator[](int64_t index) const { + LIBRAPID_ASSERT( + index >= 0 && index < static_cast(m_shape[0]), + "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}", + index, + m_shape[0]); + + if constexpr (typetraits::IsOpenCLStorage::value) { #if defined(LIBRAPID_HAS_OPENCL) - ArrayContainer res; - res.m_shape = m_shape.subshape(1, ndim()); - auto subSize = res.shape().size(); - int64_t storageSize = sizeof(typename StorageType_::Scalar); - cl_buffer_region region {index * subSize * storageSize, subSize * storageSize}; - res.m_storage = - StorageType_(m_storage.data().createSubBuffer( - StorageType_::bufferFlags, CL_BUFFER_CREATE_TYPE_REGION, ®ion), - subSize, - false); - return res; + ArrayContainer res; + res.m_shape = m_shape.subshape(1, ndim()); + auto subSize = res.shape().size(); + int64_t storageSize = sizeof(typename StorageType_::Scalar); + cl_buffer_region region {index * subSize * storageSize, subSize * storageSize}; + res.m_storage = + StorageType_(m_storage.data().createSubBuffer( + StorageType_::bufferFlags, CL_BUFFER_CREATE_TYPE_REGION, ®ion), + subSize, + false); + return res; #else - LIBRAPID_ERROR("OpenCL support not enabled"); + LIBRAPID_ERROR("OpenCL support not enabled"); #endif // LIBRAPID_HAS_OPENCL - } else if constexpr (typetraits::IsCudaStorage::value) { + } else if constexpr (typetraits::IsCudaStorage::value) { #if defined(LIBRAPID_HAS_CUDA) - ArrayContainer res; - res.m_shape = m_shape.subshape(1, ndim()); - auto subSize = res.shape().size(); - Scalar *begin = m_storage.begin().get() + index * subSize; - res.m_storage = StorageType_(begin, subSize, false); - return res; + ArrayContainer res; + res.m_shape = m_shape.subshape(1, ndim()); + auto subSize = res.shape().size(); + Scalar *begin = m_storage.begin().get() + index * subSize; + res.m_storage = StorageType_(begin, subSize, false); + return res; #else - LIBRAPID_ERROR("CUDA support not enabled"); + LIBRAPID_ERROR("CUDA support not enabled"); #endif // LIBRAPID_HAS_CUDA - } else if constexpr (typetraits::IsFixedStorage::value) { - return ArrayView(*this)[index]; - } else { - ArrayContainer res; - res.m_shape = m_shape.subshape(1, ndim()); - auto subSize = res.shape().size(); - Scalar *begin = m_storage.begin() + index * subSize; - Scalar *end = begin + subSize; - res.m_storage = StorageType_(begin, end, false); - return res; - } - } - - template - auto ArrayContainer::operator[](int64_t index) { - LIBRAPID_ASSERT( - index >= 0 && index < static_cast(m_shape[0]), - "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}", - index, - m_shape[0]); - - if constexpr (typetraits::IsOpenCLStorage::value) { + } else if constexpr (typetraits::IsFixedStorage::value) { + return ArrayView(*this)[index]; + } else { + ArrayContainer res; + res.m_shape = m_shape.subshape(1, ndim()); + auto subSize = res.shape().size(); + Scalar *begin = m_storage.begin() + index * subSize; + Scalar *end = begin + subSize; + res.m_storage = StorageType_(begin, end, false); + return res; + } + } + + template + auto ArrayContainer::operator[](int64_t index) { + LIBRAPID_ASSERT( + index >= 0 && index < static_cast(m_shape[0]), + "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}", + index, + m_shape[0]); + + if constexpr (typetraits::IsOpenCLStorage::value) { #if defined(LIBRAPID_HAS_OPENCL) - ArrayContainer res; - res.m_shape = m_shape.subshape(1, ndim()); - auto subSize = res.shape().size(); - int64_t storageSize = sizeof(typename StorageType_::Scalar); - cl_buffer_region region {index * subSize * storageSize, subSize * storageSize}; - res.m_storage.set( - StorageType_(m_storage.data().createSubBuffer( - StorageType_::bufferFlags, CL_BUFFER_CREATE_TYPE_REGION, ®ion), - subSize, - false)); - return res; + ArrayContainer res; + res.m_shape = m_shape.subshape(1, ndim()); + auto subSize = res.shape().size(); + int64_t storageSize = sizeof(typename StorageType_::Scalar); + cl_buffer_region region {index * subSize * storageSize, subSize * storageSize}; + res.m_storage.set( + StorageType_(m_storage.data().createSubBuffer( + StorageType_::bufferFlags, CL_BUFFER_CREATE_TYPE_REGION, ®ion), + subSize, + false)); + return res; #else - LIBRAPID_ERROR("OpenCL support not enabled"); + LIBRAPID_ERROR("OpenCL support not enabled"); #endif // LIBRAPID_HAS_OPENCL - } else if constexpr (typetraits::IsCudaStorage::value) { + } else if constexpr (typetraits::IsCudaStorage::value) { #if defined(LIBRAPID_HAS_CUDA) - ArrayContainer res; - res.m_shape = m_shape.subshape(1, ndim()); - auto subSize = res.shape().size(); - Scalar *begin = m_storage.begin().get() + index * subSize; - res.m_storage.set(StorageType_(begin, subSize, false)); - return res; + ArrayContainer res; + res.m_shape = m_shape.subshape(1, ndim()); + auto subSize = res.shape().size(); + Scalar *begin = m_storage.begin().get() + index * subSize; + res.m_storage.set(StorageType_(begin, subSize, false)); + return res; #else - LIBRAPID_ERROR("CUDA support not enabled"); + LIBRAPID_ERROR("CUDA support not enabled"); #endif // LIBRAPID_HAS_CUDA - } else if constexpr (typetraits::IsFixedStorage::value) { - return ArrayView(*this)[index]; - } else { - ArrayContainer res; - res.m_shape = m_shape.subshape(1, ndim()); - auto subSize = res.shape().size(); - Scalar *begin = m_storage.begin() + index * subSize; - Scalar *end = begin + subSize; - res.m_storage.set(StorageType_(begin, end, false)); - return res; - } - } - - template - template - auto ArrayContainer::operator()(Indices... indices) const - -> DirectSubscriptType { - LIBRAPID_ASSERT( - m_shape.ndim() == sizeof...(Indices), - "ArrayContainer::operator() called with {} indices, but array has {} dimensions", - sizeof...(Indices), - m_shape.ndim()); - - int64_t index = 0; - for (int64_t i : {indices...}) { - LIBRAPID_ASSERT( - i >= 0 && i < static_cast(m_shape[index]), - "Index {} out of bounds in ArrayContainer::operator() with dimension={}", - i, - m_shape[index]); - index = index * m_shape[index] + i; - } - return m_storage[index]; - } - - template - template - auto ArrayContainer::operator()(Indices... indices) - -> DirectRefSubscriptType { - LIBRAPID_ASSERT( - m_shape.ndim() == sizeof...(Indices), - "ArrayContainer::operator() called with {} indices, but array has {} dimensions", - sizeof...(Indices), - m_shape.ndim()); - - int64_t index = 0; - int64_t count = 0; - for (int64_t i : {indices...}) { - LIBRAPID_ASSERT( - i >= 0 && i < static_cast(m_shape[count]), - "Index {} out of bounds in ArrayContainer::operator() with dimension={}", - i, - m_shape[index]); - index = index * m_shape[count++] + i; - } - return m_storage[index]; - } - - template - auto ArrayContainer::get() const -> Scalar { - LIBRAPID_ASSERT(m_shape.ndim() == 0, - "Can only cast a scalar ArrayView to a salar object"); - return scalar(0); - } - - template - auto ArrayContainer::ndim() const noexcept -> - typename ShapeType_::SizeType { - return m_shape.ndim(); - } - - template - auto ArrayContainer::shape() const noexcept -> const ShapeType & { - return m_shape; - } - - template - auto ArrayContainer::storage() const noexcept - -> const StorageType & { - return m_storage; - } - - template - auto ArrayContainer::storage() noexcept -> StorageType & { - return m_storage; - } - - template - auto ArrayContainer::packet(size_t index) const -> Packet { + } else if constexpr (typetraits::IsFixedStorage::value) { + return ArrayView(*this)[index]; + } else { + ArrayContainer res; + res.m_shape = m_shape.subshape(1, ndim()); + auto subSize = res.shape().size(); + Scalar *begin = m_storage.begin() + index * subSize; + Scalar *end = begin + subSize; + res.m_storage.set(StorageType_(begin, end, false)); + return res; + } + } + + template + template + auto ArrayContainer::operator()(Indices... indices) const + -> DirectSubscriptType { + LIBRAPID_ASSERT( + m_shape.ndim() == sizeof...(Indices), + "ArrayContainer::operator() called with {} indices, but array has {} dimensions", + sizeof...(Indices), + m_shape.ndim()); + + int64_t index = 0; + for (int64_t i : {indices...}) { + LIBRAPID_ASSERT( + i >= 0 && i < static_cast(m_shape[index]), + "Index {} out of bounds in ArrayContainer::operator() with dimension={}", + i, + m_shape[index]); + index = index * m_shape[index] + i; + } + return m_storage[index]; + } + + template + template + auto ArrayContainer::operator()(Indices... indices) + -> DirectRefSubscriptType { + LIBRAPID_ASSERT( + m_shape.ndim() == sizeof...(Indices), + "ArrayContainer::operator() called with {} indices, but array has {} dimensions", + sizeof...(Indices), + m_shape.ndim()); + + int64_t index = 0; + int64_t count = 0; + for (int64_t i : {indices...}) { + LIBRAPID_ASSERT( + i >= 0 && i < static_cast(m_shape[count]), + "Index {} out of bounds in ArrayContainer::operator() with dimension={}", + i, + m_shape[index]); + index = index * m_shape[count++] + i; + } + return m_storage[index]; + } + + template + auto ArrayContainer::get() const -> Scalar { + LIBRAPID_ASSERT(m_shape.ndim() == 0, + "Can only cast a scalar ArrayView to a salar object"); + return scalar(0); + } + + template + auto ArrayContainer::ndim() const noexcept -> + typename ShapeType_::SizeType { + return m_shape.ndim(); + } + + template + auto ArrayContainer::shape() const noexcept -> const ShapeType & { + return m_shape; + } + + template + auto ArrayContainer::storage() const noexcept + -> const StorageType & { + return m_storage; + } + + template + auto ArrayContainer::storage() noexcept -> StorageType & { + return m_storage; + } + + template + auto ArrayContainer::packet(size_t index) const -> Packet { #if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_OSX) - // On MacOS (and other platforms??) we cannot use aligned loads in arrays due to one - // annoying edge case. Normally, all SIMD loads will be aligned to a 64-byte boundary. - // Say, however, this array is a sub-array of a larger array. If the outer dimension - // of the larger array does not result in a 64-byte alignment, the data of *this* array - // will not be correctly aligned, hence causing a segfault. - return xsimd::load_aligned(m_storage.begin() + index); + // On MacOS (and other platforms??) we cannot use aligned loads in arrays due to one + // annoying edge case. Normally, all SIMD loads will be aligned to a 64-byte boundary. + // Say, however, this array is a sub-array of a larger array. If the outer dimension + // of the larger array does not result in a 64-byte alignment, the data of *this* array + // will not be correctly aligned, hence causing a segfault. + return xsimd::load_aligned(m_storage.begin() + index); #else - return xsimd::load_unaligned(m_storage.begin() + index); + return xsimd::load_unaligned(m_storage.begin() + index); #endif - } + } - template - auto ArrayContainer::scalar(size_t index) const -> Scalar { - return m_storage[index]; - } + template + auto ArrayContainer::scalar(size_t index) const -> Scalar { + return m_storage[index]; + } - template - void ArrayContainer::writePacket(size_t index, - const Packet &value) { + template + void ArrayContainer::writePacket(size_t index, + const Packet &value) { #if defined(LIBRAPID_NATIVE_ARCH) - value.store_aligned(m_storage.begin() + index); + value.store_aligned(m_storage.begin() + index); #else - value.store_unaligned(m_storage.begin() + index); + value.store_unaligned(m_storage.begin() + index); #endif - } - - template - void ArrayContainer::write(size_t index, const Scalar &value) { - m_storage[index] = value; - } - - template - template - auto ArrayContainer::operator+=(const T &value) - -> ArrayContainer & { - *this = *this + value; - return *this; - } - - template - template - auto ArrayContainer::operator-=(const T &value) - -> ArrayContainer & { - *this = *this - value; - return *this; - } - - template - template - auto ArrayContainer::operator*=(const T &value) - -> ArrayContainer & { - *this = *this * value; - return *this; - } - - template - template - auto ArrayContainer::operator/=(const T &value) - -> ArrayContainer & { - *this = *this / value; - return *this; - } - - template - template - auto ArrayContainer::operator%=(const T &value) - -> ArrayContainer & { - *this = *this % value; - return *this; - } - - template - template - auto ArrayContainer::operator&=(const T &value) - -> ArrayContainer & { - *this = *this & value; - return *this; - } - - template - template - auto ArrayContainer::operator|=(const T &value) - -> ArrayContainer & { - *this = *this | value; - return *this; - } - - template - template - auto ArrayContainer::operator^=(const T &value) - -> ArrayContainer & { - *this = *this ^ value; - return *this; - } - - template - template - auto ArrayContainer::operator<<=(const T &value) - -> ArrayContainer & { - *this = *this << value; - return *this; - } - - template - template - auto ArrayContainer::operator>>=(const T &value) - -> ArrayContainer & { - *this = *this >> value; - return *this; - } - - template - auto ArrayContainer::begin() const noexcept -> Iterator { - return Iterator(ArrayView(*this), 0); - } - - template - auto ArrayContainer::end() const noexcept -> Iterator { - return Iterator(ArrayView(*this), m_shape[0]); - } - - template - auto ArrayContainer::begin() -> Iterator { - return Iterator(ArrayView(*this), 0); - } - - template - auto ArrayContainer::end() -> Iterator { - return Iterator(ArrayView(*this), m_shape[0]); - } - - template - std::string ArrayContainer::str(const std::string &format) const { - return ArrayView(*this).str(format); - } - - template - template - void ArrayContainer::fmtStr(const fmt::formatter &format, - char bracket, char separator, - Ctx &ctx) const { - ArrayView(*this).fmtStr(format, bracket, separator, ctx); - } - } // namespace array - - namespace detail { - template - struct IsArrayType { - static constexpr bool val = false; - }; - - template - struct IsArrayType> { - static constexpr bool val = true; - }; - - template - struct IsArrayType> { - static constexpr bool val = true; - }; - - template - struct IsArrayType> { - static constexpr bool val = true; - }; - - template - struct ContainsArrayType { - static constexpr auto evaluator() { - if constexpr (sizeof...(Types) == 0) - return IsArrayType::val; - else - return IsArrayType::val || ContainsArrayType::val; - }; - - static constexpr bool val = evaluator(); - }; - }; // namespace detail + } + + template + void ArrayContainer::write(size_t index, const Scalar &value) { + m_storage[index] = value; + } + + template + template + auto ArrayContainer::operator+=(const T &value) + -> ArrayContainer & { + *this = *this + value; + return *this; + } + + template + template + auto ArrayContainer::operator-=(const T &value) + -> ArrayContainer & { + *this = *this - value; + return *this; + } + + template + template + auto ArrayContainer::operator*=(const T &value) + -> ArrayContainer & { + *this = *this * value; + return *this; + } + + template + template + auto ArrayContainer::operator/=(const T &value) + -> ArrayContainer & { + *this = *this / value; + return *this; + } + + template + template + auto ArrayContainer::operator%=(const T &value) + -> ArrayContainer & { + *this = *this % value; + return *this; + } + + template + template + auto ArrayContainer::operator&=(const T &value) + -> ArrayContainer & { + *this = *this & value; + return *this; + } + + template + template + auto ArrayContainer::operator|=(const T &value) + -> ArrayContainer & { + *this = *this | value; + return *this; + } + + template + template + auto ArrayContainer::operator^=(const T &value) + -> ArrayContainer & { + *this = *this ^ value; + return *this; + } + + template + template + auto ArrayContainer::operator<<=(const T &value) + -> ArrayContainer & { + *this = *this << value; + return *this; + } + + template + template + auto ArrayContainer::operator>>=(const T &value) + -> ArrayContainer & { + *this = *this >> value; + return *this; + } + + template + auto ArrayContainer::begin() const noexcept -> Iterator { + return Iterator(ArrayView(*this), 0); + } + + template + auto ArrayContainer::end() const noexcept -> Iterator { + return Iterator(ArrayView(*this), m_shape[0]); + } + + template + auto ArrayContainer::begin() -> Iterator { + return Iterator(ArrayView(*this), 0); + } + + template + auto ArrayContainer::end() -> Iterator { + return Iterator(ArrayView(*this), m_shape[0]); + } + + template + template + void ArrayContainer::str(const fmt::formatter &format, + char bracket, char separator, + Ctx &ctx) const { + ArrayView(*this).str(format, bracket, separator, ctx); + } + } // namespace array + + namespace detail { + template + struct IsArrayType { + static constexpr bool val = false; + }; + + template + struct IsArrayType> { + static constexpr bool val = true; + }; + + template + struct IsArrayType> { + static constexpr bool val = true; + }; + + template + struct IsArrayType> { + static constexpr bool val = true; + }; + + template + struct ContainsArrayType { + static constexpr auto evaluator() { + if constexpr (sizeof...(Types) == 0) + return IsArrayType::val; + else + return IsArrayType::val || ContainsArrayType::val; + }; + + static constexpr bool val = evaluator(); + }; + }; // namespace detail } // namespace librapid // Support FMT printing #ifdef FMT_API template struct fmt::formatter> { - using Type = librapid::array::ArrayContainer; - using Scalar = typename librapid::typetraits::TypeInfo::Scalar; - using Formatter = fmt::formatter; - Formatter m_formatter; - char m_bracket; - char m_separator; - - template - FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { - // Custom format options: - // - 'r' for round brackets - // - 's' for square brackets - // - 'c' for curly brackets - // - 'a' for angle brackets - // - 'p' for pipe brackets - // - "-," for comma separator - // - "-;" for semicolon separator - // - "-:" for colon separator - // - "-|" for pipe separator - // - "-_" for underscore separator - - auto it = ctx.begin(), end = ctx.end(); - if (it != end && (*it == 'r' || *it == 's' || *it == 'c' || *it == 'a' || *it == 'p')) { - m_bracket = *it++; - } else { - m_bracket = 's'; - } - - if (it != end && *it == '-') { - ++it; - if (it != end) { - m_separator = *it++; - } else { - m_separator = ','; - } - } else { - m_separator = ' '; - } - - ctx.advance_to(it); - - return m_formatter.parse(ctx); - } - - template - FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) { - val.fmtStr(m_formatter, m_bracket, m_separator, ctx); - return ctx.out(); - } + using Type = librapid::array::ArrayContainer; + using Scalar = typename librapid::typetraits::TypeInfo::Scalar; + using Formatter = fmt::formatter; + Formatter m_formatter; + char m_bracket = 's'; + char m_separator = ' '; + + template + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { + // Custom format options: + // - "~r" for round brackets + // - "~s" for square brackets + // - "~c" for curly brackets + // - "~a" for angle brackets + // - "~p" for pipe brackets + // - "-," for comma separator + // - "-;" for semicolon separator + // - "-:" for colon separator + // - "-|" for pipe separator + // - "-_" for underscore separator + + auto it = ctx.begin(), end = ctx.end(); + if (it != end && *it == '~') { + ++it; + if (it != end && (*it == 'r' || *it == 's' || *it == 'c' || *it == 'a' || *it == 'p')) { + m_bracket = *it++; + } + } + + if (it != end && *it == '-') { + ++it; + if (it != end) { m_separator = *it++; } + } + + ctx.advance_to(it); + + return m_formatter.parse(ctx); + } + + template + FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) { + val.str(m_formatter, m_bracket, m_separator, ctx); + return ctx.out(); + } }; template -std::ostream &operator<<(std::ostream &os, - const librapid::array::ArrayContainer &object) { - os << "NOT IMPLEMENTED!"; // object.str(); - return os; +auto operator<<(std::ostream &os, + const librapid::array::ArrayContainer &object) + -> std::ostream & { + os << fmt::format("{}", object); + return os; } LIBRAPID_SIMPLE_IO_NORANGE(typename ShapeType_ COMMA typename StorageType_, - librapid::array::ArrayContainer) + librapid::array::ArrayContainer) #endif // FMT_API #endif // LIBRAPID_ARRAY_ARRAY_CONTAINER_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/arrayFromData.hpp b/librapid/include/librapid/array/arrayFromData.hpp index db293eb7..472fbe3d 100644 --- a/librapid/include/librapid/array/arrayFromData.hpp +++ b/librapid/include/librapid/array/arrayFromData.hpp @@ -2,68 +2,68 @@ #define LIBRAPID_ARRAY_FROM_DATA_HPP namespace librapid { - /// \brief Create an array from a list of values (possibly multi-dimensional) - /// - /// Create a new array from a potentially nested list of values. It is possible to specify the - /// data type of the Array with the \p Scalar template parameter. If no type is specified, the - /// type will be inferred from the data. The backend on which the Array is created can also be - /// specified with the \p Backend template parameter. If no backend is specified, the Array will - /// be created on the CPU. - /// - /// \tparam Scalar The type of the Array - /// \tparam Backend The backend on which the Array is created - /// \param data The data from which the Array is created - /// \return The created Array - template - auto array::ArrayContainer::fromData(const std::initializer_list &data) - -> ArrayContainer { - LIBRAPID_ASSERT(data.size() > 0, "Array must have at least one element"); - return ArrayContainer(data); - } + /// \brief Create an array from a list of values (possibly multi-dimensional) + /// + /// Create a new array from a potentially nested list of values. It is possible to specify the + /// data type of the Array with the \p Scalar template parameter. If no type is specified, the + /// type will be inferred from the data. The backend on which the Array is created can also be + /// specified with the \p Backend template parameter. If no backend is specified, the Array will + /// be created on the CPU. + /// + /// \tparam Scalar The type of the Array + /// \tparam Backend The backend on which the Array is created + /// \param data The data from which the Array is created + /// \return The created Array + template + auto array::ArrayContainer::fromData(const std::initializer_list &data) + -> ArrayContainer { + LIBRAPID_ASSERT(data.size() > 0, "Array must have at least one element"); + return ArrayContainer(data); + } - template - auto array::ArrayContainer::fromData(const std::vector &data) - -> ArrayContainer { - LIBRAPID_ASSERT(data.size() > 0, "Array must have at least one element"); - return ArrayContainer(data); - } + template + auto array::ArrayContainer::fromData(const std::vector &data) + -> ArrayContainer { + LIBRAPID_ASSERT(data.size() > 0, "Array must have at least one element"); + return ArrayContainer(data); + } - template - auto array::ArrayContainer::fromData( - const std::initializer_list> &data) -> ArrayContainer { - LIBRAPID_ASSERT(data.size() > 0, "Cannot create a zero-sized array"); - auto newShape = ShapeType({data.size(), data.begin()->size()}); + template + auto array::ArrayContainer::fromData( + const std::initializer_list> &data) -> ArrayContainer { + LIBRAPID_ASSERT(data.size() > 0, "Cannot create a zero-sized array"); + auto newShape = ShapeType({data.size(), data.begin()->size()}); #if defined(LIBRAPID_ENABLE_ASSERT) - for (size_t i = 0; i < data.size(); ++i) { - LIBRAPID_ASSERT(data.begin()[i].size() == newShape[1], - "Arrays must have consistent shapes"); - } + for (size_t i = 0; i < data.size(); ++i) { + LIBRAPID_ASSERT(data.begin()[i].size() == newShape[1], + "Arrays must have consistent shapes"); + } #endif - auto res = ArrayContainer(newShape); - int64_t index = 0; - for (const auto &item : data) res[index++] = fromData(item); - return res; - } + auto res = ArrayContainer(newShape); + int64_t index = 0; + for (const auto &item : data) res[index++] = fromData(item); + return res; + } - template - auto - array::ArrayContainer::fromData(const std::vector> &data) - -> ArrayContainer { - LIBRAPID_ASSERT(data.size() > 0, "Cannot create a zero-sized array"); - auto newShape = ShapeType({data.size(), data.begin()->size()}); + template + auto + array::ArrayContainer::fromData(const std::vector> &data) + -> ArrayContainer { + LIBRAPID_ASSERT(data.size() > 0, "Cannot create a zero-sized array"); + auto newShape = ShapeType({data.size(), data.begin()->size()}); #if defined(LIBRAPID_ENABLE_ASSERT) - for (size_t i = 0; i < data.size(); ++i) { - LIBRAPID_ASSERT(data.begin()[i].size() == newShape[1], - "Arrays must have consistent shapes"); - } + for (size_t i = 0; i < data.size(); ++i) { + LIBRAPID_ASSERT(data.begin()[i].size() == newShape[1], + "Arrays must have consistent shapes"); + } #endif - auto res = ArrayContainer(newShape); - int64_t index = 0; - for (const auto &item : data) res[index++] = fromData(item); - return res; - } + auto res = ArrayContainer(newShape); + int64_t index = 0; + for (const auto &item : data) res[index++] = fromData(item); + return res; + } - //#define HIGHER_DIMENSIONAL_FROM_DATA(TYPE) \ + //#define HIGHER_DIMENSIONAL_FROM_DATA(TYPE) \ // template \ // auto array::ArrayContainer::fromData(const TYPE &data) -> ArrayContainer { \ // LIBRAPID_ASSERT(data.size() > 0, "Cannot create a zero-sized array"); \ @@ -85,40 +85,40 @@ namespace librapid { // } #define HIGHER_DIMENSIONAL_FROM_DATA(TYPE) \ - template \ - auto array::ArrayContainer::fromData(const TYPE &data) -> ArrayContainer { \ - LIBRAPID_ASSERT(data.size() > 0, "Cannot create a zero-sized array"); \ - std::vector tmp(data.size()); \ - int64_t index = 0; \ - for (const auto &item : data) tmp[index++] = fromData(item); \ - auto zeroShape = tmp[0].shape(); \ - for (int64_t i = 0; i < data.size(); ++i) \ - LIBRAPID_ASSERT(tmp[i].shape().operator==(zeroShape), \ - "Arrays must have consistent shapes"); \ - auto newShape = ShapeType::zeros(zeroShape.ndim() + 1); \ - newShape[0] = data.size(); \ - for (size_t i = 0; i < zeroShape.ndim(); ++i) { newShape[i + 1] = zeroShape[i]; } \ - auto res = Array(newShape); \ - for (int64_t i = 0; i < data.size(); ++i) res[i] = tmp[i]; \ - return res; \ - } + template \ + auto array::ArrayContainer::fromData(const TYPE &data) -> ArrayContainer { \ + LIBRAPID_ASSERT(data.size() > 0, "Cannot create a zero-sized array"); \ + std::vector tmp(data.size()); \ + int64_t index = 0; \ + for (const auto &item : data) tmp[index++] = fromData(item); \ + auto zeroShape = tmp[0].shape(); \ + for (int64_t i = 0; i < data.size(); ++i) \ + LIBRAPID_ASSERT(tmp[i].shape().operator==(zeroShape), \ + "Arrays must have consistent shapes"); \ + auto newShape = ShapeType::zeros(zeroShape.ndim() + 1); \ + newShape[0] = data.size(); \ + for (size_t i = 0; i < zeroShape.ndim(); ++i) { newShape[i + 1] = zeroShape[i]; } \ + auto res = Array(newShape); \ + for (int64_t i = 0; i < data.size(); ++i) res[i] = tmp[i]; \ + return res; \ + } #define SINIT(SUB_TYPE) std::initializer_list -#define SVEC(SUB_TYPE) std::vector +#define SVEC(SUB_TYPE) std::vector - HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(Scalar)))) - HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(SINIT(Scalar))))) - HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(SINIT(SINIT(Scalar)))))) - HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(Scalar))))))) - HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(Scalar)))))))) - HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(Scalar))))))))) + HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(Scalar)))) + HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(SINIT(Scalar))))) + HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(SINIT(SINIT(Scalar)))))) + HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(Scalar))))))) + HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(Scalar)))))))) + HIGHER_DIMENSIONAL_FROM_DATA(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(SINIT(Scalar))))))))) - HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(Scalar)))) - HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(SVEC(Scalar))))) - HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(SVEC(SVEC(Scalar)))))) - HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(Scalar))))))) - HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(Scalar)))))))) - HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(Scalar))))))))) + HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(Scalar)))) + HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(SVEC(Scalar))))) + HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(SVEC(SVEC(Scalar)))))) + HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(Scalar))))))) + HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(Scalar)))))))) + HIGHER_DIMENSIONAL_FROM_DATA(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(SVEC(Scalar))))))))) #undef SINIT #undef HIGHER_DIMENSIONAL_FROM_DATA diff --git a/librapid/include/librapid/array/arrayIterator.hpp b/librapid/include/librapid/array/arrayIterator.hpp index 998fa1c4..0c500af8 100644 --- a/librapid/include/librapid/array/arrayIterator.hpp +++ b/librapid/include/librapid/array/arrayIterator.hpp @@ -2,101 +2,101 @@ #define LIBRAPID_ARRAY_ITERATOR_HPP namespace librapid::detail { - template - class ArrayIterator { - public: - using IndexType = int64_t; - - /// Default constructor should never be used - ArrayIterator() = delete; - - explicit ArrayIterator(const T &array); - - explicit ArrayIterator(const T &array, IndexType index); - - /// Copy an ArrayIterator object (const) - /// \param other The array to copy - ArrayIterator(const ArrayIterator &other) = default; - - /// Constructs an ArrayIterator from a temporary instance - /// \param other The ArrayIterator to move - ArrayIterator(ArrayIterator &&other) = default; - - /// Assigns another ArrayIterator object to this ArrayIterator. - /// \param other The ArrayIterator to assign. - /// \return A reference to this - ArrayIterator &operator=(const ArrayIterator &other) = default; - - ArrayIterator &operator++(); - bool operator==(const ArrayIterator &other) const; - bool operator!=(const ArrayIterator &other) const; - - auto operator*() const; - auto operator*(); - - ArrayIterator> begin() const noexcept; - ArrayIterator> end() const noexcept; - - ArrayIterator> begin(); - ArrayIterator> end(); - - private: - T m_array; - IndexType m_index; - }; - - template - ArrayIterator::ArrayIterator(const T &array) : m_array(array), m_index(0) {} - - template - ArrayIterator::ArrayIterator(const T &array, IndexType index) : - m_array(array), m_index(index) {} - - template - ArrayIterator &ArrayIterator::operator++() { - ++m_index; - return *this; - } - - template - bool ArrayIterator::operator==(const ArrayIterator &other) const { - return m_index == other.m_index; - } - - template - bool ArrayIterator::operator!=(const ArrayIterator &other) const { - return !(this->operator==(other)); - } - - template - auto ArrayIterator::operator*() const { - return m_array[m_index]; - } - - template - auto ArrayIterator::operator*() { - return m_array[m_index]; - } - - template - auto ArrayIterator::begin() const noexcept -> ArrayIterator> { - return ArrayIterator>(*this, 0); - } - - template - auto ArrayIterator::end() const noexcept -> ArrayIterator> { - return ArrayIterator>(*this, m_array.shape()[0]); - } - - template - auto ArrayIterator::begin() -> ArrayIterator> { - return ArrayIterator>(*this, 0); - } - - template - auto ArrayIterator::end() -> ArrayIterator> { - return ArrayIterator>(*this, m_array.shape()[0]); - } + template + class ArrayIterator { + public: + using IndexType = int64_t; + + /// Default constructor should never be used + ArrayIterator() = delete; + + explicit ArrayIterator(const T &array); + + explicit ArrayIterator(const T &array, IndexType index); + + /// Copy an ArrayIterator object (const) + /// \param other The array to copy + ArrayIterator(const ArrayIterator &other) = default; + + /// Constructs an ArrayIterator from a temporary instance + /// \param other The ArrayIterator to move + ArrayIterator(ArrayIterator &&other) = default; + + /// Assigns another ArrayIterator object to this ArrayIterator. + /// \param other The ArrayIterator to assign. + /// \return A reference to this + ArrayIterator &operator=(const ArrayIterator &other) = default; + + ArrayIterator &operator++(); + bool operator==(const ArrayIterator &other) const; + bool operator!=(const ArrayIterator &other) const; + + auto operator*() const; + auto operator*(); + + ArrayIterator> begin() const noexcept; + ArrayIterator> end() const noexcept; + + ArrayIterator> begin(); + ArrayIterator> end(); + + private: + T m_array; + IndexType m_index; + }; + + template + ArrayIterator::ArrayIterator(const T &array) : m_array(array), m_index(0) {} + + template + ArrayIterator::ArrayIterator(const T &array, IndexType index) : + m_array(array), m_index(index) {} + + template + ArrayIterator &ArrayIterator::operator++() { + ++m_index; + return *this; + } + + template + bool ArrayIterator::operator==(const ArrayIterator &other) const { + return m_index == other.m_index; + } + + template + bool ArrayIterator::operator!=(const ArrayIterator &other) const { + return !(this->operator==(other)); + } + + template + auto ArrayIterator::operator*() const { + return m_array[m_index]; + } + + template + auto ArrayIterator::operator*() { + return m_array[m_index]; + } + + template + auto ArrayIterator::begin() const noexcept -> ArrayIterator> { + return ArrayIterator>(*this, 0); + } + + template + auto ArrayIterator::end() const noexcept -> ArrayIterator> { + return ArrayIterator>(*this, m_array.shape()[0]); + } + + template + auto ArrayIterator::begin() -> ArrayIterator> { + return ArrayIterator>(*this, 0); + } + + template + auto ArrayIterator::end() -> ArrayIterator> { + return ArrayIterator>(*this, m_array.shape()[0]); + } } // namespace librapid::detail #endif // LIBRAPID_ARRAY_ITERATOR_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/arrayTypeDef.hpp b/librapid/include/librapid/array/arrayTypeDef.hpp index f3e9b14f..fd00bada 100644 --- a/librapid/include/librapid/array/arrayTypeDef.hpp +++ b/librapid/include/librapid/array/arrayTypeDef.hpp @@ -2,84 +2,141 @@ #define LIBRAPID_ARRAY_TYPE_DEF_HPP namespace librapid { - namespace detail { - template - struct TypeDefStorageEvaluator { - using Type = T; - }; + namespace detail { + template + struct TypeDefStorageEvaluator { + using Type = T; + }; - template - struct TypeDefStorageEvaluator { - using Type = Storage; - }; + template + struct TypeDefStorageEvaluator { + using Type = Storage; + }; - template - struct TypeDefStorageEvaluator { - using Type = OpenCLStorage; - }; + template + struct TypeDefStorageEvaluator { + using Type = OpenCLStorage; + }; - template - struct TypeDefStorageEvaluator { - using Type = CudaStorage; - }; - } // namespace detail + template + struct TypeDefStorageEvaluator { + using Type = CudaStorage; + }; + } // namespace detail - /// An easier to use definition than ArrayContainer. In this case, StorageType can be - /// `backend::CPU`, `backend::CUDA` or any Storage interface - /// \tparam Scalar The scalar type of the array. - /// \tparam StorageType The storage type of the array. - template - using Array = - array::ArrayContainer, - typename detail::TypeDefStorageEvaluator::Type>; + /// An easier to use definition than ArrayContainer. In this case, StorageType can be + /// `backend::CPU`, `backend::CUDA` or any Storage interface + /// \tparam Scalar The scalar type of the array. + /// \tparam StorageType The storage type of the array. + template + using Array = + array::ArrayContainer, + typename detail::TypeDefStorageEvaluator::Type>; - /// A definition for fixed-size array objects. - /// \tparam Scalar The scalar type of the array. - /// \tparam Dimensions The dimensions of the array. - /// \see Array - template - using ArrayF = array::ArrayContainer, FixedStorage>; + /// A definition for fixed-size array objects. + /// \tparam Scalar The scalar type of the array. + /// \tparam Dimensions The dimensions of the array. + /// \see Array + template + using ArrayF = array::ArrayContainer, FixedStorage>; - /// A reference type for Array objects. Use this to accept Array objects as parameters since - /// the compiler cannot determine the templates tingle for the Array typedef. For more - /// granularity, you can also accept a raw ArrayContainer object. \tparam StorageType The - /// storage type of the array. \see Array \see ArrayF \see Function \see FunctionRef - template - using ArrayRef = array::ArrayContainer, StorageType>; + /// A reference type for Array objects. Use this to accept Array objects as parameters since + /// the compiler cannot determine the templates tingle for the Array typedef. For more + /// granularity, you can also accept a raw ArrayContainer object. \tparam StorageType The + /// storage type of the array. \see Array \see ArrayF \see Function \see FunctionRef + template + using ArrayRef = array::ArrayContainer, StorageType>; - /// A reference type for Array Function objects. Use this to accept Function objects as - /// parameters since the compiler cannot determine the templates for the typedef by default. - /// Additionally, this can be used to store references to Function objects. - /// \tparam Inputs The argument types to the function (template...) - /// \see Array - /// \see ArrayF - /// \see ArrayRef - /// \see Function - template - using FunctionRef = detail::Function; + /// A reference type for Array Function objects. Use this to accept Function objects as + /// parameters since the compiler cannot determine the templates for the typedef by default. + /// Additionally, this can be used to store references to Function objects. + /// \tparam Inputs The argument types to the function (template...) + /// \see Array + /// \see ArrayF + /// \see ArrayRef + /// \see Function + template + using FunctionRef = detail::Function; - namespace array { - /// An intermediate type to represent a slice or view of an array. - /// \tparam T The type of the array. - template - class ArrayView; + namespace array { + /// An intermediate type to represent a slice or view of an array. + /// \tparam T The type of the array. + template + class ArrayView; - template - class Transpose; - } // namespace array + template + class Transpose; + } // namespace array - namespace linalg { - template - class ArrayMultiply; - } + namespace linalg { + template + class ArrayMultiply; + } - template - using IsArrayType = std::integral_constant< - bool, (typetraits::TypeInfo::type == detail::LibRapidType::ArrayContainer) || - (typetraits::TypeInfo::type == detail::LibRapidType::ArrayView) || - (typetraits::TypeInfo::type == detail::LibRapidType::ArrayFunction)>; + template + using IsArrayType = std::integral_constant< + bool, (typetraits::TypeInfo::type == detail::LibRapidType::ArrayContainer) || + (typetraits::TypeInfo::type == detail::LibRapidType::ArrayView) || + (typetraits::TypeInfo::type == detail::LibRapidType::ArrayFunction)>; + +#define ARRAY_TYPE_FMT_IML(TEMPLATE_, TYPE_) \ + template \ + struct fmt::formatter { \ + using Type = TYPE_; \ + using Scalar = typename librapid::typetraits::TypeInfo::Scalar; \ + using Formatter = fmt::formatter; \ + Formatter m_formatter; \ + char m_bracket = 's'; \ + char m_separator = ' '; \ + \ + template \ + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { \ + /* Custom format options: */ \ + /* - "~r" for round brackets */ \ + /* - "~s" for square brackets */ \ + /* - "~c" for curly brackets */ \ + /* - "~a" for angle brackets */ \ + /* - "~p" for pipe brackets */ \ + /* - "-," for comma separator */ \ + /* - "-;" for semicolon separator */ \ + /* - "-:" for colon separator */ \ + /* - "-|" for pipe separator */ \ + /* - "-_" for underscore separator */ \ + \ + auto it = ctx.begin(), end = ctx.end(); \ + if (it != end && *it == '~') { \ + ++it; \ + if (it != end && \ + (*it == 'r' || *it == 's' || *it == 'c' || *it == 'a' || *it == 'p')) { \ + m_bracket = *it++; \ + } \ + } \ + \ + if (it != end && *it == '-') { \ + ++it; \ + if (it != end) { m_separator = *it++; } \ + } \ + \ + ctx.advance_to(it); \ + \ + return m_formatter.parse(ctx); \ + } \ + \ + template \ + FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const \ + -> decltype(ctx.out()) { \ + val.str(m_formatter, m_bracket, m_separator, ctx); \ + return ctx.out(); \ + } \ + }; \ + \ + template \ + auto operator<<(std::ostream &os, const TYPE_ &object) -> std::ostream & { \ + os << fmt::format("{}", object); \ + return os; \ + } } // namespace librapid #endif // LIBRAPID_ARRAY_TYPE_DEF_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/arrayView.hpp b/librapid/include/librapid/array/arrayView.hpp index 88ae7783..b95eb3a5 100644 --- a/librapid/include/librapid/array/arrayView.hpp +++ b/librapid/include/librapid/array/arrayView.hpp @@ -2,333 +2,370 @@ #define LIBRAPID_ARRAY_ARRAY_VIEW_HPP namespace librapid { - namespace typetraits { - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayView; - using Scalar = typename TypeInfo>::Scalar; - using Backend = typename TypeInfo>::Backend; - static constexpr bool allowVectorisation = false; - }; - - LIBRAPID_DEFINE_AS_TYPE(typename T, array::ArrayView); - } // namespace typetraits - - namespace array { - template - class ArrayView { - public: - // using ArrayType = T; - using BaseType = typename std::decay_t; - using Scalar = typename typetraits::TypeInfo::Scalar; - using Reference = BaseType &; - using ConstReference = const BaseType &; - using Backend = typename typetraits::TypeInfo::Backend; - using ArrayType = Array; - using StrideType = typename ArrayType::StrideType; - using ShapeType = typename ArrayType::ShapeType; - using Iterator = detail::ArrayIterator; - - /// Default constructor should never be used - ArrayView() = delete; - - /// Copy an ArrayView object - /// \param array The array to copy - explicit ArrayView(ArrayViewType &array); - - /// Copy an ArrayView object (not const) - /// \param array The array to copy - explicit ArrayView(ArrayViewType &&array) = delete; - - /// Copy an ArrayView object (const) - /// \param other The array to copy - ArrayView(const ArrayView &other) = default; - - /// Constructs an ArrayView from a temporary instance - /// \param other The ArrayView to move - ArrayView(ArrayView &&other) = default; - - /// Assigns another ArrayView object to this ArrayView. - /// \param other The ArrayView to assign. - /// \return A reference to this - ArrayView &operator=(const ArrayView &other) = default; - - /// Assigns a temporary ArrayView to this ArrayView. - /// \param other The ArrayView to move. - /// \return A reference to this ArrayView. - // ArrayView &operator=(ArrayView &&other) noexcept = default; - - /// Assign a scalar value to this ArrayView. This function should only be used to - /// assign to a zero-dimensional "scalar" ArrayView, and will throw an error if used - /// incorrectly. - /// \param scalar The scalar value to assign - /// \return A reference to this - ArrayView &operator=(const Scalar &scalar); - - template - ArrayView &operator=(const ArrayRef &other); - - /// Access a sub-array of this ArrayView. - /// \param index The index of the sub-array. - /// \return An ArrayView from this - const ArrayView operator[](int64_t index) const; - - ArrayView operator[](int64_t index); - - /// Since even scalars are represented as an ArrayView object, it can be difficult to - /// operate on them directly. This allows you to extract the scalar value stored by a - /// zero-dimensional ArrayView object - /// \tparam CAST Type to cast to - /// \return The scalar represented by the ArrayView object - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE CAST get() const; - - /// Same functionality as "get", except slightly less robust for user-defined types. - /// \tparam CAST Type to cast to - /// \return The scalar represented by the ArrayView object - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE explicit operator CAST() const; - - /// Access the underlying shape of this ArrayView - /// \return Shape object - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const; - - /// Access the stride of this ArrayView - /// \return Stride object - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE StrideType stride() const; - - /// Access the offset of this ArrayView. This is the offset, in elements, from the - /// referenced Array's first element. - /// \return Offset - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t offset() const; - - /// Set the Shape of this ArrayView to something else. Intended for internal use only. - /// \param shape The new shape of this ArrayView - void setShape(const ShapeType &shape); - - /// Set the Stride of this ArrayView to something else. Intended for internal use only. - /// \param stride The new stride of this ArrayView - void setStride(const StrideType &stride); - - /// Set the offset of this ArrayView object. Intended for internal use only. - /// \param offset The new offset of this ArrayView - void setOffset(const int64_t &offset); - - /// Returns the number of dimensions of this ArrayView - /// \return Number of dimensions - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const; - - /// Return the Scalar at a given index in this ArrayView. This is intended for use - /// internally, but can be used externally too. - /// \param index The index of the Scalar to return - /// \return Scalar at the given index - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalar(int64_t index) const; - - /// Evaluate the contents of this ArrayView object and return an Array instance from - /// it. Depending on your use case, this may result in more performant code, but the new - /// Array will not reference the original data in the ArrayView. - /// \return A new Array instance - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ArrayType eval() const; - - LIBRAPID_NODISCARD Iterator begin() const; - LIBRAPID_NODISCARD Iterator end() const; - - /// Cast an ArrayView to a std::string, aligning items down the columns. A format - /// string can also be specified, which will be used to format the items to strings - /// \param format The format string - /// \return A std::string representation of this ArrayView - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; - - template - void fmtStr(const fmt::formatter &format, char bracket, char separator, - Ctx &ctx) const; - - private: - ArrayViewType &m_ref; - ShapeType m_shape; - StrideType m_stride; - int64_t m_offset = 0; - }; - - template - ArrayView::ArrayView(ArrayViewType &array) : - m_ref(array), m_shape(array.shape()), m_stride(array.shape()) {} - - template - ArrayView &ArrayView::operator=(const Scalar &scalar) { - LIBRAPID_ASSERT(m_shape.ndim() == 0, "Cannot assign to a non-scalar ArrayView."); - m_ref.storage()[m_offset] = static_cast(scalar); - return *this; - } - - template - template - ArrayView &ArrayView::operator=(const ArrayRef &other) { - LIBRAPID_ASSERT(m_shape.operator==(other.shape()), - "Cannot assign to a non-scalar ArrayView."); - - ShapeType coord = ShapeType::zeros(m_shape.ndim()); - int64_t d = 0, p = 0; - int64_t idim = 0, adim = 0; - const int64_t ndim = m_shape.ndim(); - - do { - m_ref.storage()[p + m_offset] = other.scalar(d++); - - for (idim = 0; idim < ndim; ++idim) { - adim = ndim - idim - 1; - if (++coord[adim] == m_shape[adim]) { - coord[adim] = 0; - p = p - (m_shape[adim] - 1) * m_stride[adim]; - } else { - p = p + m_stride[adim]; - break; - } - } - } while (idim < ndim); - } - - template - auto ArrayView::operator[](int64_t index) const -> const ArrayView { - LIBRAPID_ASSERT( - index >= 0 && index < static_cast(m_shape[0]), - "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}", - index, - m_shape[0]); - ArrayView view(m_ref); - const auto stride = Stride(m_shape); - view.setShape(m_shape.subshape(1, ndim())); - if (ndim() == 1) - view.setStride(Stride({1})); - else - view.setStride(stride.subshape(1, ndim())); - view.setOffset(m_offset + index * stride[0]); - return view; - } - - template - auto ArrayView::operator[](int64_t index) -> ArrayView { - LIBRAPID_ASSERT( - index >= 0 && index < static_cast(m_shape[0]), - "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}", - index, - m_shape[0]); - ArrayView view(m_ref); - const auto stride = Stride(m_shape); - view.setShape(m_shape.subshape(1, ndim())); - if (ndim() == 1) - view.setStride(Stride({1})); - else - view.setStride(stride.subshape(1, ndim())); - view.setOffset(m_offset + index * stride[0]); - return view; - } - - template - template - CAST ArrayView::get() const { - LIBRAPID_ASSERT(m_shape.ndim() == 0, - "Can only cast a scalar ArrayView to a salar object"); - return scalar(0); - } - - template - template - ArrayView::operator CAST() const { - return get(); - } - - template - auto ArrayView::shape() const -> ShapeType { - return m_shape; - } - - template - auto ArrayView::stride() const -> StrideType { - return m_stride; - } - - template - auto ArrayView::offset() const -> int64_t { - return m_offset; - } - - template - void ArrayView::setShape(const ShapeType &shape) { - m_shape = shape; - } - - template - void ArrayView::setStride(const StrideType &stride) { - m_stride = stride; - } - - template - void ArrayView::setOffset(const int64_t &offset) { - m_offset = offset; - } - - template - auto ArrayView::ndim() const -> int64_t { - return m_shape.ndim(); - } - - template - auto ArrayView::scalar(int64_t index) const -> auto { - if (ndim() == 0) return m_ref.scalar(m_offset); - - ShapeType tmp = ShapeType::zeros(ndim()); - tmp[ndim() - 1] = index % m_shape[ndim() - 1]; - for (int64_t i = ndim() - 2; i >= 0; --i) { - index /= m_shape[i + 1]; - tmp[i] = index % m_shape[i]; - } - int64_t offset = 0; - for (int64_t i = 0; i < ndim(); ++i) { offset += tmp[i] * m_stride[i]; } - return m_ref.scalar(m_offset + offset); - } - - template - auto ArrayView::eval() const -> ArrayType { - ArrayType res(m_shape); - ShapeType coord = ShapeType::zeros(m_shape.ndim()); - int64_t d = 0, p = 0; - int64_t idim = 0, adim = 0; - const int64_t ndim = m_shape.ndim(); - - do { - res.storage()[d++] = m_ref.scalar(p + m_offset); - - for (idim = 0; idim < ndim; ++idim) { - adim = ndim - idim - 1; - if (++coord[adim] == m_shape[adim]) { - coord[adim] = 0; - p = p - (m_shape[adim] - 1) * m_stride[adim]; - } else { - p = p + m_stride[adim]; - break; - } - } - } while (idim < ndim); - - return res; - } - - template - auto ArrayView::begin() const -> Iterator { - return Iterator(*this, 0); - } - - template - auto ArrayView::end() const -> Iterator { - return Iterator(*this, m_shape[0]); - } - } // namespace array + namespace typetraits { + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayView; + using Scalar = typename TypeInfo>::Scalar; + using Backend = typename TypeInfo>::Backend; + static constexpr bool allowVectorisation = false; + }; + + LIBRAPID_DEFINE_AS_TYPE(typename T, array::ArrayView); + } // namespace typetraits + + namespace array { + template + class ArrayView { + public: + // using ArrayType = T; + using BaseType = typename std::decay_t; + using Scalar = typename typetraits::TypeInfo::Scalar; + using Reference = BaseType &; + using ConstReference = const BaseType &; + using Backend = typename typetraits::TypeInfo::Backend; + using ArrayType = Array; + using StrideType = typename ArrayType::StrideType; + using ShapeType = typename ArrayType::ShapeType; + using Iterator = detail::ArrayIterator; + + /// Default constructor should never be used + ArrayView() = delete; + + /// Copy an ArrayView object + /// \param array The array to copy + explicit ArrayView(ArrayViewType &array); + + /// Copy an ArrayView object (not const) + /// \param array The array to copy + explicit ArrayView(ArrayViewType &&array) = delete; + + /// Copy an ArrayView object (const) + /// \param other The array to copy + ArrayView(const ArrayView &other) = default; + + /// Constructs an ArrayView from a temporary instance + /// \param other The ArrayView to move + ArrayView(ArrayView &&other) = default; + + /// Assigns another ArrayView object to this ArrayView. + /// \param other The ArrayView to assign. + /// \return A reference to this + ArrayView &operator=(const ArrayView &other) = default; + + /// Assigns a temporary ArrayView to this ArrayView. + /// \param other The ArrayView to move. + /// \return A reference to this ArrayView. + // ArrayView &operator=(ArrayView &&other) noexcept = default; + + /// Assign a scalar value to this ArrayView. This function should only be used to + /// assign to a zero-dimensional "scalar" ArrayView, and will throw an error if used + /// incorrectly. + /// \param scalar The scalar value to assign + /// \return A reference to this + ArrayView &operator=(const Scalar &scalar); + + template + ArrayView &operator=(const ArrayRef &other); + + /// Access a sub-array of this ArrayView. + /// \param index The index of the sub-array. + /// \return An ArrayView from this + const ArrayView operator[](int64_t index) const; + + ArrayView operator[](int64_t index); + + /// Since even scalars are represented as an ArrayView object, it can be difficult to + /// operate on them directly. This allows you to extract the scalar value stored by a + /// zero-dimensional ArrayView object + /// \tparam CAST Type to cast to + /// \return The scalar represented by the ArrayView object + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE CAST get() const; + + /// Same functionality as "get", except slightly less robust for user-defined types. + /// \tparam CAST Type to cast to + /// \return The scalar represented by the ArrayView object + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE explicit operator CAST() const; + + /// Access the underlying shape of this ArrayView + /// \return Shape object + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const; + + /// Access the stride of this ArrayView + /// \return Stride object + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE StrideType stride() const; + + /// Access the offset of this ArrayView. This is the offset, in elements, from the + /// referenced Array's first element. + /// \return Offset + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t offset() const; + + /// Set the Shape of this ArrayView to something else. Intended for internal use only. + /// \param shape The new shape of this ArrayView + void setShape(const ShapeType &shape); + + /// Set the Stride of this ArrayView to something else. Intended for internal use only. + /// \param stride The new stride of this ArrayView + void setStride(const StrideType &stride); + + /// Set the offset of this ArrayView object. Intended for internal use only. + /// \param offset The new offset of this ArrayView + void setOffset(const int64_t &offset); + + /// Returns the number of dimensions of this ArrayView + /// \return Number of dimensions + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const; + + /// Return the Scalar at a given index in this ArrayView. This is intended for use + /// internally, but can be used externally too. + /// \param index The index of the Scalar to return + /// \return Scalar at the given index + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalar(int64_t index) const; + + /// Evaluate the contents of this ArrayView object and return an Array instance from + /// it. Depending on your use case, this may result in more performant code, but the new + /// Array will not reference the original data in the ArrayView. + /// \return A new Array instance + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ArrayType eval() const; + + LIBRAPID_NODISCARD Iterator begin() const; + LIBRAPID_NODISCARD Iterator end() const; + + template + void str(const fmt::formatter &format, char bracket, char separator, + Ctx &ctx) const; + + private: + ArrayViewType &m_ref; + ShapeType m_shape; + StrideType m_stride; + int64_t m_offset = 0; + }; + + template + ArrayView::ArrayView(ArrayViewType &array) : + m_ref(array), m_shape(array.shape()), m_stride(array.shape()) {} + + template + ArrayView &ArrayView::operator=(const Scalar &scalar) { + LIBRAPID_ASSERT(m_shape.ndim() == 0, "Cannot assign to a non-scalar ArrayView."); + m_ref.storage()[m_offset] = static_cast(scalar); + return *this; + } + + template + template + ArrayView &ArrayView::operator=(const ArrayRef &other) { + LIBRAPID_ASSERT(m_shape.operator==(other.shape()), + "Cannot assign to a non-scalar ArrayView."); + + ShapeType coord = ShapeType::zeros(m_shape.ndim()); + int64_t d = 0, p = 0; + int64_t idim = 0, adim = 0; + const int64_t ndim = m_shape.ndim(); + + do { + m_ref.storage()[p + m_offset] = other.scalar(d++); + + for (idim = 0; idim < ndim; ++idim) { + adim = ndim - idim - 1; + if (++coord[adim] == m_shape[adim]) { + coord[adim] = 0; + p = p - (m_shape[adim] - 1) * m_stride[adim]; + } else { + p = p + m_stride[adim]; + break; + } + } + } while (idim < ndim); + } + + template + auto ArrayView::operator[](int64_t index) const -> const ArrayView { + LIBRAPID_ASSERT( + index >= 0 && index < static_cast(m_shape[0]), + "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}", + index, + m_shape[0]); + ArrayView view(m_ref); + const auto stride = Stride(m_shape); + view.setShape(m_shape.subshape(1, ndim())); + if (ndim() == 1) + view.setStride(Stride({1})); + else + view.setStride(stride.subshape(1, ndim())); + view.setOffset(m_offset + index * stride[0]); + return view; + } + + template + auto ArrayView::operator[](int64_t index) -> ArrayView { + LIBRAPID_ASSERT( + index >= 0 && index < static_cast(m_shape[0]), + "Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}", + index, + m_shape[0]); + ArrayView view(m_ref); + const auto stride = Stride(m_shape); + view.setShape(m_shape.subshape(1, ndim())); + if (ndim() == 1) + view.setStride(Stride({1})); + else + view.setStride(stride.subshape(1, ndim())); + view.setOffset(m_offset + index * stride[0]); + return view; + } + + template + template + CAST ArrayView::get() const { + LIBRAPID_ASSERT(m_shape.ndim() == 0, + "Can only cast a scalar ArrayView to a salar object"); + return scalar(0); + } + + template + template + ArrayView::operator CAST() const { + return get(); + } + + template + auto ArrayView::shape() const -> ShapeType { + return m_shape; + } + + template + auto ArrayView::stride() const -> StrideType { + return m_stride; + } + + template + auto ArrayView::offset() const -> int64_t { + return m_offset; + } + + template + void ArrayView::setShape(const ShapeType &shape) { + m_shape = shape; + } + + template + void ArrayView::setStride(const StrideType &stride) { + m_stride = stride; + } + + template + void ArrayView::setOffset(const int64_t &offset) { + m_offset = offset; + } + + template + auto ArrayView::ndim() const -> int64_t { + return m_shape.ndim(); + } + + template + auto ArrayView::scalar(int64_t index) const -> auto { + if (ndim() == 0) return m_ref.scalar(m_offset); + + ShapeType tmp = ShapeType::zeros(ndim()); + tmp[ndim() - 1] = index % m_shape[ndim() - 1]; + for (int64_t i = ndim() - 2; i >= 0; --i) { + index /= m_shape[i + 1]; + tmp[i] = index % m_shape[i]; + } + int64_t offset = 0; + for (int64_t i = 0; i < ndim(); ++i) { offset += tmp[i] * m_stride[i]; } + return m_ref.scalar(m_offset + offset); + } + + template + auto ArrayView::eval() const -> ArrayType { + ArrayType res(m_shape); + ShapeType coord = ShapeType::zeros(m_shape.ndim()); + int64_t d = 0, p = 0; + int64_t idim = 0, adim = 0; + const int64_t ndim = m_shape.ndim(); + + do { + res.storage()[d++] = m_ref.scalar(p + m_offset); + + for (idim = 0; idim < ndim; ++idim) { + adim = ndim - idim - 1; + if (++coord[adim] == m_shape[adim]) { + coord[adim] = 0; + p = p - (m_shape[adim] - 1) * m_stride[adim]; + } else { + p = p + m_stride[adim]; + break; + } + } + } while (idim < ndim); + + return res; + } + + template + auto ArrayView::begin() const -> Iterator { + return Iterator(*this, 0); + } + + template + auto ArrayView::end() const -> Iterator { + return Iterator(*this, m_shape[0]); + } + } // namespace array } // namespace librapid // Support FMT printing #ifdef FMT_API -// LIBRAPID_SIMPLE_IO_IMPL(typename T, librapid::array::ArrayView) +template +struct fmt::formatter> { + using Type = librapid::array::ArrayView; + using Scalar = typename librapid::typetraits::TypeInfo::Scalar; + using Formatter = fmt::formatter; + Formatter m_formatter; + char m_bracket = 's'; + char m_separator = ' '; + + template + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { + // Same formatting options as for the ArrayContainer type + + auto it = ctx.begin(), end = ctx.end(); + if (it != end && *it == '~') { + ++it; + if (it != end && (*it == 'r' || *it == 's' || *it == 'c' || *it == 'a' || *it == 'p')) { + m_bracket = *it++; + } + } + + if (it != end && *it == '-') { + ++it; + if (it != end) { m_separator = *it++; } + } + + ctx.advance_to(it); + + return m_formatter.parse(ctx); + } + + template + FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) { + val.str(m_formatter, m_bracket, m_separator, ctx); + return ctx.out(); + } +}; + +template +auto operator<<(std::ostream &os, const librapid::array::ArrayView &object) + -> std::ostream & { + os << fmt::format("{}", object); + return os; +} LIBRAPID_SIMPLE_IO_NORANGE(typename T, librapid::array::ArrayView) #endif // FMT_API diff --git a/librapid/include/librapid/array/arrayViewString.hpp b/librapid/include/librapid/array/arrayViewString.hpp index 3a74a3d3..488ffbca 100644 --- a/librapid/include/librapid/array/arrayViewString.hpp +++ b/librapid/include/librapid/array/arrayViewString.hpp @@ -2,80 +2,80 @@ #define LIBRAPID_ARRAY_ARRAY_VIEW_STRING_HPP namespace librapid { - namespace detail { - template - void arrayViewToString(const array::ArrayView &view, - const fmt::formatter &formatter, char bracket, - char separator, int64_t indent, Ctx &ctx) { - char bracketCharOpen, bracketCharClose; + namespace detail { + template + void arrayViewToString(const array::ArrayView &view, + const fmt::formatter &formatter, char bracket, + char separator, int64_t indent, Ctx &ctx) { + char bracketCharOpen, bracketCharClose; - switch (bracket) { - case 'r': - bracketCharOpen = '('; - bracketCharClose = ')'; - break; - case 's': - bracketCharOpen = '['; - bracketCharClose = ']'; - break; - case 'c': - bracketCharOpen = '{'; - bracketCharClose = '}'; - break; - case 'a': - bracketCharOpen = '<'; - bracketCharClose = '>'; - break; - case 'p': - bracketCharOpen = '|'; - bracketCharClose = '|'; - break; - default: - bracketCharOpen = '['; - bracketCharClose = ']'; - break; - } + switch (bracket) { + case 'r': + bracketCharOpen = '('; + bracketCharClose = ')'; + break; + case 's': + bracketCharOpen = '['; + bracketCharClose = ']'; + break; + case 'c': + bracketCharOpen = '{'; + bracketCharClose = '}'; + break; + case 'a': + bracketCharOpen = '<'; + bracketCharClose = '>'; + break; + case 'p': + bracketCharOpen = '|'; + bracketCharClose = '|'; + break; + default: + bracketCharOpen = '['; + bracketCharClose = ']'; + break; + } - // Separator char is already the correct character + // Separator char is already the correct character - if (view.ndim() == 0) { - formatter.format(view.scalar(0), ctx); - } else if (view.ndim() == 1) { - fmt::format_to(ctx.out(), "{}", bracketCharOpen); - for (int64_t i = 0; i < static_cast(view.shape()[0]); i++) { - formatter.format(view.scalar(i), ctx); - if (i != view.shape()[0] - 1) { - if (separator == ' ') { - fmt::format_to(ctx.out(), " "); - } else { - fmt::format_to(ctx.out(), "{} ", separator); - } - } - } - fmt::format_to(ctx.out(), "{}", bracketCharClose); - } else { - fmt::format_to(ctx.out(), "{}", bracketCharOpen); - for (int64_t i = 0; i < static_cast(view.shape()[0]); i++) { - if (i > 0) fmt::format_to(ctx.out(), "{}", std::string(indent + 1, ' ')); - arrayViewToString(view[i], formatter, bracket, separator, indent + 1, ctx); - if (i != view.shape()[0] - 1) { - fmt::format_to(ctx.out(), "{}\n", separator); - if (view.ndim() > 2) { fmt::format_to(ctx.out(), "\n"); } - } - } - fmt::format_to(ctx.out(), "{}", bracketCharClose); - } - } - } // namespace detail + if (view.ndim() == 0) { + formatter.format(view.scalar(0), ctx); + } else if (view.ndim() == 1) { + fmt::format_to(ctx.out(), "{}", bracketCharOpen); + for (int64_t i = 0; i < static_cast(view.shape()[0]); i++) { + formatter.format(view.scalar(i), ctx); + if (i != view.shape()[0] - 1) { + if (separator == ' ') { + fmt::format_to(ctx.out(), " "); + } else { + fmt::format_to(ctx.out(), "{} ", separator); + } + } + } + fmt::format_to(ctx.out(), "{}", bracketCharClose); + } else { + fmt::format_to(ctx.out(), "{}", bracketCharOpen); + for (int64_t i = 0; i < static_cast(view.shape()[0]); i++) { + if (i > 0) fmt::format_to(ctx.out(), "{}", std::string(indent + 1, ' ')); + arrayViewToString(view[i], formatter, bracket, separator, indent + 1, ctx); + if (i != view.shape()[0] - 1) { + fmt::format_to(ctx.out(), "{}\n", separator); + if (view.ndim() > 2) { fmt::format_to(ctx.out(), "\n"); } + } + } + fmt::format_to(ctx.out(), "{}", bracketCharClose); + } + } + } // namespace detail - namespace array { - template - template - void ArrayView::fmtStr(const fmt::formatter &format, char bracket, - char separator, Ctx &ctx) const { - detail::arrayViewToString(*this, format, bracket, separator, 0, ctx); - } - } // namespace array + namespace array { + template + template + void ArrayView::str(const fmt::formatter &format, char bracket, + char separator, Ctx &ctx) const { + detail::arrayViewToString(*this, format, bracket, separator, 0, ctx); + } + } // namespace array } // namespace librapid #endif // LIBRAPID_ARRAY_ARRAY_VIEW_STRING_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/assignOps.hpp b/librapid/include/librapid/array/assignOps.hpp index 30ca3696..1092c05f 100644 --- a/librapid/include/librapid/array/assignOps.hpp +++ b/librapid/include/librapid/array/assignOps.hpp @@ -2,528 +2,524 @@ #define LIBRAPID_ARRAY_ASSIGN_OPS_HPP namespace librapid { - // All assignment operators are forward declared in "forward.hpp" so they can be used - // elsewhere. They are defined here. - - namespace detail { - /// Trivial array assignment operator -- assignment can be done with a single vectorised - /// loop over contiguous data. - /// \tparam ShapeType_ The shape type of the array container - /// \tparam StorageScalar The scalar type of the storage object - /// \tparam Functor_ The function type - /// \tparam Args The argument types of the function - /// \param lhs The array container to assign to - /// \param function The function to assign - template>::value, - int>> - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer> &lhs, - const detail::Function &function) { - using Function = detail::Function; - using Scalar = - typename array::ArrayContainer>::Scalar; - constexpr int64_t packetWidth = typetraits::TypeInfo::packetWidth; - constexpr bool allowVectorisation = - typetraits::TypeInfo< - detail::Function>::allowVectorisation && - Function::argsAreSameType; - - const int64_t size = function.shape().size(); - const int64_t vectorSize = size - (size % packetWidth); - - // Ensure the function can actually be assigned to the array container - // static_assert( - // typetraits::IsSame::Scalar>, - // "Function return type must be the same as the array container's scalar type"); - LIBRAPID_ASSERT(lhs.shape() == function.shape(), "Shapes must be equal"); - - if constexpr (allowVectorisation) { - for (int64_t index = 0; index < vectorSize; index += packetWidth) { - lhs.writePacket(index, function.packet(index)); - } - - // Assign the remaining elements - for (int64_t index = vectorSize; index < size; ++index) { - lhs.write(index, function.scalar(index)); - } - } else { - // Assign the remaining elements - for (int64_t index = 0; index < size; ++index) { - lhs.write(index, function.scalar(index)); - } - } - } - - /// Trivial assignment with fixed-size arrays - /// \tparam ShapeType_ The shape type of the array container - /// \tparam StorageScalar The scalar type of the storage object - /// \tparam StorageSize The size of the storage object - /// \tparam Functor_ The function type - /// \tparam Args The argument types of the function - template>::value, - int>> - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer> &lhs, - const detail::Function &function) { - using Function = detail::Function; - using Scalar = - typename array::ArrayContainer>::Scalar; - constexpr int64_t packetWidth = typetraits::TypeInfo::packetWidth; - constexpr int64_t elements = ::librapid::product(); - constexpr int64_t vectorSize = elements - (elements % packetWidth); - constexpr bool allowVectorisation = - typetraits::TypeInfo< - detail::Function>::allowVectorisation && - Function::argsAreSameType; - - // Ensure the function can actually be assigned to the array container - static_assert( - typetraits::IsSame::Scalar>, - "Function return type must be the same as the array container's scalar type"); - LIBRAPID_ASSERT(lhs.shape() == function.shape(), "Shapes must be equal"); - - if constexpr (allowVectorisation) { - for (int64_t index = 0; index < vectorSize; index += packetWidth) { - lhs.writePacket(index, function.packet(index)); - } - - // Assign the remaining elements - for (int64_t index = vectorSize; index < elements; ++index) { - lhs.write(index, function.scalar(index)); - } - } else { - // Assign the remaining elements - for (int64_t index = 0; index < elements; ++index) { - lhs.write(index, function.scalar(index)); - } - } - } - - /// Trivial assignment with parallel execution - /// \tparam ShapeType_ The shape type of the array container - /// \tparam StorageScalar The scalar type of the storage object - /// \tparam Functor_ The function type - /// \tparam Args The argument types of the function - /// \param lhs The array container to assign to - /// \param function The function to assign - /// \see assign(array::ArrayContainer> - /// &lhs, const detail::Function &function) - template>::value, - int>> - LIBRAPID_ALWAYS_INLINE void assignParallel( - array::ArrayContainer> &lhs, - const detail::Function &function) { - using Function = detail::Function; - using Scalar = - typename array::ArrayContainer>::Scalar; - constexpr int64_t packetWidth = typetraits::TypeInfo::packetWidth; - - constexpr bool allowVectorisation = - typetraits::TypeInfo< - detail::Function>::allowVectorisation && - Function::argsAreSameType; - - const int64_t size = function.shape().size(); - const int64_t vectorSize = size - (size % packetWidth); - - // Ensure the function can actually be assigned to the array container - // static_assert( - // typetraits::IsSame::Scalar>, - // "Function return type must be the same as the array container's scalar type"); - LIBRAPID_ASSERT(lhs.shape() == function.shape(), "Shapes must be equal"); - - if constexpr (allowVectorisation) { + // All assignment operators are forward declared in "forward.hpp" so they can be used + // elsewhere. They are defined here. + + namespace detail { + /// Trivial array assignment operator -- assignment can be done with a single vectorised + /// loop over contiguous data. + /// \tparam ShapeType_ The shape type of the array container + /// \tparam StorageScalar The scalar type of the storage object + /// \tparam Functor_ The function type + /// \tparam Args The argument types of the function + /// \param lhs The array container to assign to + /// \param function The function to assign + template>::value, + int>> + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer> &lhs, + const detail::Function &function) { + using Function = detail::Function; + using Scalar = + typename array::ArrayContainer>::Scalar; + constexpr int64_t packetWidth = typetraits::TypeInfo::packetWidth; + constexpr bool allowVectorisation = + typetraits::TypeInfo< + detail::Function>::allowVectorisation && + Function::argsAreSameType; + + const int64_t size = function.shape().size(); + const int64_t vectorSize = size - (size % packetWidth); + + // Ensure the function can actually be assigned to the array container + // static_assert( + // typetraits::IsSame::Scalar>, + // "Function return type must be the same as the array container's scalar type"); + LIBRAPID_ASSERT(lhs.shape() == function.shape(), "Shapes must be equal"); + + if constexpr (allowVectorisation) { + for (int64_t index = 0; index < vectorSize; index += packetWidth) { + lhs.writePacket(index, function.packet(index)); + } + + // Assign the remaining elements + for (int64_t index = vectorSize; index < size; ++index) { + lhs.write(index, function.scalar(index)); + } + } else { + // Assign the remaining elements + for (int64_t index = 0; index < size; ++index) { + lhs.write(index, function.scalar(index)); + } + } + } + + /// Trivial assignment with fixed-size arrays + /// \tparam ShapeType_ The shape type of the array container + /// \tparam StorageScalar The scalar type of the storage object + /// \tparam StorageSize The size of the storage object + /// \tparam Functor_ The function type + /// \tparam Args The argument types of the function + template>::value, + int>> + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer> &lhs, + const detail::Function &function) { + using Function = detail::Function; + using Scalar = + typename array::ArrayContainer>::Scalar; + constexpr int64_t packetWidth = typetraits::TypeInfo::packetWidth; + constexpr int64_t elements = ::librapid::product(); + constexpr int64_t vectorSize = elements - (elements % packetWidth); + constexpr bool allowVectorisation = + typetraits::TypeInfo< + detail::Function>::allowVectorisation && + Function::argsAreSameType; + + // Ensure the function can actually be assigned to the array container + static_assert( + typetraits::IsSame::Scalar>, + "Function return type must be the same as the array container's scalar type"); + LIBRAPID_ASSERT(lhs.shape() == function.shape(), "Shapes must be equal"); + + if constexpr (allowVectorisation) { + for (int64_t index = 0; index < vectorSize; index += packetWidth) { + lhs.writePacket(index, function.packet(index)); + } + + // Assign the remaining elements + for (int64_t index = vectorSize; index < elements; ++index) { + lhs.write(index, function.scalar(index)); + } + } else { + // Assign the remaining elements + for (int64_t index = 0; index < elements; ++index) { + lhs.write(index, function.scalar(index)); + } + } + } + + /// Trivial assignment with parallel execution + /// \tparam ShapeType_ The shape type of the array container + /// \tparam StorageScalar The scalar type of the storage object + /// \tparam Functor_ The function type + /// \tparam Args The argument types of the function + /// \param lhs The array container to assign to + /// \param function The function to assign + /// \see assign(array::ArrayContainer> + /// &lhs, const detail::Function &function) + template>::value, + int>> + LIBRAPID_ALWAYS_INLINE void + assignParallel(array::ArrayContainer> &lhs, + const detail::Function &function) { + using Function = detail::Function; + using Scalar = + typename array::ArrayContainer>::Scalar; + constexpr int64_t packetWidth = typetraits::TypeInfo::packetWidth; + + constexpr bool allowVectorisation = + typetraits::TypeInfo< + detail::Function>::allowVectorisation && + Function::argsAreSameType; + + const int64_t size = function.shape().size(); + const int64_t vectorSize = size - (size % packetWidth); + + // Ensure the function can actually be assigned to the array container + // static_assert( + // typetraits::IsSame::Scalar>, + // "Function return type must be the same as the array container's scalar type"); + LIBRAPID_ASSERT(lhs.shape() == function.shape(), "Shapes must be equal"); + + if constexpr (allowVectorisation) { #pragma omp parallel for shared(vectorSize, lhs, function) default(none) \ num_threads(int(global::numThreads)) - for (int64_t index = 0; index < vectorSize; index += packetWidth) { - lhs.writePacket(index, function.packet(index)); - } - - // Assign the remaining elements - for (int64_t index = vectorSize; index < size; ++index) { - lhs.write(index, function.scalar(index)); - } - } else { + for (int64_t index = 0; index < vectorSize; index += packetWidth) { + lhs.writePacket(index, function.packet(index)); + } + + // Assign the remaining elements + for (int64_t index = vectorSize; index < size; ++index) { + lhs.write(index, function.scalar(index)); + } + } else { #pragma omp parallel for shared(vectorSize, lhs, function, size) default(none) \ num_threads(int(global::numThreads)) - for (int64_t index = 0; index < size; ++index) { - lhs.write(index, function.scalar(index)); - } - } - } - - /// Trivial assignment with fixed-size arrays and parallel execution - /// \tparam ShapeType_ The shape type of the array container - /// \tparam StorageScalar The scalar type of the storage object - /// \tparam StorageSize The size of the storage object - /// \tparam Functor_ The function type - /// \tparam Args The argument types of the function - template>::value, - int>> - LIBRAPID_ALWAYS_INLINE void assignParallel( - array::ArrayContainer> &lhs, - const detail::Function &function) { - using Function = detail::Function; - using Scalar = - typename array::ArrayContainer>::Scalar; - constexpr int64_t packetWidth = typetraits::TypeInfo::packetWidth; - - constexpr bool allowVectorisation = - typetraits::TypeInfo< - detail::Function>::allowVectorisation && - Function::argsAreSameType; - - constexpr int64_t size = ::librapid::product(); - constexpr int64_t vectorSize = size - (size % packetWidth); - - // Ensure the function can actually be assigned to the array container - static_assert( - typetraits::IsSame::Scalar>, - "Function return type must be the same as the array container's scalar type"); - LIBRAPID_ASSERT(lhs.shape() == function.shape(), "Shapes must be equal"); - - if constexpr (allowVectorisation) { + for (int64_t index = 0; index < size; ++index) { + lhs.write(index, function.scalar(index)); + } + } + } + + /// Trivial assignment with fixed-size arrays and parallel execution + /// \tparam ShapeType_ The shape type of the array container + /// \tparam StorageScalar The scalar type of the storage object + /// \tparam StorageSize The size of the storage object + /// \tparam Functor_ The function type + /// \tparam Args The argument types of the function + template>::value, + int>> + LIBRAPID_ALWAYS_INLINE void assignParallel( + array::ArrayContainer> &lhs, + const detail::Function &function) { + using Function = detail::Function; + using Scalar = + typename array::ArrayContainer>::Scalar; + constexpr int64_t packetWidth = typetraits::TypeInfo::packetWidth; + + constexpr bool allowVectorisation = + typetraits::TypeInfo< + detail::Function>::allowVectorisation && + Function::argsAreSameType; + + constexpr int64_t size = ::librapid::product(); + constexpr int64_t vectorSize = size - (size % packetWidth); + + // Ensure the function can actually be assigned to the array container + static_assert( + typetraits::IsSame::Scalar>, + "Function return type must be the same as the array container's scalar type"); + LIBRAPID_ASSERT(lhs.shape() == function.shape(), "Shapes must be equal"); + + if constexpr (allowVectorisation) { #pragma omp parallel for shared(vectorSize, lhs, function) default(none) \ num_threads(int(global::numThreads)) - for (int64_t index = 0; index < vectorSize; index += packetWidth) { - lhs.writePacket(index, function.packet(index)); - } - - // Assign the remaining elements - for (int64_t index = vectorSize; index < size; ++index) { - lhs.write(index, function.scalar(index)); - } - } else { + for (int64_t index = 0; index < vectorSize; index += packetWidth) { + lhs.writePacket(index, function.packet(index)); + } + + // Assign the remaining elements + for (int64_t index = vectorSize; index < size; ++index) { + lhs.write(index, function.scalar(index)); + } + } else { #pragma omp parallel for shared(vectorSize, lhs, function, size) default(none) \ num_threads(int(global::numThreads)) - for (int64_t index = vectorSize; index < size; ++index) { - lhs.write(index, function.scalar(index)); - } - } - } - } // namespace detail - - /* - * Since we cannot (reasonably) generate the kernels at runtime (ease of development, - * performance, etc.), operations such as (a + b) + c cannot be made into a singe kernel. - * Therefore, we must employ a recursive evaluator to evaluate the expression tree. - * - * Unfortunately, this is surprisingly difficult to do with the setup used by the CPU side of - * things. - * - * We can approach this problem as follows: - * 1. Create a templated function to call the kernel - * 2. Create a function with two specialisations - * - One for an array::ArrayContainer of some kind (this is the base case) - * - One for an Expression (this is the recursive case) - * - The base case returns the array::ArrayContainer's storage object - * - The recursive case returns the result of calling the templated function with the - * Expression's left and right children - * 3. Call the templated function with the result of the recursive function - * - * This will be slower than a single kernel call, but it saves us from having to generate one - * each time, improving performance in the long run (hopefully). - */ + for (int64_t index = vectorSize; index < size; ++index) { + lhs.write(index, function.scalar(index)); + } + } + } + } // namespace detail + + /* + * Since we cannot (reasonably) generate the kernels at runtime (ease of development, + * performance, etc.), operations such as (a + b) + c cannot be made into a singe kernel. + * Therefore, we must employ a recursive evaluator to evaluate the expression tree. + * + * Unfortunately, this is surprisingly difficult to do with the setup used by the CPU side of + * things. + * + * We can approach this problem as follows: + * 1. Create a templated function to call the kernel + * 2. Create a function with two specialisations + * - One for an array::ArrayContainer of some kind (this is the base case) + * - One for an Expression (this is the recursive case) + * - The base case returns the array::ArrayContainer's storage object + * - The recursive case returns the result of calling the templated function with the + * Expression's left and right children + * 3. Call the templated function with the result of the recursive function + * + * This will be slower than a single kernel call, but it saves us from having to generate one + * each time, improving performance in the long run (hopefully). + */ #if defined(LIBRAPID_HAS_OPENCL) - namespace opencl { - template::type != - ::librapid::detail::LibRapidType::Scalar, - int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto dataSourceExtractor(const T &obj) { - return obj.storage().data(); - } - - template::type == - ::librapid::detail::LibRapidType::Scalar, - int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto dataSourceExtractor(const T &obj) { - return obj; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const auto & - openCLTupleEvaluatorImpl(const T &scalar) { - return scalar; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const - array::ArrayContainer> & - openCLTupleEvaluatorImpl( - const array::ArrayContainer> &container) { - return container; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - openCLTupleEvaluatorImpl(const detail::Function &function) { - array::ArrayContainer< - decltype(function.shape()), - OpenCLStorage::Scalar>> - result(function.shape()); - assign(result, function); - return result; - } - - template - LIBRAPID_ALWAYS_INLINE void - openCLTupleEvaluator(std::index_sequence, const std::string &kernelBase, - cl::Buffer &dst, - const detail::Function &function) { - using Scalar = typename detail::Function::Scalar; - runLinearKernel( - kernelBase, - function.shape().size(), - dst, - dataSourceExtractor(openCLTupleEvaluatorImpl(std::get(function.args())))...); - } - } // namespace opencl - - namespace detail { - template>::value, - int>> - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer> &lhs, - const detail::Function &function) { - // Unfortunately, as we are not generating the kernels at runtime, we can't use - // temporary-free evaluation. Instead, we must recursively evaluate each sub-operation - // until a final result is computed - - constexpr const char *filename = typetraits::TypeInfo::filename; - const char *kernelBase = typetraits::TypeInfo::getKernelName(function.args()); - using Scalar = - typename array::ArrayContainer>::Scalar; - - const auto args = function.args(); - constexpr size_t argSize = std::tuple_size::value; - ::librapid::opencl::openCLTupleEvaluator( - std::make_index_sequence(), kernelBase, lhs.storage().data(), function); - } - } // namespace detail + namespace opencl { + template::type != + ::librapid::detail::LibRapidType::Scalar, + int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto dataSourceExtractor(const T &obj) { + return obj.storage().data(); + } + + template::type == + ::librapid::detail::LibRapidType::Scalar, + int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto dataSourceExtractor(const T &obj) { + return obj; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const auto & + openCLTupleEvaluatorImpl(const T &scalar) { + return scalar; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const + array::ArrayContainer> & + openCLTupleEvaluatorImpl( + const array::ArrayContainer> &container) { + return container; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + openCLTupleEvaluatorImpl(const detail::Function &function) { + array::ArrayContainer< + decltype(function.shape()), + OpenCLStorage::Scalar>> + result(function.shape()); + assign(result, function); + return result; + } + + template + LIBRAPID_ALWAYS_INLINE void + openCLTupleEvaluator(std::index_sequence, const std::string &kernelBase, + cl::Buffer &dst, + const detail::Function &function) { + using Scalar = typename detail::Function::Scalar; + runLinearKernel( + kernelBase, + function.shape().size(), + dst, + dataSourceExtractor(openCLTupleEvaluatorImpl(std::get(function.args())))...); + } + } // namespace opencl + + namespace detail { + template>::value, + int>> + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer> &lhs, + const detail::Function &function) { + // Unfortunately, as we are not generating the kernels at runtime, we can't use + // temporary-free evaluation. Instead, we must recursively evaluate each sub-operation + // until a final result is computed + + constexpr const char *filename = typetraits::TypeInfo::filename; + const char *kernelBase = typetraits::TypeInfo::getKernelName(function.args()); + using Scalar = + typename array::ArrayContainer>::Scalar; + + const auto args = function.args(); + constexpr size_t argSize = std::tuple_size::value; + ::librapid::opencl::openCLTupleEvaluator( + std::make_index_sequence(), kernelBase, lhs.storage().data(), function); + } + } // namespace detail #endif // LIBRAPID_HAS_CUDA #if defined(LIBRAPID_HAS_CUDA) - namespace cuda { - template::type != - ::librapid::detail::LibRapidType::Scalar, - int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto dataSourceExtractor(const T &obj) { - return obj.storage().begin().get(); - } - - template::type == - ::librapid::detail::LibRapidType::Scalar, - int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const auto &dataSourceExtractor(const T &obj) { - return obj; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const auto & - cudaTupleEvaluatorImpl(const T &scalar) { - return scalar; - } - - /// Helper for "evaluating" an array::ArrayContainer - /// \tparam ShapeType The shape type of the array::ArrayContainer - /// \tparam StorageScalar The scalar type of the array::ArrayContainer's storage object - /// \param container The array::ArrayContainer to evaluate - /// \return The array::ArrayContainer itself - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const - array::ArrayContainer> & - cudaTupleEvaluatorImpl( - const array::ArrayContainer> &container) { - return container; - } - - /// Helper for evaluating an expression - /// \tparam descriptor The descriptor of the expression - /// \tparam Functor The function type of the expression - /// \tparam Args The argument types of the expression - /// \param function The expression to evaluate - /// \return The result of evaluating the expression - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - cudaTupleEvaluatorImpl(const detail::Function &function) { - array::ArrayContainer< - decltype(function.shape()), - CudaStorage::Scalar>> - result(function.shape()); - assign(result, function); - return result; - } - - template - struct CudaVectorExtractor { - static constexpr auto tester() { - using ScalarType = typename typetraits::TypeInfo>::Scalar; - constexpr bool allowVectorisation = typetraits::TypeInfo::allowVectorisation; - - if constexpr (std::is_same_v && allowVectorisation) { - return CUDA_FLOAT_VECTOR_TYPE {}; - } else if constexpr (std::is_same_v && allowVectorisation) { - return CUDA_DOUBLE_VECTOR_TYPE {}; - } else { - return ScalarType {}; - } - } - - using Scalar = decltype(tester()); - }; - - template - struct CudaCanVectorise { - public: - template - static constexpr auto extractFirst() { - return First {}; - } - - template - static constexpr bool edgeCases() { - if constexpr (typetraits::TypeInfo::type == detail::LibRapidType::Dual) { - return false; - } else { - return true; - } - } - - static constexpr bool supportsVectorisation = - (typetraits::TypeInfo>::allowVectorisation && ...); - using First = decltype(extractFirst()); - static constexpr bool dtypesAreSame = - (std::is_same_v>::Scalar, - typename typetraits::TypeInfo>::Scalar> && - ...); - static constexpr bool dtypeSupportsVectorisation = - ((typetraits::TypeInfo>::Scalar>::Scalar>::cudaPacketWidth > 1) && - ...); - static constexpr bool fitsEdgeCases = - (edgeCases>::Scalar>() && ...); - - public: - static constexpr bool value = - supportsVectorisation && dtypesAreSame && dtypeSupportsVectorisation && fitsEdgeCases; - }; - - template - struct CudaVectorHelper { - static constexpr bool canVectorise = CudaCanVectorise::value; - - template - static constexpr auto extractor() { - using Scalar = typename typetraits::TypeInfo>::Scalar; - if constexpr (typetraits::TypeInfo>::type == - ::librapid::detail::LibRapidType::Scalar) { - return Scalar {}; - } else if constexpr (canVectorise) { - using Type = typename CudaVectorExtractor::Scalar; - return Type {}; - } else { - return Scalar {}; - } - } - }; - - /// Helper for evaluating a tuple - /// \tparam descriptor The descriptor of the Function - /// \tparam Functor The function type of the Function - /// \tparam Args The argument types of the Function - /// \tparam Pointer The pointer type of the destination - /// \tparam I Index sequence for the tuple - /// \param filename The filename of the kernel - /// \param kernelName The name of the kernel - /// \param dst The memory location to assign data to - /// \param function The Function to evaluate - template - LIBRAPID_ALWAYS_INLINE void - cudaTupleEvaluator(std::index_sequence, const std::string &filename, - const std::string &kernelName, Pointer *dst, - const detail::Function &function) { - // I'm not convinced the logic here is infallible, but it does seem to work. - // It is possible that you could use the Pointer type directly to extract the - // packet width, but I'm not sure if that would work for all cases. - - using Function = detail::Function; - using Scalar = typename typetraits::TypeInfo::Scalar; - using Helper = CudaVectorHelper; - using PacketType = decltype(Helper::template extractor()); - constexpr int64_t cudaPacketWidth = typetraits::TypeInfo::cudaPacketWidth; - - runKernel::template extractor())...>( - filename, - kernelName, - (function.shape().size() + (cudaPacketWidth - 1)) / cudaPacketWidth, // Round up - (function.shape().size() + (cudaPacketWidth - 1)) / cudaPacketWidth, - dst, - dataSourceExtractor(cudaTupleEvaluatorImpl(std::get(function.args())))...); - } - } // namespace cuda - - namespace detail { - /// Trivial assignment with CUDA execution - /// \tparam ShapeType_ The shape type of the array container - /// \tparam StorageScalar The scalar type of the storage object - /// \tparam Functor_ The function type - /// \tparam Args The argument types of the function - /// \param lhs The array container to assign to - /// \param function The function to assign - template>::value, - int>> - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer> &lhs, - const detail::Function &function) { - // Unfortunately, as we are not generating the kernels at runtime, we can't use - // temporary-free evaluation. Instead, we must recursively evaluate each sub-operation - // until a final result is computed - - using Function = detail::Function; - constexpr const char *filename = typetraits::TypeInfo::filename; - const char *kernelName = typetraits::TypeInfo::getKernelName(function.args()); - - using DstType = array::ArrayContainer>; - using Scalar = - decltype(::librapid::cuda::CudaVectorHelper::template extractor()); - - const auto args = function.args(); - constexpr size_t argSize = std::tuple_size::value; - ::librapid::cuda::cudaTupleEvaluator( - std::make_index_sequence(), - filename, - kernelName, - reinterpret_cast(lhs.storage().begin().get()), - function); - } - } // namespace detail + namespace cuda { + template::type != + ::librapid::detail::LibRapidType::Scalar, + int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto dataSourceExtractor(const T &obj) { + return obj.storage().begin().get(); + } + + template::type == + ::librapid::detail::LibRapidType::Scalar, + int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const auto &dataSourceExtractor(const T &obj) { + return obj; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const auto & + cudaTupleEvaluatorImpl(const T &scalar) { + return scalar; + } + + /// Helper for "evaluating" an array::ArrayContainer + /// \tparam ShapeType The shape type of the array::ArrayContainer + /// \tparam StorageScalar The scalar type of the array::ArrayContainer's storage object + /// \param container The array::ArrayContainer to evaluate + /// \return The array::ArrayContainer itself + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const + array::ArrayContainer> & + cudaTupleEvaluatorImpl( + const array::ArrayContainer> &container) { + return container; + } + + /// Helper for evaluating an expression + /// \tparam descriptor The descriptor of the expression + /// \tparam Functor The function type of the expression + /// \tparam Args The argument types of the expression + /// \param function The expression to evaluate + /// \return The result of evaluating the expression + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + cudaTupleEvaluatorImpl(const detail::Function &function) { + array::ArrayContainer< + decltype(function.shape()), + CudaStorage::Scalar>> + result(function.shape()); + assign(result, function); + return result; + } + + template + struct CudaVectorExtractor { + static constexpr auto tester() { + using ScalarType = typename typetraits::TypeInfo>::Scalar; + constexpr bool allowVectorisation = typetraits::TypeInfo::allowVectorisation; + + if constexpr (std::is_same_v && allowVectorisation) { + return CUDA_FLOAT_VECTOR_TYPE {}; + } else if constexpr (std::is_same_v && allowVectorisation) { + return CUDA_DOUBLE_VECTOR_TYPE {}; + } else { + return ScalarType {}; + } + } + + using Scalar = decltype(tester()); + }; + + template + struct CudaCanVectorise { + public: + template + static constexpr auto extractFirst() { + return First {}; + } + + template + static constexpr bool edgeCases() { + if constexpr (typetraits::TypeInfo::type == detail::LibRapidType::Dual) { + return false; + } else { + return true; + } + } + + static constexpr bool supportsVectorisation = + (typetraits::TypeInfo>::allowVectorisation && ...); + using First = decltype(extractFirst()); + static constexpr bool dtypesAreSame = + (std::is_same_v>::Scalar, + typename typetraits::TypeInfo>::Scalar> && + ...); + static constexpr bool dtypeSupportsVectorisation = + ((typetraits::TypeInfo>::Scalar>::Scalar>::cudaPacketWidth > 1) && + ...); + static constexpr bool fitsEdgeCases = + (edgeCases>::Scalar>() && ...); + + public: + static constexpr bool value = + supportsVectorisation && dtypesAreSame && dtypeSupportsVectorisation && fitsEdgeCases; + }; + + template + struct CudaVectorHelper { + static constexpr bool canVectorise = CudaCanVectorise::value; + + template + static constexpr auto extractor() { + using Scalar = typename typetraits::TypeInfo>::Scalar; + if constexpr (typetraits::TypeInfo>::type == + ::librapid::detail::LibRapidType::Scalar) { + return Scalar {}; + } else if constexpr (canVectorise) { + using Type = typename CudaVectorExtractor::Scalar; + return Type {}; + } else { + return Scalar {}; + } + } + }; + + /// Helper for evaluating a tuple + /// \tparam descriptor The descriptor of the Function + /// \tparam Functor The function type of the Function + /// \tparam Args The argument types of the Function + /// \tparam Pointer The pointer type of the destination + /// \tparam I Index sequence for the tuple + /// \param filename The filename of the kernel + /// \param kernelName The name of the kernel + /// \param dst The memory location to assign data to + /// \param function The Function to evaluate + template + LIBRAPID_ALWAYS_INLINE void + cudaTupleEvaluator(std::index_sequence, const std::string &filename, + const std::string &kernelName, Pointer *dst, + const detail::Function &function) { + // I'm not convinced the logic here is infallible, but it does seem to work. + // It is possible that you could use the Pointer type directly to extract the + // packet width, but I'm not sure if that would work for all cases. + + using Function = detail::Function; + using Scalar = typename typetraits::TypeInfo::Scalar; + using Helper = CudaVectorHelper; + using PacketType = decltype(Helper::template extractor()); + constexpr int64_t cudaPacketWidth = typetraits::TypeInfo::cudaPacketWidth; + + runKernel::template extractor())...>( + filename, + kernelName, + (function.shape().size() + (cudaPacketWidth - 1)) / cudaPacketWidth, // Round up + (function.shape().size() + (cudaPacketWidth - 1)) / cudaPacketWidth, + dst, + dataSourceExtractor(cudaTupleEvaluatorImpl(std::get(function.args())))...); + } + } // namespace cuda + + namespace detail { + /// Trivial assignment with CUDA execution + /// \tparam ShapeType_ The shape type of the array container + /// \tparam StorageScalar The scalar type of the storage object + /// \tparam Functor_ The function type + /// \tparam Args The argument types of the function + /// \param lhs The array container to assign to + /// \param function The function to assign + template>::value, + int>> + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer> &lhs, + const detail::Function &function) { + // Unfortunately, as we are not generating the kernels at runtime, we can't use + // temporary-free evaluation. Instead, we must recursively evaluate each sub-operation + // until a final result is computed + + using Function = detail::Function; + constexpr const char *filename = typetraits::TypeInfo::filename; + const char *kernelName = typetraits::TypeInfo::getKernelName(function.args()); + + using DstType = array::ArrayContainer>; + using Scalar = + decltype(::librapid::cuda::CudaVectorHelper::template extractor()); + + const auto args = function.args(); + constexpr size_t argSize = std::tuple_size::value; + ::librapid::cuda::cudaTupleEvaluator( + std::make_index_sequence(), + filename, + kernelName, + reinterpret_cast(lhs.storage().begin().get()), + function); + } + } // namespace detail #endif // LIBRAPID_HAS_CUDA } // namespace librapid diff --git a/librapid/include/librapid/array/commaInitializer.hpp b/librapid/include/librapid/array/commaInitializer.hpp index 222fde0f..8fcede35 100644 --- a/librapid/include/librapid/array/commaInitializer.hpp +++ b/librapid/include/librapid/array/commaInitializer.hpp @@ -2,47 +2,47 @@ #define LIBRAPID_ARRAY_COMMA_INITIALIZER_HPP namespace librapid::detail { - /// Allows for an Array object to be initialized with a comma separated list of values. While - /// this is not particularly useful for large arrays, it is a very quick and easy way to - /// initialize smaller arrays with a few values. - /// \tparam ArrT The type of the Array object to be initialized. - template - class CommaInitializer { - public: - /// The scalar type of the Array object. - using Scalar = typename typetraits::TypeInfo::Scalar; - - CommaInitializer() = delete; - - /// Construct a CommaInitializer from an Array object. - /// \param dst The Array object to initialize. - /// \param val The first value to initialize the Array object with. - template - explicit CommaInitializer(ArrT &dst, const T &val) : m_array(dst) { - next(static_cast(val)); - } - - /// Initialize the next element of the Array object. - template - CommaInitializer &operator,(const T &val) { - next(static_cast(val)); - return *this; - } - - private: - /// Initialize the current element of the Array and increment the index. - /// \param other The value to initialize the current element with. - void next(const Scalar &other) { - m_array.storage()[m_index] = other; - ++m_index; - } - - /// The Array object to initialize. - ArrT &m_array; - - /// The current index of the Array object. - int64_t m_index = 0; - }; + /// Allows for an Array object to be initialized with a comma separated list of values. While + /// this is not particularly useful for large arrays, it is a very quick and easy way to + /// initialize smaller arrays with a few values. + /// \tparam ArrT The type of the Array object to be initialized. + template + class CommaInitializer { + public: + /// The scalar type of the Array object. + using Scalar = typename typetraits::TypeInfo::Scalar; + + CommaInitializer() = delete; + + /// Construct a CommaInitializer from an Array object. + /// \param dst The Array object to initialize. + /// \param val The first value to initialize the Array object with. + template + explicit CommaInitializer(ArrT &dst, const T &val) : m_array(dst) { + next(static_cast(val)); + } + + /// Initialize the next element of the Array object. + template + CommaInitializer &operator,(const T &val) { + next(static_cast(val)); + return *this; + } + + private: + /// Initialize the current element of the Array and increment the index. + /// \param other The value to initialize the current element with. + void next(const Scalar &other) { + m_array.storage()[m_index] = other; + ++m_index; + } + + /// The Array object to initialize. + ArrT &m_array; + + /// The current index of the Array object. + int64_t m_index = 0; + }; } // namespace librapid::detail #endif // LIBRAPID_ARRAY_COMMA_INITIALIZER_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/fill.hpp b/librapid/include/librapid/array/fill.hpp index a33f1285..2b280c50 100644 --- a/librapid/include/librapid/array/fill.hpp +++ b/librapid/include/librapid/array/fill.hpp @@ -2,171 +2,171 @@ #define LIBRAPID_ARRAY_FILL_HPP namespace librapid { - template - LIBRAPID_ALWAYS_INLINE void fill(array::ArrayContainer &dst, - const Scalar &value) { - dst = array::ArrayContainer(dst.shape(), value); - } - - template - LIBRAPID_ALWAYS_INLINE void - fillRandom(array::ArrayContainer> &dst, const Lower &lower, - const Upper &upper) { - ShapeType shape = dst.shape(); - auto *data = dst.storage().begin(); - bool parallel = global::numThreads != 1 && shape.size() > global::multithreadThreshold; - - if (parallel) { + template + LIBRAPID_ALWAYS_INLINE void fill(array::ArrayContainer &dst, + const Scalar &value) { + dst = array::ArrayContainer(dst.shape(), value); + } + + template + LIBRAPID_ALWAYS_INLINE void + fillRandom(array::ArrayContainer> &dst, const Lower &lower, + const Upper &upper) { + ShapeType shape = dst.shape(); + auto *data = dst.storage().begin(); + bool parallel = global::numThreads != 1 && shape.size() > global::multithreadThreshold; + + if (parallel) { #pragma omp parallel for - for (int64_t i = 0; i < shape.size(); ++i) { - data[i] = random(static_cast(lower), - static_cast(upper)); - } - } else { - for (int64_t i = 0; i < shape.size(); ++i) { - data[i] = random(static_cast(lower), - static_cast(upper)); - } - } - } - - template - LIBRAPID_ALWAYS_INLINE void - fillRandomGaussian(array::ArrayContainer> &dst, - const Lower &lower, const Upper &upper) { - ShapeType shape = dst.shape(); - auto *data = dst.storage().begin(); - bool parallel = global::numThreads != 1 && shape.size() > global::multithreadThreshold; - - if (parallel) { + for (int64_t i = 0; i < shape.size(); ++i) { + data[i] = random(static_cast(lower), + static_cast(upper)); + } + } else { + for (int64_t i = 0; i < shape.size(); ++i) { + data[i] = random(static_cast(lower), + static_cast(upper)); + } + } + } + + template + LIBRAPID_ALWAYS_INLINE void + fillRandomGaussian(array::ArrayContainer> &dst, + const Lower &lower, const Upper &upper) { + ShapeType shape = dst.shape(); + auto *data = dst.storage().begin(); + bool parallel = global::numThreads != 1 && shape.size() > global::multithreadThreshold; + + if (parallel) { #pragma omp parallel for - for (int64_t i = 0; i < shape.size(); ++i) { - data[i] = randomGaussian(); - } - } else { - for (int64_t i = 0; i < shape.size(); ++i) { - data[i] = randomGaussian(); - } - } - } + for (int64_t i = 0; i < shape.size(); ++i) { + data[i] = randomGaussian(); + } + } else { + for (int64_t i = 0; i < shape.size(); ++i) { + data[i] = randomGaussian(); + } + } + } #if defined(LIBRAPID_HAS_OPENCL) - template - LIBRAPID_ALWAYS_INLINE void - fillRandom(array::ArrayContainer> &dst, - const Lower &lower, const Upper &upper) { - ShapeType shape = dst.shape(); - int64_t elements = shape.size(); - - // Initialize a buffer of random seeds - static int64_t numSeeds = 1024; - static bool initialized = false; - static Array seeds(Shape {numSeeds}); - if (global::reseed || !initialized) { - for (int64_t i = 0; i < numSeeds; ++i) { seeds(i) = randint(0, INT64_MAX); } - initialized = true; - - // reseed is controlled by the random module, so we don't need to worry about it here - } - - // Run the kernel - opencl::runLinearKernel("fillRandom", - elements, - dst.storage().data(), - elements, - static_cast(lower), - static_cast(upper), - seeds.storage().data(), - numSeeds); - } + template + LIBRAPID_ALWAYS_INLINE void + fillRandom(array::ArrayContainer> &dst, + const Lower &lower, const Upper &upper) { + ShapeType shape = dst.shape(); + int64_t elements = shape.size(); + + // Initialize a buffer of random seeds + static int64_t numSeeds = 1024; + static bool initialized = false; + static Array seeds(Shape {numSeeds}); + if (global::reseed || !initialized) { + for (int64_t i = 0; i < numSeeds; ++i) { seeds(i) = randint(0, INT64_MAX); } + initialized = true; + + // reseed is controlled by the random module, so we don't need to worry about it here + } + + // Run the kernel + opencl::runLinearKernel("fillRandom", + elements, + dst.storage().data(), + elements, + static_cast(lower), + static_cast(upper), + seeds.storage().data(), + numSeeds); + } #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - template - LIBRAPID_ALWAYS_INLINE void - fillRandom(array::ArrayContainer> &dst, - const Lower &lower, const Upper &upper) { - ShapeType shape = dst.shape(); - int64_t elements = shape.size(); - - // Initialize a buffer of random seeds - static int64_t numSeeds = 1024; - static bool initialized = false; - static Array seeds(Shape {numSeeds}); - - if (global::reseed || !initialized) { - for (int64_t i = 0; i < numSeeds; ++i) { seeds(i) = randint(0, INT64_MAX); } - initialized = true; - - // reseed is controlled by the random module, so we don't need to worry about it here - } - - cuda::runKernel( - "fill", - std::is_same_v ? "fillRandomHalf" : "fillRandom", - elements, - dst.storage().data().get(), - elements, - static_cast(lower), - static_cast(upper), - seeds.storage().data().get(), - numSeeds); - } - - template - LIBRAPID_ALWAYS_INLINE void - fillRandom(array::ArrayContainer> &dst, const Lower &lower, - const Upper &upper) { - ShapeType shape = dst.shape(); - int64_t elements = shape.size(); - - // Create a pseudo-random number generator - static curandGenerator_t prng; - static bool initialized = false; - - if (!initialized) { - curandCreateGenerator(&prng, CURAND_RNG_PSEUDO_DEFAULT); - curandSetPseudoRandomGeneratorSeed(prng, global::randomSeed); - initialized = true; - } - - if (global::reseed) { curandSetPseudoRandomGeneratorSeed(prng, global::randomSeed); } - - // Run the kernel - curandGenerateUniform(prng, dst.storage().data().get(), elements); - - // Scale the result to the desired range - dst = dst * (upper - lower) + lower; - } - - template - LIBRAPID_ALWAYS_INLINE void - fillRandom(array::ArrayContainer> &dst, const Lower &lower, - const Upper &upper) { - ShapeType shape = dst.shape(); - int64_t elements = shape.size(); - - // Create a pseudo-random number generator - static curandGenerator_t prng; - static bool initialized = false; - - if (!initialized) { - curandCreateGenerator(&prng, CURAND_RNG_PSEUDO_DEFAULT); - curandSetPseudoRandomGeneratorSeed(prng, global::randomSeed); - initialized = true; - } - - if (global::reseed) { curandSetPseudoRandomGeneratorSeed(prng, global::randomSeed); } - - // Run the kernel - curandGenerateUniformDouble(prng, dst.storage().data().get(), elements); - - // Scale the result to the desired range - dst = dst * (upper - lower) + lower; - } + template + LIBRAPID_ALWAYS_INLINE void + fillRandom(array::ArrayContainer> &dst, + const Lower &lower, const Upper &upper) { + ShapeType shape = dst.shape(); + int64_t elements = shape.size(); + + // Initialize a buffer of random seeds + static int64_t numSeeds = 1024; + static bool initialized = false; + static Array seeds(Shape {numSeeds}); + + if (global::reseed || !initialized) { + for (int64_t i = 0; i < numSeeds; ++i) { seeds(i) = randint(0, INT64_MAX); } + initialized = true; + + // reseed is controlled by the random module, so we don't need to worry about it here + } + + cuda::runKernel( + "fill", + std::is_same_v ? "fillRandomHalf" : "fillRandom", + elements, + dst.storage().data().get(), + elements, + static_cast(lower), + static_cast(upper), + seeds.storage().data().get(), + numSeeds); + } + + template + LIBRAPID_ALWAYS_INLINE void + fillRandom(array::ArrayContainer> &dst, const Lower &lower, + const Upper &upper) { + ShapeType shape = dst.shape(); + int64_t elements = shape.size(); + + // Create a pseudo-random number generator + static curandGenerator_t prng; + static bool initialized = false; + + if (!initialized) { + curandCreateGenerator(&prng, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(prng, global::randomSeed); + initialized = true; + } + + if (global::reseed) { curandSetPseudoRandomGeneratorSeed(prng, global::randomSeed); } + + // Run the kernel + curandGenerateUniform(prng, dst.storage().data().get(), elements); + + // Scale the result to the desired range + dst = dst * (upper - lower) + lower; + } + + template + LIBRAPID_ALWAYS_INLINE void + fillRandom(array::ArrayContainer> &dst, const Lower &lower, + const Upper &upper) { + ShapeType shape = dst.shape(); + int64_t elements = shape.size(); + + // Create a pseudo-random number generator + static curandGenerator_t prng; + static bool initialized = false; + + if (!initialized) { + curandCreateGenerator(&prng, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(prng, global::randomSeed); + initialized = true; + } + + if (global::reseed) { curandSetPseudoRandomGeneratorSeed(prng, global::randomSeed); } + + // Run the kernel + curandGenerateUniformDouble(prng, dst.storage().data().get(), elements); + + // Scale the result to the desired range + dst = dst * (upper - lower) + lower; + } #endif // LIBRAPID_HAS_CUDA } // namespace librapid diff --git a/librapid/include/librapid/array/fourierTransform.hpp b/librapid/include/librapid/array/fourierTransform.hpp index 7e2d3c2a..bedde0a7 100644 --- a/librapid/include/librapid/array/fourierTransform.hpp +++ b/librapid/include/librapid/array/fourierTransform.hpp @@ -2,121 +2,122 @@ #define LIBRAPID_ARRAY_FOURIER_TRANFORM_HPP namespace librapid::fft { - namespace detail { - namespace cpu { - template - void rfft(Complex *output, T *input, size_t n) { - pocketfft::shape_t shape = {n}; - pocketfft::stride_t strideIn = {sizeof(T)}; - pocketfft::stride_t strideOut = {sizeof(Complex)}; - size_t axis = 0; - bool forward = true; - T fct = 1.0; - pocketfft::r2c(shape, - strideIn, - strideOut, - axis, - forward, - input, - reinterpret_cast *>(output), - fct, - global::numThreads); - } + namespace detail { + namespace cpu { + template + void rfft(Complex *output, T *input, size_t n) { + pocketfft::shape_t shape = {n}; + pocketfft::stride_t strideIn = {sizeof(T)}; + pocketfft::stride_t strideOut = {sizeof(Complex)}; + size_t axis = 0; + bool forward = true; + T fct = 1.0; + pocketfft::r2c(shape, + strideIn, + strideOut, + axis, + forward, + input, + reinterpret_cast *>(output), + fct, + global::numThreads); + } #if defined(LIBRAPID_HAS_CUDA) - LIBRAPID_INLINE void rfft(Complex *output, double *input, size_t n) { - unsigned int mode = FFTW_ESTIMATE; - fftw_plan plan = fftw_plan_dft_r2c_1d( - (int)n, input, reinterpret_cast(output), mode); - fftw_execute(plan); - fftw_destroy_plan(plan); - } + LIBRAPID_INLINE void rfft(Complex *output, double *input, size_t n) { + unsigned int mode = FFTW_ESTIMATE; + fftw_plan plan = fftw_plan_dft_r2c_1d( + (int)n, input, reinterpret_cast(output), mode); + fftw_execute(plan); + fftw_destroy_plan(plan); + } - LIBRAPID_INLINE void rfft(Complex *output, float *input, size_t n) { - unsigned int mode = FFTW_ESTIMATE; - fftwf_plan plan = fftwf_plan_dft_r2c_1d( - (int)n, input, reinterpret_cast(output), mode); - fftwf_execute(plan); - fftwf_destroy_plan(plan); - } + LIBRAPID_INLINE void rfft(Complex *output, float *input, size_t n) { + unsigned int mode = FFTW_ESTIMATE; + fftwf_plan plan = fftwf_plan_dft_r2c_1d( + (int)n, input, reinterpret_cast(output), mode); + fftwf_execute(plan); + fftwf_destroy_plan(plan); + } #elif defined(LIBRAPID_HAS_FFTW) - LIBRAPID_INLINE void rfft(Complex *output, double *input, size_t n) { - unsigned int mode = FFTW_ESTIMATE; - fftw_plan_with_nthreads((int)global::numThreads); - fftw_plan plan = fftw_plan_dft_r2c_1d( - (int)n, input, reinterpret_cast(output), mode); - fftw_execute(plan); - fftw_destroy_plan(plan); - } + LIBRAPID_INLINE void rfft(Complex *output, double *input, size_t n) { + unsigned int mode = FFTW_ESTIMATE; + fftw_plan_with_nthreads((int)global::numThreads); + fftw_plan plan = fftw_plan_dft_r2c_1d( + (int)n, input, reinterpret_cast(output), mode); + fftw_execute(plan); + fftw_destroy_plan(plan); + } - LIBRAPID_INLINE void rfft(Complex *output, float *input, size_t n) { - unsigned int mode = FFTW_ESTIMATE; - fftwf_plan_with_nthreads((int)global::numThreads); - fftwf_plan plan = fftwf_plan_dft_r2c_1d( - (int)n, input, reinterpret_cast(output), mode); - fftwf_execute(plan); - fftwf_destroy_plan(plan); - } + LIBRAPID_INLINE void rfft(Complex *output, float *input, size_t n) { + unsigned int mode = FFTW_ESTIMATE; + fftwf_plan_with_nthreads((int)global::numThreads); + fftwf_plan plan = fftwf_plan_dft_r2c_1d( + (int)n, input, reinterpret_cast(output), mode); + fftwf_execute(plan); + fftwf_destroy_plan(plan); + } #endif - } // namespace cpu + } // namespace cpu #if defined(LIBRAPID_HAS_CUDA) - namespace gpu { - LIBRAPID_INLINE void rfft(Complex *output, double *input, size_t n) { - cufftHandle plan; - cufftPlan1d(&plan, (int)n, CUFFT_D2Z, 1); - cufftSetStream(plan, global::cudaStream); - cufftExecD2Z(plan, input, reinterpret_cast(output)); - cufftDestroy(plan); - } + namespace gpu { + LIBRAPID_INLINE void rfft(Complex *output, double *input, size_t n) { + cufftHandle plan; + cufftPlan1d(&plan, (int)n, CUFFT_D2Z, 1); + cufftSetStream(plan, global::cudaStream); + cufftExecD2Z(plan, input, reinterpret_cast(output)); + cufftDestroy(plan); + } - LIBRAPID_INLINE void rfft(Complex *output, float *input, size_t n) { - cufftHandle plan; - cufftPlan1d(&plan, (int)n, CUFFT_R2C, 1); - cufftSetStream(plan, global::cudaStream); - cudaStreamSynchronize(global::cudaStream); - cufftExecR2C(plan, input, reinterpret_cast(output)); - cufftDestroy(plan); - } - } // namespace gpu -#endif // LIBRAPID_HAS_CUDA - } // namespace detail + LIBRAPID_INLINE void rfft(Complex *output, float *input, size_t n) { + cufftHandle plan; + cufftPlan1d(&plan, (int)n, CUFFT_R2C, 1); + cufftSetStream(plan, global::cudaStream); + cudaStreamSynchronize(global::cudaStream); + cufftExecR2C(plan, input, reinterpret_cast(output)); + cufftDestroy(plan); + } + } // namespace gpu +#endif // LIBRAPID_HAS_CUDA + } // namespace detail - - /// \brief Compute the real-valued discrete Fourier transform of a 1D array - /// - /// Given a 1D array of real numbers, compute the discrete Fourier transform of the array. This - /// returns an array of length \f$\frac{n}{2} + 1\f$ where \f$n\f$ is the length of the input - /// array. The returned array contains the non-redundant half of the resulting transform, since - /// the other half can be obtained by taking the complex conjugate of the first half. - /// - /// \tparam ShapeType The shape type of the input array - /// \tparam StorageScalar The scalar type of the input array - /// \param array The input array - /// \return The discrete Fourier transform of the input array - template - LIBRAPID_NODISCARD Array, backend::CPU> - rfft(array::ArrayContainer> &array) { - LIBRAPID_ASSERT(array.ndim() == 1, "RFFT only implemented for 1D arrays"); - int64_t outSize = array.shape()[0] / 2 + 1; - Array, backend::CPU> res(Shape({outSize})); - StorageScalar *input = array.storage().begin(); - Complex *output = res.storage().begin(); - detail::cpu::rfft(output, input, array.shape()[0]); - return res; - } + /// \brief Compute the real-valued discrete Fourier transform of a 1D array + /// + /// Given a 1D array of real numbers, compute the discrete Fourier transform of the array. This + /// returns an array of length \f$\frac{n}{2} + 1\f$ where \f$n\f$ is the length of the input + /// array. The returned array contains the non-redundant half of the resulting transform, since + /// the other half can be obtained by taking the complex conjugate of the first half. + /// + /// \tparam ShapeType The shape type of the input array + /// \tparam StorageScalar The scalar type of the input array + /// \param array The input array + /// \return The discrete Fourier transform of the input array + template + LIBRAPID_NODISCARD auto rfft(array::ArrayContainer> &array) + -> Array, backend::CPU> { + LIBRAPID_ASSERT(array.ndim() == 1, "RFFT only implemented for 1D arrays"); + int64_t outSize = array.shape()[0] / 2 + 1; + Array, backend::CPU> res(Shape({outSize})); + StorageScalar *input = array.storage().begin(); + Complex *output = res.storage().begin(); + detail::cpu::rfft(output, input, array.shape()[0]); + return res; + } #if defined(LIBRAPID_HAS_CUDA) - template - LIBRAPID_NODISCARD Array, backend::CUDA> rfft(Array &array) { - LIBRAPID_ASSERT(array.ndim() == 1, "RFFT only implemented for 1D arrays"); - int64_t outSize = array.shape()[0] / 2 + 1; - Array, backend::CUDA> res(Shape({outSize})); - Scalar *input = array.storage().begin().get(); - Complex *output = res.storage().begin().get(); - detail::gpu::rfft(output, input, array.shape()[0]); - return res; - } + template + LIBRAPID_NODISCARD auto + rfft(array::ArrayContainer> &array) + -> Array, backend::CUDA> { + LIBRAPID_ASSERT(array.ndim() == 1, "RFFT only implemented for 1D arrays"); + int64_t outSize = array.shape()[0] / 2 + 1; + Array, backend::CUDA> res(Shape({outSize})); + StorageScalar *input = array.storage().begin().get(); + Complex *output = res.storage().begin().get(); + detail::gpu::rfft(output, input, array.shape()[0]); + return res; + } #endif // LIBRAPID_HAS_CUDA } // namespace librapid::fft diff --git a/librapid/include/librapid/array/function.hpp b/librapid/include/librapid/array/function.hpp index 698359fd..41463e50 100644 --- a/librapid/include/librapid/array/function.hpp +++ b/librapid/include/librapid/array/function.hpp @@ -2,280 +2,281 @@ #define LIBRAPID_ARRAY_FUNCTION_HPP namespace librapid { - namespace typetraits { - // Extract allowVectorisation from the input types - template - constexpr bool checkAllowVectorisation() { - if constexpr (sizeof...(T) == 0) { - return TypeInfo>::allowVectorisation; - } else { - using T1 = typename TypeInfo>::Scalar; - return TypeInfo>::allowVectorisation && - checkAllowVectorisation() && - (std::is_same_v>::Scalar> && ...); - } - } - - template - constexpr auto commonBackend() { - using FirstBackend = typename TypeInfo>::Backend; - if constexpr (sizeof...(Rest) == 0) { - return FirstBackend {}; - } else { - using RestBackend = decltype(commonBackend()); - if constexpr (std::is_same_v || - std::is_same_v) { - return backend::OpenCLIfAvailable {}; - } else if constexpr (std::is_same_v || - std::is_same_v) { - return backend::CUDAIfAvailable {}; - } else { - return backend::CPU {}; - } - } - } - - template - struct TypeInfo<::librapid::detail::Function> { - static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction; - using Scalar = decltype(std::declval()( - std::declval>::Scalar>()...)); - using Backend = decltype(commonBackend()); - - static constexpr bool allowVectorisation = checkAllowVectorisation(); - - static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; - static constexpr bool supportsLogical = TypeInfo::supportsLogical; - static constexpr bool supportsBinary = TypeInfo::supportsBinary; - }; - - LIBRAPID_DEFINE_AS_TYPE(typename desc COMMA typename Functor_ COMMA typename... Args, - ::librapid::detail::Function); - } // namespace typetraits - - namespace detail { - // Descriptor is defined in "forward.hpp" - - template< - typename Packet, typename T, - typename std::enable_if_t< - typetraits::TypeInfo::type != ::librapid::detail::LibRapidType::Scalar, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, - size_t index) { - static_assert(std::is_same_v, - "Packet types do not match"); - return obj.packet(index); - } - - template< - typename Packet, typename T, - typename std::enable_if_t< - typetraits::TypeInfo::type == ::librapid::detail::LibRapidType::Scalar, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, size_t) { - return Packet(obj); - } - - template::type != - ::librapid::detail::LibRapidType::Scalar, - int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t index) { - return obj.scalar(index); - } - - template::type == - ::librapid::detail::LibRapidType::Scalar, - int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t) { - return obj; - } - - template - constexpr auto scalarTypesAreSame() { - if constexpr (sizeof...(Rest) == 0) { - using Scalar = typename typetraits::TypeInfo>::Scalar; - return Scalar {}; - } else { - using RestType = decltype(scalarTypesAreSame()); - if constexpr (std::is_same_v< - typename typetraits::TypeInfo>::Scalar, - RestType>) { - return RestType {}; - } else { - return std::false_type {}; - } - } - } - - template - class Function { - public: - using Type = Function; - using Functor = Functor_; - using ShapeType = Shape; - using StrideType = ShapeType; - using Scalar = typename typetraits::TypeInfo::Scalar; - using Backend = typename typetraits::TypeInfo::Backend; - using Packet = typename typetraits::TypeInfo::Packet; - using Iterator = detail::ArrayIterator; - - using Descriptor = desc; - static constexpr bool argsAreSameType = - !std::is_same_v()), std::false_type>; - - Function() = default; - - /// Constructs a Function from a functor and arguments. - /// \param functor The functor to use. - /// \param args The arguments to use. - LIBRAPID_ALWAYS_INLINE explicit Function(const Functor &functor, const Args &...args); - - /// Constructs a Function from another function. - /// \param other The Function to copy. - LIBRAPID_ALWAYS_INLINE Function(const Function &other) = default; - - /// Construct a Function from a temporary function. - /// \param other The Function to move. - LIBRAPID_ALWAYS_INLINE Function(Function &&other) noexcept = default; - - /// Assigns a Function to this function. - /// \param other The Function to copy. - /// \return A reference to this Function. - LIBRAPID_ALWAYS_INLINE Function &operator=(const Function &other) = default; - - /// Assigns a temporary Function to this Function. - /// \param other The Function to move. - /// \return A reference to this Function. - LIBRAPID_ALWAYS_INLINE Function &operator=(Function &&other) noexcept = default; - - /// Return the shape of the Function's result - /// \return The shape of the Function's result - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto shape() const; - - /// Return the arguments in the Function - /// \return The arguments in the Function - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto &args() const; - - /// Return an evaluated Array object - /// \return - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index) const; - - /// Evaluates the function at the given index, returning a Packet result. - /// \param index The index to evaluate at. - /// \return The result of the function (vectorized). - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packet(size_t index) const; - - /// Evaluates the function at the given index, returning a Scalar result. - /// \param index The index to evaluate at. - /// \return The result of the function (scalar). - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalar(size_t index) const; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const; - - /// Return a string representation of the Function - /// \param format The format to use. - /// \return A string representation of the Function - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; - - private: - /// Implementation detail -- evaluates the function at the given index, - /// returning a Packet result. - /// \tparam I The index sequence. - /// \param index The index to evaluate at. - /// \return The result of the function (vectorized). - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetImpl(std::index_sequence, - size_t index) const; - - /// Implementation detail -- evaluates the function at the given index, - /// returning a Scalar result. - /// \tparam I The index sequence. - /// \param index The index to evaluate at. - /// \return The result of the function (scalar). - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalarImpl(std::index_sequence, - size_t index) const; - - Functor m_functor; - std::tuple m_args; - }; - - template - Function::Function(const Functor &functor, const Args &...args) : - m_functor(functor), m_args(args...) {} - - template - auto Function::shape() const { - return typetraits::TypeInfo::getShape(m_args); - } - - template - auto &Function::args() const { - return m_args; - } - - template - auto Function::operator[](int64_t index) const { - return array::ArrayView(*this)[index]; - } - - template - auto Function::eval() const { - auto res = Array(shape()); - res = *this; - return res; - } - - template - typename Function::Packet - Function::packet(size_t index) const { - return packetImpl(std::make_index_sequence(), index); - } - - template - template - auto Function::packetImpl(std::index_sequence, - size_t index) const -> Packet { - return m_functor.packet(packetExtractor(std::get(m_args), index)...); - } - - template - auto Function::scalar(size_t index) const -> Scalar { - return scalarImpl(std::make_index_sequence(), index); - } - - template - template - auto Function::scalarImpl(std::index_sequence, - size_t index) const -> Scalar { - return m_functor(scalarExtractor(std::get(m_args), index)...); - } - - template - auto Function::begin() const -> Iterator { - return Iterator(*this, 0); - } - - template - auto Function::end() const -> Iterator { - return Iterator(*this, shape()[0]); - } - - template - std::string Function::str(const std::string &format) const { - return eval().str(format); - } - } // namespace detail + namespace typetraits { + // Extract allowVectorisation from the input types + template + constexpr bool checkAllowVectorisation() { + if constexpr (sizeof...(T) == 0) { + return TypeInfo>::allowVectorisation; + } else { + using T1 = typename TypeInfo>::Scalar; + return TypeInfo>::allowVectorisation && + checkAllowVectorisation() && + (std::is_same_v>::Scalar> && ...); + } + } + + template + constexpr auto commonBackend() { + using FirstBackend = typename TypeInfo>::Backend; + if constexpr (sizeof...(Rest) == 0) { + return FirstBackend {}; + } else { + using RestBackend = decltype(commonBackend()); + if constexpr (std::is_same_v || + std::is_same_v) { + return backend::OpenCLIfAvailable {}; + } else if constexpr (std::is_same_v || + std::is_same_v) { + return backend::CUDAIfAvailable {}; + } else { + return backend::CPU {}; + } + } + } + + template + struct TypeInfo<::librapid::detail::Function> { + static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction; + using Scalar = decltype(std::declval()( + std::declval>::Scalar>()...)); + using Backend = decltype(commonBackend()); + + static constexpr bool allowVectorisation = checkAllowVectorisation(); + + static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; + static constexpr bool supportsLogical = TypeInfo::supportsLogical; + static constexpr bool supportsBinary = TypeInfo::supportsBinary; + }; + + LIBRAPID_DEFINE_AS_TYPE(typename desc COMMA typename Functor_ COMMA typename... Args, + ::librapid::detail::Function); + } // namespace typetraits + + namespace detail { + // Descriptor is defined in "forward.hpp" + + template< + typename Packet, typename T, + typename std::enable_if_t< + typetraits::TypeInfo::type != ::librapid::detail::LibRapidType::Scalar, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, + size_t index) { + static_assert(std::is_same_v, + "Packet types do not match"); + return obj.packet(index); + } + + template< + typename Packet, typename T, + typename std::enable_if_t< + typetraits::TypeInfo::type == ::librapid::detail::LibRapidType::Scalar, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetExtractor(const T &obj, size_t) { + return Packet(obj); + } + + template::type != + ::librapid::detail::LibRapidType::Scalar, + int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t index) { + return obj.scalar(index); + } + + template::type == + ::librapid::detail::LibRapidType::Scalar, + int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalarExtractor(const T &obj, size_t) { + return obj; + } + + template + constexpr auto scalarTypesAreSame() { + if constexpr (sizeof...(Rest) == 0) { + using Scalar = typename typetraits::TypeInfo>::Scalar; + return Scalar {}; + } else { + using RestType = decltype(scalarTypesAreSame()); + if constexpr (std::is_same_v< + typename typetraits::TypeInfo>::Scalar, + RestType>) { + return RestType {}; + } else { + return std::false_type {}; + } + } + } + + template + class Function { + public: + using Type = Function; + using Functor = Functor_; + using ShapeType = Shape; + using StrideType = ShapeType; + using Scalar = typename typetraits::TypeInfo::Scalar; + using Backend = typename typetraits::TypeInfo::Backend; + using Packet = typename typetraits::TypeInfo::Packet; + using Iterator = detail::ArrayIterator; + + using Descriptor = desc; + static constexpr bool argsAreSameType = + !std::is_same_v()), std::false_type>; + + Function() = default; + + /// Constructs a Function from a functor and arguments. + /// \param functor The functor to use. + /// \param args The arguments to use. + LIBRAPID_ALWAYS_INLINE explicit Function(const Functor &functor, const Args &...args); + + /// Constructs a Function from another function. + /// \param other The Function to copy. + LIBRAPID_ALWAYS_INLINE Function(const Function &other) = default; + + /// Construct a Function from a temporary function. + /// \param other The Function to move. + LIBRAPID_ALWAYS_INLINE Function(Function &&other) noexcept = default; + + /// Assigns a Function to this function. + /// \param other The Function to copy. + /// \return A reference to this Function. + LIBRAPID_ALWAYS_INLINE Function &operator=(const Function &other) = default; + + /// Assigns a temporary Function to this Function. + /// \param other The Function to move. + /// \return A reference to this Function. + LIBRAPID_ALWAYS_INLINE Function &operator=(Function &&other) noexcept = default; + + /// Return the shape of the Function's result + /// \return The shape of the Function's result + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto shape() const; + + /// Return the arguments in the Function + /// \return The arguments in the Function + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto &args() const; + + /// Return an evaluated Array object + /// \return + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator[](int64_t index) const; + + /// Evaluates the function at the given index, returning a Packet result. + /// \param index The index to evaluate at. + /// \return The result of the function (vectorized). + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packet(size_t index) const; + + /// Evaluates the function at the given index, returning a Scalar result. + /// \param index The index to evaluate at. + /// \return The result of the function (scalar). + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalar(size_t index) const; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const; + + /// Return a string representation of the Function + /// \param format The format to use. + /// \return A string representation of the Function + LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; + + private: + /// Implementation detail -- evaluates the function at the given index, + /// returning a Packet result. + /// \tparam I The index sequence. + /// \param index The index to evaluate at. + /// \return The result of the function (vectorized). + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Packet packetImpl(std::index_sequence, + size_t index) const; + + /// Implementation detail -- evaluates the function at the given index, + /// returning a Scalar result. + /// \tparam I The index sequence. + /// \param index The index to evaluate at. + /// \return The result of the function (scalar). + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar scalarImpl(std::index_sequence, + size_t index) const; + + Functor m_functor; + std::tuple m_args; + }; + + template + Function::Function(const Functor &functor, const Args &...args) : + m_functor(functor), m_args(args...) {} + + template + auto Function::shape() const { + return typetraits::TypeInfo::getShape(m_args); + } + + template + auto &Function::args() const { + return m_args; + } + + template + auto Function::operator[](int64_t index) const { + return array::ArrayView(*this)[index]; + } + + template + auto Function::eval() const { + auto res = Array(shape()); + res = *this; + return res; + } + + template + typename Function::Packet + Function::packet(size_t index) const { + return packetImpl(std::make_index_sequence(), index); + } + + template + template + auto Function::packetImpl(std::index_sequence, + size_t index) const -> Packet { + return m_functor.packet(packetExtractor(std::get(m_args), index)...); + } + + template + auto Function::scalar(size_t index) const -> Scalar { + return scalarImpl(std::make_index_sequence(), index); + } + + template + template + auto Function::scalarImpl(std::index_sequence, + size_t index) const -> Scalar { + return m_functor(scalarExtractor(std::get(m_args), index)...); + } + + template + auto Function::begin() const -> Iterator { + return Iterator(*this, 0); + } + + template + auto Function::end() const -> Iterator { + return Iterator(*this, shape()[0]); + } + + template + std::string Function::str(const std::string &format) const { + return eval().str(format); + } + } // namespace detail } // namespace librapid // Support FMT printing #ifdef FMT_API LIBRAPID_SIMPLE_IO_IMPL(typename desc COMMA typename Functor COMMA typename... Args, - librapid::detail::Function) + librapid::detail::Function) + LIBRAPID_SIMPLE_IO_NORANGE(typename desc COMMA typename Functor COMMA typename... Args, - librapid::detail::Function) + librapid::detail::Function) #endif // FMT_API #endif // LIBRAPID_ARRAY_FUNCTION_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/linalg/arrayMultiply.hpp b/librapid/include/librapid/array/linalg/arrayMultiply.hpp index 386658c9..80aeb4a4 100644 --- a/librapid/include/librapid/array/linalg/arrayMultiply.hpp +++ b/librapid/include/librapid/array/linalg/arrayMultiply.hpp @@ -2,707 +2,707 @@ #define LIBRAPID_ARRAY_LINALG_ARRAY_MULTIPLY_HPP namespace librapid { - namespace detail { - /// Extract the pointer from a given array type - /// \tparam T - /// \param ptr - /// \return - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto arrayPointerExtractor(T *ptr) { - return ptr; - } + namespace detail { + /// Extract the pointer from a given array type + /// \tparam T + /// \param ptr + /// \return + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto arrayPointerExtractor(T *ptr) { + return ptr; + } #if defined(LIBRAPID_HAS_OPENCL) - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto arrayPointerExtractor(cl::Buffer ptr) { - return ptr; - } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto arrayPointerExtractor(cl::Buffer ptr) { + return ptr; + } #endif // LIBRAPID_HAS_OPENCL - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - arrayPointerExtractor(std::shared_ptr ptr) { - return ptr.get(); - } - } // namespace detail - - namespace linalg { - enum class MatmulClass { - DOT, // Vector-vector dot product - GEMV, // Matrix-vector product - GEMM, // Matrix-matrix product - OUTER, // Outer product - }; - - /// Class to represent an array multiplication (vector-vector, matrix-vector, matrix-matrix) - /// \tparam ShapeTypeA Shape of the first array - /// \tparam StorageTypeA Storage type of the first array - /// \tparam ShapeTypeB Shape of the second array - /// \tparam StorageTypeB Storage type of the second array - /// \tparam Alpha Type of \f$ \alpha \f$ scaling factor - /// \tparam Beta Type of \f$ \beta \f$ scaling factor - template - class ArrayMultiply { - public: - using TypeA = array::ArrayContainer; - using TypeB = array::ArrayContainer; - using ScalarA = typename StorageTypeA::Scalar; - using ScalarB = typename StorageTypeB::Scalar; - using Scalar = decltype(std::declval() * std::declval()); - using ShapeType = ShapeTypeA; - using BackendA = typename typetraits::TypeInfo::Backend; - using BackendB = typename typetraits::TypeInfo::Backend; - using Backend = decltype(typetraits::commonBackend()); - - static_assert(std::is_same_v, "Backend of A and B must match"); - - /// Default constructor (deleted) - ArrayMultiply() = delete; - - /// Copy constructor - ArrayMultiply(const ArrayMultiply &) = default; - - /// Move constructor - ArrayMultiply(ArrayMultiply &&) noexcept = default; - - /// \brief Full set of parameters for an array multiplication - /// \param transA Determines \f$ \mathrm{OP}_A \f$ (true: transpose, false: no - /// transpose) \param transB Determines \f$ \mathrm{OP}_B \f$ (true: transpose, false: - /// no transpose) \param a First array \param alpha Scaling factor \f$ \alpha \f$ \param - /// b Second array \param beta Scaling factor \f$ \beta \f$ - ArrayMultiply(bool transA, bool transB, const TypeA &a, Alpha alpha, const TypeB &b, - Beta beta); - - /// \brief Full set of parameters, but with move semantics - /// \param transA - /// \param transB - /// \param a - /// \param alpha - /// \param b - /// \param beta - ArrayMultiply(bool transA, bool transB, TypeA &&a, Alpha alpha, TypeB &&b, Beta beta); - - /// \brief Array multiplication with \f$ \alpha = 1 \f$ and \f$ \beta = 0 \f$ - /// \param a - /// \param b - ArrayMultiply(const TypeA &a, const TypeB &b); - - /// \brief Array multiplication with \f$ \alpha = 1 \f$ and \f$ \beta = 0 \f$, but with - /// move semantics - /// \param a - /// \param b - ArrayMultiply(TypeA &&a, TypeB &&b); - - /// \brief Array multiplication with \f$ \alpha = 1 \f$ and \f$ \beta = 0 \f$ and - /// transpose options - /// \param transA - /// \param transB - /// \param a - /// \param b - ArrayMultiply(bool transA, bool transB, const TypeA &a, const TypeB &b); - - /// \brief Array multiplication with \f$ \alpha = 1 \f$ and \f$ \beta = 0 \f$ and - /// transpose options, but with move semantics - /// \param transA - /// \param transB - /// \param a - /// \param b - ArrayMultiply(bool transA, bool transB, TypeA &&a, TypeB &&b); - - /// \brief Copy assignment operator - /// \return Reference to this - ArrayMultiply &operator=(const ArrayMultiply &) = default; - - /// \brief Move assignment operator - /// \return Reference to this - ArrayMultiply &operator=(ArrayMultiply &&) noexcept = default; - - /// \brief Determine the class of the array multiplication - /// - /// The class of the array multiplication is determined by the shapes of the arrays. - /// There are three supported cases: - /// - Vector-vector dot product (both arrays are 1-dimensional vectors) - /// - Matrix-vector product (first array is a 2-dimensional matrix, second array is a - /// 1-dimensional vector) - /// - Matrix-matrix product (both arrays are 2-dimensional matrices) - /// \return Class of the array multiplication - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE MatmulClass matmulClass() const; - - /// \brief Determine the shape of the result - /// \return Shape of the result - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const; - - /// \brief Determine the number of dimensions of the result - /// \return Number of dimensions of the result - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const; - - /// \brief Force evaluation of the array multiplication, returning an Array object - /// \return Array object containing the result - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const; - - /// \brief Get the scaling factor \f$ \alpha \f$ - /// \return \f$ \alpha \f$ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ScalarA alpha() const; - - /// \brief Get the scaling factor \f$ \beta \f$ - /// \return \f$ \beta \f$ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ScalarB beta() const; - - /// \brief Determine \f$ \mathrm{OP}_A \f$ - /// \return True: \f$ \mathrm{OP}_A(\mathbf{A}) = \mathbf{A}^T \f$, false: \f$ - /// \mathrm{OP}_A(\mathbf{A}) = \mathbf{A} \f$ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool transA() const; - - /// \brief Determine \f$ \mathrm{OP}_B \f$ - /// \return True: \f$ \mathrm{OP}_B(\mathbf{B}) = \mathbf{B}^T \f$, false: \f$ - /// \mathrm{OP}_B(\mathbf{B}) = \mathbf{B} \f$ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool transB() const; - - /// \brief Get the first array - /// \return First array - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const TypeA &a() const; - - /// \brief Get the second array - /// \return Second array - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const TypeB &b() const; - - /// \brief Get the first array - /// \return First array - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE TypeA &a(); - - /// \brief Get the second array - /// \return Second array - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE TypeB &b(); - - /// \brief Apply the array multiplication to an array container - /// - /// Apply this operation to the provided Array, assuming that the Array has the correct - /// shape. If the Array does not have the correct shape, an error is thrown. - /// - /// \tparam StorageType Storage type of the array container - /// \param out Array container to store the result in - template - void applyTo(array::ArrayContainer &out) const; - - /// \brief String representation of the array multiplication - /// \param format Format string for each element - /// \return String representation of the array multiplication - LIBRAPID_NODISCARD std::string str(const std::string &format) const; - - private: - bool m_transA; // Transpose state of A - bool m_transB; // Transpose state of B - TypeA m_a; // First array - ScalarA m_alpha; // Scaling factor for A - TypeB m_b; // Second array - ScalarB m_beta; // Scaling factor for B - }; - - template - ArrayMultiply::ArrayMultiply(bool transA, bool transB, const TypeA &a, Alpha alpha, - const TypeB &b, Beta beta) : - m_transA(transA), - m_transB(transB), m_a(a), m_alpha(static_cast(alpha)), m_b(b), - m_beta(static_cast(beta)) {} - - template - ArrayMultiply::ArrayMultiply(bool transA, bool transB, TypeA &&a, Alpha alpha, - TypeB &&b, Beta beta) : - m_transA(transA), - m_transB(transB), m_a(std::forward(a)), m_alpha(static_cast(alpha)), - m_b(std::forward(b)), m_beta(static_cast(beta)) {} - - template - ArrayMultiply::ArrayMultiply(const TypeA &a, const TypeB &b) : - m_transA(false), - m_transB(false), m_a(a), m_alpha(1), m_b(b), m_beta(0) {} - - template - ArrayMultiply::ArrayMultiply(TypeA &&a, TypeB &&b) : - m_transA(false), - m_transB(false), m_a(std::forward(a)), m_alpha(1), - m_b(std::forward(b)), m_beta(0) {} - - template - ArrayMultiply::ArrayMultiply(bool transA, bool transB, const TypeA &a, - const TypeB &b) : - m_transA(transA), - m_transB(transB), m_a(a), m_alpha(1), m_b(b), m_beta(0) {} - - template - ArrayMultiply::ArrayMultiply(bool transA, bool transB, TypeA &&a, TypeB &&b) : - m_transA(transA), - m_transB(transB), m_a(std::forward(a)), m_alpha(1), - m_b(std::forward(b)), m_beta(0) {} - - template - auto ArrayMultiply::matmulClass() const -> MatmulClass { - const auto &shapeA = m_a.shape(); - const auto &shapeB = m_b.shape(); - - if (shapeA.ndim() == 1 && shapeB.ndim() == 1) { - LIBRAPID_ASSERT(shapeA[0] == shapeB[0], - "Vector dimensions must. Expected: {} -- Got: {}", - shapeA[0], - shapeB[0]); - - return MatmulClass::DOT; - } else if (shapeA.ndim() == 1 && shapeB.ndim() == 2) { - LIBRAPID_ASSERT( - shapeA[0] == shapeB[int(!m_transB)], - "Columns of OP(B) must match elements of A. Expected: {} -- Got: {}", - shapeA[0], - shapeB[int(!m_transB)]); - - return MatmulClass::GEMV; - } else if (shapeA.ndim() == 2 && shapeB.ndim() == 1) { - LIBRAPID_ASSERT(shapeA[int(!m_transA)] == shapeB[0], - "Rows of OP(A) must match elements of B. Expected: {} -- Got: {}", - shapeA[int(m_transA)], - shapeB[0]); - - return MatmulClass::GEMV; - } else if (shapeA.ndim() == 2 && shapeB.ndim() == 2) { - // // Check for GEMV variations - // // 1. A is a matrix, B is a 1xn vector - // // 2. A is a matrix, B is a nx1 vector - - // if (shapeB[0] == 1) { // Case 1 - // LIBRAPID_ASSERT( - // shapeA[int(!m_transA)] == shapeB[1], - // "Columns of {} must match columns of B. Expected: {} -- Got: {}", - // (m_transA ? "A" : "A^T"), - // shapeA[int(!m_transA)], - // shapeB[1]); - - // return MatmulClass::GEMV; - // } else if (shapeB[1] == 1) { // Case 2 - // LIBRAPID_ASSERT(shapeA[int(!m_transA)] == shapeB[0], - // "Columns of {} must match rows of B. Expected: {} -- Got: {}", - // (m_transA ? "A" : "A^T"), - // shapeA[int(!m_transA)], - // shapeB[0]); - - // return MatmulClass::GEMV; - // } - - LIBRAPID_ASSERT(m_a.shape()[int(!m_transA)] == m_b.shape()[int(m_transB)], - "Inner dimensions of matrices must match. Expected: {} -- Got: {}", - m_a.shape()[int(!m_transA)], - m_b.shape()[int(m_transB)]); - - return MatmulClass::GEMM; - } else { - LIBRAPID_NOT_IMPLEMENTED; - - return MatmulClass::OUTER; - } - } - - template - auto ArrayMultiply::shape() - const -> ShapeType { - const auto &shapeA = m_a.shape(); - const auto &shapeB = m_b.shape(); - MatmulClass matmulClass = this->matmulClass(); - - switch (matmulClass) { - case MatmulClass::DOT: { - return {1}; - } - case MatmulClass::GEMV: { - return {shapeA[int(m_transA)]}; - } - case MatmulClass::GEMM: { - return {shapeA[int(m_transA)], shapeB[int(!m_transB)]}; - } - case MatmulClass::OUTER: { - LIBRAPID_NOT_IMPLEMENTED; - return {1}; - } - } - - LIBRAPID_NOT_IMPLEMENTED; - return {1}; - } - - template - auto - ArrayMultiply::ndim() const - -> int64_t { - return shape().ndim(); - } - - template - auto ArrayMultiply::eval() - const { - Array result(shape()); - applyTo(result); - return result; - } - - template - auto ArrayMultiply::alpha() - const -> ScalarA { - return m_alpha; - } - - template - auto - ArrayMultiply::beta() const - -> ScalarB { - return m_beta; - } - - template - bool - ArrayMultiply::transA() - const { - return m_transA; - } - - template - bool - ArrayMultiply::transB() - const { - return m_transB; - } - - template - auto - ArrayMultiply::a() const - -> const TypeA & { - return m_a; - } - - template - auto - ArrayMultiply::b() const - -> const TypeB & { - return m_b; - } - - template - auto ArrayMultiply::a() - -> TypeA & { - return m_a; - } - - template - auto ArrayMultiply::b() - -> TypeB & { - return m_b; - } - - template - template - void - ArrayMultiply::applyTo( - array::ArrayContainer &out) const { - LIBRAPID_ASSERT(out.shape() == shape(), - "Shape of output array must match shape of array multiply operation. " - "Expected: {} -- Got: {}", - shape(), - out.shape()); - MatmulClass matmulClass = this->matmulClass(); - - auto a = detail::arrayPointerExtractor(m_a.storage().data()); - auto b = detail::arrayPointerExtractor(m_b.storage().data()); - auto c = detail::arrayPointerExtractor(out.storage().data()); - - switch (matmulClass) { - case MatmulClass::DOT: { - LIBRAPID_NOT_IMPLEMENTED; - } - case MatmulClass::GEMV: { - auto m = int64_t(m_a.shape()[m_transA]); - auto n = int64_t(m_a.shape()[1 - m_transA]); - - auto lda = int64_t(m_a.shape()[1]); - auto incB = int64_t(1); - auto incC = int64_t(1); - - gemv(m_transA, - m, - n, - static_cast(m_alpha), - a, - lda, - b, - incB, - static_cast(m_beta), - c, - incC, - Backend()); - - break; - } - case MatmulClass::GEMM: { - auto m = int64_t(m_a.shape()[m_transA]); - auto n = int64_t(m_b.shape()[1 - m_transB]); - auto k = int64_t(m_a.shape()[1 - m_transA]); - - auto lda = int64_t(m_a.shape()[1]); - auto ldb = int64_t(m_b.shape()[1]); - auto ldc = int64_t(out.shape()[1]); - - gemm(m_transA, - m_transB, - m, - n, - k, - static_cast(m_alpha), - a, - lda, - b, - ldb, - static_cast(m_beta), - c, - ldc, - Backend()); - - break; - } - default: { - LIBRAPID_NOT_IMPLEMENTED; - } - } - } - - template - std::string - ArrayMultiply::str( - const std::string &format) const { - return eval().str(format); - } - } // namespace linalg - - // /// \brief Computes the dot product of two arrays. - // /// - // /// This function calculates the dot product of two arrays. - // /// - // /// If the input arrays are 1-dimensional vectors, this function computes the vector-dot - // /// product \f$ \mathbf{a} \cdot \mathbf{b} = a_1b_1 + a_2b_2 + \ldots + a_nb_n \f$. - // /// Note that the return value will be a 1x1 array (i.e. a scalar) since we cannot return a - // /// scalar directly. - // /// - // /// If the left input is a 2-dimensional matrix and the right input is a 1-dimensional - // vector, - // /// this function computes the matrix-vector product \f$ y_i = \sum_{j=1}^{n} a_{ij} x_j \f$ - // /// for \f$ i = 1, \ldots, m \f$. - // /// - // /// If both inputs are 2-dimensional matrices, this function computes the matrix-matrix - // product - // /// \f$ c_{ij} = \sum_{k=1}^{n} a_{ik} b_{kj} \f$ for \f$ i = 1, \ldots, m \f$ and \f$ j = - // 1, - // /// \ldots, p \f$. \tparam StorageTypeA The storage type of the left input array. \tparam - // /// StorageTypeB The storage type of the right input array. \param a The left input array. - // /// \param b The right input array. - // /// \return The dot product of the two input arrays. - // template - // auto dot(const ArrayRef &a, const ArrayRef &b) { - // return linalg::ArrayMultiply(a, b); - // } - - namespace detail { - template - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer &destination, - const linalg::ArrayMultiply &op) { - op.applyTo(destination); - } - - /// Evaluates to true if the type is a transpose type. - /// \tparam T - template - struct IsTransposeType : std::false_type {}; - - template - struct IsTransposeType> : std::true_type {}; - - /// Returns a tuple of the form (transpose, raw array) where transpose is true if the array - /// is transposed and false otherwise, and raw array is the raw array data. - /// \tparam T - /// \param val - /// \return - template::value, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto transposeExtractor(T &&val) { - using Scalar = typename typetraits::TypeInfo>::Scalar; - return std::make_tuple(false, Scalar(1), std::forward(val)); - } - - template::value, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto transposeExtractor(T &&val) { - using Type = decltype(val.array()); - return std::make_tuple(true, val.alpha(), std::forward(val.array())); - } - - /// Evaluates to true if the type is a multiply type. - /// \tparam T - template - struct IsMultiplyType : std::false_type {}; - - template - struct IsMultiplyType> - : std::true_type {}; - - /// Returns a tuple of the form (scalar, raw array) where scalar is the multiplication - /// factor and raw array is the raw array data. - /// \tparam T - /// \param val - /// \return - template::value, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto multiplyExtractor(T &&val) { - using Scalar = typename typetraits::TypeInfo>::Scalar; - return std::make_tuple(Scalar(1), std::forward(val)); - } - - template::type == detail::LibRapidType::Scalar, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - multiplyExtractor(detail::Function &&val) { - using Type = decltype(std::get<0>(val.args())); - return std::make_tuple(std::get<1>(val.args()), - std::forward(std::get<0>(val.args()))); - } - - template::type == detail::LibRapidType::Scalar, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - multiplyExtractor(detail::Function &&val) { - using Type = decltype(std::get<1>(val.args())); - return std::make_tuple(std::get<0>(val.args()), - std::forward(std::get<1>(val.args()))); - } - - /// Return a tuple of the form (transpose, scalar, raw array) depending on the input type. - /// All scalar values are extracted and combined, and successive transpose operations are - /// combined. - /// \tparam T - /// \param val - /// \return - template - auto dotHelper(T &&val) { - if constexpr (IsTransposeType::value) { - auto [transpose, alpha, array] = transposeExtractor(std::forward(val)); - auto [transpose2, alpha2, array2] = dotHelper(std::forward(array)); - return std::make_tuple( - transpose ^ transpose2, alpha * alpha2, std::forward(array2)); - } else if constexpr (IsMultiplyType::value) { - auto [alpha, array] = multiplyExtractor(std::forward(val)); - auto [transpose, alpha2, array2] = dotHelper(std::forward(array)); - return std::make_tuple(transpose, alpha * alpha2, std::forward(array2)); - } else { - using Scalar = typename typetraits::TypeInfo>::Scalar; - return std::make_tuple(false, Scalar(1), std::forward(val)); - } - } - } // namespace detail - - /// \brief Computes the dot product of two arrays. - /// - /// This function calculates the dot product of two arrays. - /// - /// If the input arrays are 1-dimensional vectors, this function computes the vector-dot - /// product \f$ \mathbf{a} \cdot \mathbf{b} = a_1b_1 + a_2b_2 + \ldots + a_nb_n \f$. - /// Note that the return value will be a 1x1 array (i.e. a scalar) since we cannot return a - /// scalar directly. - /// - /// If the left input is a 2-dimensional matrix and the right input is a 1-dimensional vector, - /// this function computes the matrix-vector product \f$ y_i = \sum_{j=1}^{n} a_{ij} x_j \f$ - /// for \f$ i = 1, \ldots, m \f$. - /// - /// If both inputs are 2-dimensional matrices, this function computes the matrix-matrix product - /// \f$ c_{ij} = \sum_{k=1}^{n} a_{ik} b_{kj} \f$ for \f$ i = 1, \ldots, m \f$ and \f$ j = 1, - /// \ldots, p \f$. \tparam StorageTypeA The storage type of the left input array. \tparam - /// StorageTypeB The storage type of the right input array. \param a The left input array. - /// \param b The right input array. - /// \return The dot product of the two input arrays. - template< - typename First, typename Second, - typename std::enable_if_t::value && IsArrayType::value, int> = 0> - auto dot(First &&a, Second &&b) { - using ScalarA = typename typetraits::TypeInfo>::Scalar; - using ScalarB = typename typetraits::TypeInfo>::Scalar; - using BackendA = typename typetraits::TypeInfo>::Backend; - using BackendB = typename typetraits::TypeInfo>::Backend; - using ArrayA = Array; - using ArrayB = Array; - - auto [transA, alpha, arrA] = detail::dotHelper(a); - auto [transB, beta, arrB] = detail::dotHelper(b); - return linalg::ArrayMultiply(transA, - transB, - std::forward(arrA), - alpha * beta, - std::forward(arrB), - 0); // .eval(); - } - - namespace typetraits { - template - struct TypeInfo< - linalg::ArrayMultiply> { - detail::LibRapidType type = detail::LibRapidType::ArrayFunction; - using Type = linalg::ArrayMultiply; - using Scalar = typename Type::Scalar; - using Backend = typename Type::Backend; - static constexpr bool allowVectorisation = false; - }; - - LIBRAPID_DEFINE_AS_TYPE(typename ShapeTypeA COMMA typename StorageTypeA COMMA - typename ShapeTypeB COMMA typename StorageTypeB COMMA - typename Alpha COMMA typename Beta, - linalg::ArrayMultiply); - } // namespace typetraits + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + arrayPointerExtractor(std::shared_ptr ptr) { + return ptr.get(); + } + } // namespace detail + + namespace linalg { + enum class MatmulClass { + DOT, // Vector-vector dot product + GEMV, // Matrix-vector product + GEMM, // Matrix-matrix product + OUTER, // Outer product + }; + + /// Class to represent an array multiplication (vector-vector, matrix-vector, matrix-matrix) + /// \tparam ShapeTypeA Shape of the first array + /// \tparam StorageTypeA Storage type of the first array + /// \tparam ShapeTypeB Shape of the second array + /// \tparam StorageTypeB Storage type of the second array + /// \tparam Alpha Type of \f$ \alpha \f$ scaling factor + /// \tparam Beta Type of \f$ \beta \f$ scaling factor + template + class ArrayMultiply { + public: + using TypeA = array::ArrayContainer; + using TypeB = array::ArrayContainer; + using ScalarA = typename StorageTypeA::Scalar; + using ScalarB = typename StorageTypeB::Scalar; + using Scalar = decltype(std::declval() * std::declval()); + using ShapeType = ShapeTypeA; + using BackendA = typename typetraits::TypeInfo::Backend; + using BackendB = typename typetraits::TypeInfo::Backend; + using Backend = decltype(typetraits::commonBackend()); + + static_assert(std::is_same_v, "Backend of A and B must match"); + + /// Default constructor (deleted) + ArrayMultiply() = delete; + + /// Copy constructor + ArrayMultiply(const ArrayMultiply &) = default; + + /// Move constructor + ArrayMultiply(ArrayMultiply &&) noexcept = default; + + /// \brief Full set of parameters for an array multiplication + /// \param transA Determines \f$ \mathrm{OP}_A \f$ (true: transpose, false: no + /// transpose) \param transB Determines \f$ \mathrm{OP}_B \f$ (true: transpose, false: + /// no transpose) \param a First array \param alpha Scaling factor \f$ \alpha \f$ \param + /// b Second array \param beta Scaling factor \f$ \beta \f$ + ArrayMultiply(bool transA, bool transB, const TypeA &a, Alpha alpha, const TypeB &b, + Beta beta); + + /// \brief Full set of parameters, but with move semantics + /// \param transA + /// \param transB + /// \param a + /// \param alpha + /// \param b + /// \param beta + ArrayMultiply(bool transA, bool transB, TypeA &&a, Alpha alpha, TypeB &&b, Beta beta); + + /// \brief Array multiplication with \f$ \alpha = 1 \f$ and \f$ \beta = 0 \f$ + /// \param a + /// \param b + ArrayMultiply(const TypeA &a, const TypeB &b); + + /// \brief Array multiplication with \f$ \alpha = 1 \f$ and \f$ \beta = 0 \f$, but with + /// move semantics + /// \param a + /// \param b + ArrayMultiply(TypeA &&a, TypeB &&b); + + /// \brief Array multiplication with \f$ \alpha = 1 \f$ and \f$ \beta = 0 \f$ and + /// transpose options + /// \param transA + /// \param transB + /// \param a + /// \param b + ArrayMultiply(bool transA, bool transB, const TypeA &a, const TypeB &b); + + /// \brief Array multiplication with \f$ \alpha = 1 \f$ and \f$ \beta = 0 \f$ and + /// transpose options, but with move semantics + /// \param transA + /// \param transB + /// \param a + /// \param b + ArrayMultiply(bool transA, bool transB, TypeA &&a, TypeB &&b); + + /// \brief Copy assignment operator + /// \return Reference to this + ArrayMultiply &operator=(const ArrayMultiply &) = default; + + /// \brief Move assignment operator + /// \return Reference to this + ArrayMultiply &operator=(ArrayMultiply &&) noexcept = default; + + /// \brief Determine the class of the array multiplication + /// + /// The class of the array multiplication is determined by the shapes of the arrays. + /// There are three supported cases: + /// - Vector-vector dot product (both arrays are 1-dimensional vectors) + /// - Matrix-vector product (first array is a 2-dimensional matrix, second array is a + /// 1-dimensional vector) + /// - Matrix-matrix product (both arrays are 2-dimensional matrices) + /// \return Class of the array multiplication + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE MatmulClass matmulClass() const; + + /// \brief Determine the shape of the result + /// \return Shape of the result + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const; + + /// \brief Determine the number of dimensions of the result + /// \return Number of dimensions of the result + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const; + + /// \brief Force evaluation of the array multiplication, returning an Array object + /// \return Array object containing the result + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const; + + /// \brief Get the scaling factor \f$ \alpha \f$ + /// \return \f$ \alpha \f$ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ScalarA alpha() const; + + /// \brief Get the scaling factor \f$ \beta \f$ + /// \return \f$ \beta \f$ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ScalarB beta() const; + + /// \brief Determine \f$ \mathrm{OP}_A \f$ + /// \return True: \f$ \mathrm{OP}_A(\mathbf{A}) = \mathbf{A}^T \f$, false: \f$ + /// \mathrm{OP}_A(\mathbf{A}) = \mathbf{A} \f$ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool transA() const; + + /// \brief Determine \f$ \mathrm{OP}_B \f$ + /// \return True: \f$ \mathrm{OP}_B(\mathbf{B}) = \mathbf{B}^T \f$, false: \f$ + /// \mathrm{OP}_B(\mathbf{B}) = \mathbf{B} \f$ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool transB() const; + + /// \brief Get the first array + /// \return First array + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const TypeA &a() const; + + /// \brief Get the second array + /// \return Second array + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const TypeB &b() const; + + /// \brief Get the first array + /// \return First array + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE TypeA &a(); + + /// \brief Get the second array + /// \return Second array + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE TypeB &b(); + + /// \brief Apply the array multiplication to an array container + /// + /// Apply this operation to the provided Array, assuming that the Array has the correct + /// shape. If the Array does not have the correct shape, an error is thrown. + /// + /// \tparam StorageType Storage type of the array container + /// \param out Array container to store the result in + template + void applyTo(array::ArrayContainer &out) const; + + /// \brief String representation of the array multiplication + /// \param format Format string for each element + /// \return String representation of the array multiplication + LIBRAPID_NODISCARD std::string str(const std::string &format) const; + + private: + bool m_transA; // Transpose state of A + bool m_transB; // Transpose state of B + TypeA m_a; // First array + ScalarA m_alpha; // Scaling factor for A + TypeB m_b; // Second array + ScalarB m_beta; // Scaling factor for B + }; + + template + ArrayMultiply::ArrayMultiply(bool transA, bool transB, const TypeA &a, Alpha alpha, + const TypeB &b, Beta beta) : + m_transA(transA), + m_transB(transB), m_a(a), m_alpha(static_cast(alpha)), m_b(b), + m_beta(static_cast(beta)) {} + + template + ArrayMultiply::ArrayMultiply(bool transA, bool transB, TypeA &&a, Alpha alpha, + TypeB &&b, Beta beta) : + m_transA(transA), + m_transB(transB), m_a(std::forward(a)), m_alpha(static_cast(alpha)), + m_b(std::forward(b)), m_beta(static_cast(beta)) {} + + template + ArrayMultiply::ArrayMultiply(const TypeA &a, const TypeB &b) : + m_transA(false), + m_transB(false), m_a(a), m_alpha(1), m_b(b), m_beta(0) {} + + template + ArrayMultiply::ArrayMultiply(TypeA &&a, TypeB &&b) : + m_transA(false), + m_transB(false), m_a(std::forward(a)), m_alpha(1), + m_b(std::forward(b)), m_beta(0) {} + + template + ArrayMultiply::ArrayMultiply(bool transA, bool transB, const TypeA &a, + const TypeB &b) : + m_transA(transA), + m_transB(transB), m_a(a), m_alpha(1), m_b(b), m_beta(0) {} + + template + ArrayMultiply::ArrayMultiply(bool transA, bool transB, TypeA &&a, TypeB &&b) : + m_transA(transA), + m_transB(transB), m_a(std::forward(a)), m_alpha(1), + m_b(std::forward(b)), m_beta(0) {} + + template + auto ArrayMultiply::matmulClass() const -> MatmulClass { + const auto &shapeA = m_a.shape(); + const auto &shapeB = m_b.shape(); + + if (shapeA.ndim() == 1 && shapeB.ndim() == 1) { + LIBRAPID_ASSERT(shapeA[0] == shapeB[0], + "Vector dimensions must. Expected: {} -- Got: {}", + shapeA[0], + shapeB[0]); + + return MatmulClass::DOT; + } else if (shapeA.ndim() == 1 && shapeB.ndim() == 2) { + LIBRAPID_ASSERT( + shapeA[0] == shapeB[int(!m_transB)], + "Columns of OP(B) must match elements of A. Expected: {} -- Got: {}", + shapeA[0], + shapeB[int(!m_transB)]); + + return MatmulClass::GEMV; + } else if (shapeA.ndim() == 2 && shapeB.ndim() == 1) { + LIBRAPID_ASSERT(shapeA[int(!m_transA)] == shapeB[0], + "Rows of OP(A) must match elements of B. Expected: {} -- Got: {}", + shapeA[int(m_transA)], + shapeB[0]); + + return MatmulClass::GEMV; + } else if (shapeA.ndim() == 2 && shapeB.ndim() == 2) { + // // Check for GEMV variations + // // 1. A is a matrix, B is a 1xn vector + // // 2. A is a matrix, B is a nx1 vector + + // if (shapeB[0] == 1) { // Case 1 + // LIBRAPID_ASSERT( + // shapeA[int(!m_transA)] == shapeB[1], + // "Columns of {} must match columns of B. Expected: {} -- Got: {}", + // (m_transA ? "A" : "A^T"), + // shapeA[int(!m_transA)], + // shapeB[1]); + + // return MatmulClass::GEMV; + // } else if (shapeB[1] == 1) { // Case 2 + // LIBRAPID_ASSERT(shapeA[int(!m_transA)] == shapeB[0], + // "Columns of {} must match rows of B. Expected: {} -- Got: {}", + // (m_transA ? "A" : "A^T"), + // shapeA[int(!m_transA)], + // shapeB[0]); + + // return MatmulClass::GEMV; + // } + + LIBRAPID_ASSERT(m_a.shape()[int(!m_transA)] == m_b.shape()[int(m_transB)], + "Inner dimensions of matrices must match. Expected: {} -- Got: {}", + m_a.shape()[int(!m_transA)], + m_b.shape()[int(m_transB)]); + + return MatmulClass::GEMM; + } else { + LIBRAPID_NOT_IMPLEMENTED; + + return MatmulClass::OUTER; + } + } + + template + auto ArrayMultiply::shape() + const -> ShapeType { + const auto &shapeA = m_a.shape(); + const auto &shapeB = m_b.shape(); + MatmulClass matmulClass = this->matmulClass(); + + switch (matmulClass) { + case MatmulClass::DOT: { + return {1}; + } + case MatmulClass::GEMV: { + return {shapeA[int(m_transA)]}; + } + case MatmulClass::GEMM: { + return {shapeA[int(m_transA)], shapeB[int(!m_transB)]}; + } + case MatmulClass::OUTER: { + LIBRAPID_NOT_IMPLEMENTED; + return {1}; + } + } + + LIBRAPID_NOT_IMPLEMENTED; + return {1}; + } + + template + auto + ArrayMultiply::ndim() const + -> int64_t { + return shape().ndim(); + } + + template + auto ArrayMultiply::eval() + const { + Array result(shape()); + applyTo(result); + return result; + } + + template + auto ArrayMultiply::alpha() + const -> ScalarA { + return m_alpha; + } + + template + auto + ArrayMultiply::beta() const + -> ScalarB { + return m_beta; + } + + template + bool + ArrayMultiply::transA() + const { + return m_transA; + } + + template + bool + ArrayMultiply::transB() + const { + return m_transB; + } + + template + auto + ArrayMultiply::a() const + -> const TypeA & { + return m_a; + } + + template + auto + ArrayMultiply::b() const + -> const TypeB & { + return m_b; + } + + template + auto ArrayMultiply::a() + -> TypeA & { + return m_a; + } + + template + auto ArrayMultiply::b() + -> TypeB & { + return m_b; + } + + template + template + void + ArrayMultiply::applyTo( + array::ArrayContainer &out) const { + LIBRAPID_ASSERT(out.shape() == shape(), + "Shape of output array must match shape of array multiply operation. " + "Expected: {} -- Got: {}", + shape(), + out.shape()); + MatmulClass matmulClass = this->matmulClass(); + + auto a = detail::arrayPointerExtractor(m_a.storage().data()); + auto b = detail::arrayPointerExtractor(m_b.storage().data()); + auto c = detail::arrayPointerExtractor(out.storage().data()); + + switch (matmulClass) { + case MatmulClass::DOT: { + LIBRAPID_NOT_IMPLEMENTED; + } + case MatmulClass::GEMV: { + auto m = int64_t(m_a.shape()[m_transA]); + auto n = int64_t(m_a.shape()[1 - m_transA]); + + auto lda = int64_t(m_a.shape()[1]); + auto incB = int64_t(1); + auto incC = int64_t(1); + + gemv(m_transA, + m, + n, + static_cast(m_alpha), + a, + lda, + b, + incB, + static_cast(m_beta), + c, + incC, + Backend()); + + break; + } + case MatmulClass::GEMM: { + auto m = int64_t(m_a.shape()[m_transA]); + auto n = int64_t(m_b.shape()[1 - m_transB]); + auto k = int64_t(m_a.shape()[1 - m_transA]); + + auto lda = int64_t(m_a.shape()[1]); + auto ldb = int64_t(m_b.shape()[1]); + auto ldc = int64_t(out.shape()[1]); + + gemm(m_transA, + m_transB, + m, + n, + k, + static_cast(m_alpha), + a, + lda, + b, + ldb, + static_cast(m_beta), + c, + ldc, + Backend()); + + break; + } + default: { + LIBRAPID_NOT_IMPLEMENTED; + } + } + } + + template + std::string + ArrayMultiply::str( + const std::string &format) const { + return eval().str(format); + } + } // namespace linalg + + // /// \brief Computes the dot product of two arrays. + // /// + // /// This function calculates the dot product of two arrays. + // /// + // /// If the input arrays are 1-dimensional vectors, this function computes the vector-dot + // /// product \f$ \mathbf{a} \cdot \mathbf{b} = a_1b_1 + a_2b_2 + \ldots + a_nb_n \f$. + // /// Note that the return value will be a 1x1 array (i.e. a scalar) since we cannot return a + // /// scalar directly. + // /// + // /// If the left input is a 2-dimensional matrix and the right input is a 1-dimensional + // vector, + // /// this function computes the matrix-vector product \f$ y_i = \sum_{j=1}^{n} a_{ij} x_j \f$ + // /// for \f$ i = 1, \ldots, m \f$. + // /// + // /// If both inputs are 2-dimensional matrices, this function computes the matrix-matrix + // product + // /// \f$ c_{ij} = \sum_{k=1}^{n} a_{ik} b_{kj} \f$ for \f$ i = 1, \ldots, m \f$ and \f$ j = + // 1, + // /// \ldots, p \f$. \tparam StorageTypeA The storage type of the left input array. \tparam + // /// StorageTypeB The storage type of the right input array. \param a The left input array. + // /// \param b The right input array. + // /// \return The dot product of the two input arrays. + // template + // auto dot(const ArrayRef &a, const ArrayRef &b) { + // return linalg::ArrayMultiply(a, b); + // } + + namespace detail { + template + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer &destination, + const linalg::ArrayMultiply &op) { + op.applyTo(destination); + } + + /// Evaluates to true if the type is a transpose type. + /// \tparam T + template + struct IsTransposeType : std::false_type {}; + + template + struct IsTransposeType> : std::true_type {}; + + /// Returns a tuple of the form (transpose, raw array) where transpose is true if the array + /// is transposed and false otherwise, and raw array is the raw array data. + /// \tparam T + /// \param val + /// \return + template::value, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto transposeExtractor(T &&val) { + using Scalar = typename typetraits::TypeInfo>::Scalar; + return std::make_tuple(false, Scalar(1), std::forward(val)); + } + + template::value, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto transposeExtractor(T &&val) { + using Type = decltype(val.array()); + return std::make_tuple(true, val.alpha(), std::forward(val.array())); + } + + /// Evaluates to true if the type is a multiply type. + /// \tparam T + template + struct IsMultiplyType : std::false_type {}; + + template + struct IsMultiplyType> + : std::true_type {}; + + /// Returns a tuple of the form (scalar, raw array) where scalar is the multiplication + /// factor and raw array is the raw array data. + /// \tparam T + /// \param val + /// \return + template::value, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto multiplyExtractor(T &&val) { + using Scalar = typename typetraits::TypeInfo>::Scalar; + return std::make_tuple(Scalar(1), std::forward(val)); + } + + template::type == detail::LibRapidType::Scalar, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + multiplyExtractor(detail::Function &&val) { + using Type = decltype(std::get<0>(val.args())); + return std::make_tuple(std::get<1>(val.args()), + std::forward(std::get<0>(val.args()))); + } + + template::type == detail::LibRapidType::Scalar, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + multiplyExtractor(detail::Function &&val) { + using Type = decltype(std::get<1>(val.args())); + return std::make_tuple(std::get<0>(val.args()), + std::forward(std::get<1>(val.args()))); + } + + /// Return a tuple of the form (transpose, scalar, raw array) depending on the input type. + /// All scalar values are extracted and combined, and successive transpose operations are + /// combined. + /// \tparam T + /// \param val + /// \return + template + auto dotHelper(T &&val) { + if constexpr (IsTransposeType::value) { + auto [transpose, alpha, array] = transposeExtractor(std::forward(val)); + auto [transpose2, alpha2, array2] = dotHelper(std::forward(array)); + return std::make_tuple( + transpose ^ transpose2, alpha * alpha2, std::forward(array2)); + } else if constexpr (IsMultiplyType::value) { + auto [alpha, array] = multiplyExtractor(std::forward(val)); + auto [transpose, alpha2, array2] = dotHelper(std::forward(array)); + return std::make_tuple(transpose, alpha * alpha2, std::forward(array2)); + } else { + using Scalar = typename typetraits::TypeInfo>::Scalar; + return std::make_tuple(false, Scalar(1), std::forward(val)); + } + } + } // namespace detail + + /// \brief Computes the dot product of two arrays. + /// + /// This function calculates the dot product of two arrays. + /// + /// If the input arrays are 1-dimensional vectors, this function computes the vector-dot + /// product \f$ \mathbf{a} \cdot \mathbf{b} = a_1b_1 + a_2b_2 + \ldots + a_nb_n \f$. + /// Note that the return value will be a 1x1 array (i.e. a scalar) since we cannot return a + /// scalar directly. + /// + /// If the left input is a 2-dimensional matrix and the right input is a 1-dimensional vector, + /// this function computes the matrix-vector product \f$ y_i = \sum_{j=1}^{n} a_{ij} x_j \f$ + /// for \f$ i = 1, \ldots, m \f$. + /// + /// If both inputs are 2-dimensional matrices, this function computes the matrix-matrix product + /// \f$ c_{ij} = \sum_{k=1}^{n} a_{ik} b_{kj} \f$ for \f$ i = 1, \ldots, m \f$ and \f$ j = 1, + /// \ldots, p \f$. \tparam StorageTypeA The storage type of the left input array. \tparam + /// StorageTypeB The storage type of the right input array. \param a The left input array. + /// \param b The right input array. + /// \return The dot product of the two input arrays. + template< + typename First, typename Second, + typename std::enable_if_t::value && IsArrayType::value, int> = 0> + auto dot(First &&a, Second &&b) { + using ScalarA = typename typetraits::TypeInfo>::Scalar; + using ScalarB = typename typetraits::TypeInfo>::Scalar; + using BackendA = typename typetraits::TypeInfo>::Backend; + using BackendB = typename typetraits::TypeInfo>::Backend; + using ArrayA = Array; + using ArrayB = Array; + + auto [transA, alpha, arrA] = detail::dotHelper(a); + auto [transB, beta, arrB] = detail::dotHelper(b); + return linalg::ArrayMultiply(transA, + transB, + std::forward(arrA), + alpha * beta, + std::forward(arrB), + 0); // .eval(); + } + + namespace typetraits { + template + struct TypeInfo< + linalg::ArrayMultiply> { + detail::LibRapidType type = detail::LibRapidType::ArrayFunction; + using Type = linalg::ArrayMultiply; + using Scalar = typename Type::Scalar; + using Backend = typename Type::Backend; + static constexpr bool allowVectorisation = false; + }; + + LIBRAPID_DEFINE_AS_TYPE(typename ShapeTypeA COMMA typename StorageTypeA COMMA + typename ShapeTypeB COMMA typename StorageTypeB COMMA + typename Alpha COMMA typename Beta, + linalg::ArrayMultiply); + } // namespace typetraits } // namespace librapid LIBRAPID_SIMPLE_IO_IMPL( typename ShapeTypeA COMMA typename StorageTypeA COMMA typename ShapeTypeB COMMA typename StorageTypeB COMMA typename Alpha COMMA typename Beta, librapid::linalg::ArrayMultiply< - ShapeTypeA COMMA StorageTypeA COMMA ShapeTypeB COMMA StorageTypeB COMMA Alpha COMMA Beta>) + ShapeTypeA COMMA StorageTypeA COMMA ShapeTypeB COMMA StorageTypeB COMMA Alpha COMMA Beta>) LIBRAPID_SIMPLE_IO_NORANGE( typename ShapeTypeA COMMA typename StorageTypeA COMMA typename ShapeTypeB COMMA typename StorageTypeB COMMA typename Alpha COMMA typename Beta, librapid::linalg::ArrayMultiply< - ShapeTypeA COMMA StorageTypeA COMMA ShapeTypeB COMMA StorageTypeB COMMA Alpha COMMA Beta>) + ShapeTypeA COMMA StorageTypeA COMMA ShapeTypeB COMMA StorageTypeB COMMA Alpha COMMA Beta>) #endif // LIBRAPID_ARRAY_LINALG_ARRAY_MULTIPLY_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/linalg/compat.hpp b/librapid/include/librapid/array/linalg/compat.hpp index 54671c58..00a1b2b4 100644 --- a/librapid/include/librapid/array/linalg/compat.hpp +++ b/librapid/include/librapid/array/linalg/compat.hpp @@ -9,34 +9,34 @@ // types. namespace clblast { - // template<> - // StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, - // const Transpose b_transpose, const size_t m, const size_t n, - // const size_t k, const librapid::Complex alpha, - // const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, - // const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, - // const librapid::Complex beta, cl_mem c_buffer, - // const size_t c_offset, const size_t c_ld, cl_command_queue *queue, - // cl_event *event, cl_mem temp_buffer); + // template<> + // StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, + // const Transpose b_transpose, const size_t m, const size_t n, + // const size_t k, const librapid::Complex alpha, + // const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + // const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + // const librapid::Complex beta, cl_mem c_buffer, + // const size_t c_offset, const size_t c_ld, cl_command_queue *queue, + // cl_event *event, cl_mem temp_buffer); - // template<> - // StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, - // const Transpose b_transpose, const size_t m, const size_t n, - // const size_t k, const librapid::Complex alpha, - // const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, - // const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, - // const librapid::Complex beta, cl_mem c_buffer, - // const size_t c_offset, const size_t c_ld, cl_command_queue *queue, - // cl_event *event, cl_mem temp_buffer); + // template<> + // StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, + // const Transpose b_transpose, const size_t m, const size_t n, + // const size_t k, const librapid::Complex alpha, + // const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + // const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + // const librapid::Complex beta, cl_mem c_buffer, + // const size_t c_offset, const size_t c_ld, cl_command_queue *queue, + // cl_event *event, cl_mem temp_buffer); - template<> - StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, - const Transpose b_transpose, const size_t m, const size_t n, - const size_t k, const librapid::half alpha, const cl_mem a_buffer, - const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, - const size_t b_offset, const size_t b_ld, const librapid::half beta, - cl_mem c_buffer, const size_t c_offset, const size_t c_ld, - cl_command_queue *queue, cl_event *event, cl_mem temp_buffer); + template<> + StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, + const Transpose b_transpose, const size_t m, const size_t n, + const size_t k, const librapid::half alpha, const cl_mem a_buffer, + const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, + const size_t b_offset, const size_t b_ld, const librapid::half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue *queue, cl_event *event, cl_mem temp_buffer); } // namespace clblast #endif // LIBRAPID_HAS_OPENCL diff --git a/librapid/include/librapid/array/linalg/level2/gemv.cl b/librapid/include/librapid/array/linalg/level2/gemv.cl index 4ce62c67..37346d82 100644 --- a/librapid/include/librapid/array/linalg/level2/gemv.cl +++ b/librapid/include/librapid/array/linalg/level2/gemv.cl @@ -1,33 +1,33 @@ #define GEMV_IMPL(TYPE) \ - __kernel void gemv_##TYPE(const int trans, \ - const int32_t M, \ - const int32_t N, \ - const TYPE alpha, \ - __global const TYPE *A, \ - const int32_t lda, \ - __global const TYPE *x, \ - const int32_t incX, \ - const TYPE beta, \ - __global TYPE *y, \ - const int32_t incy) { \ - /* Get global thread ID */ \ - int idx = get_global_id(0); \ + __kernel void gemv_##TYPE(const int trans, \ + const int32_t M, \ + const int32_t N, \ + const TYPE alpha, \ + __global const TYPE *A, \ + const int32_t lda, \ + __global const TYPE *x, \ + const int32_t incX, \ + const TYPE beta, \ + __global TYPE *y, \ + const int32_t incy) { \ + /* Get global thread ID */ \ + int idx = get_global_id(0); \ \ - /* Only valid threads perform computation */ \ - if (idx < M) { \ - /* Compute dot product for this thread's row of matrix A */ \ - TYPE acc = 0; \ - if (trans == 0) { \ - /* Non-transposed matrix */ \ - for (int j = 0; j < N; ++j) { acc += A[idx * lda + j] * x[j * incX]; } \ - } else { \ - /* Transposed matrix */ \ - for (int j = 0; j < N; ++j) { acc += A[j * lda + idx] * x[j * incX]; } \ - } \ - /* Apply alpha scaling to acc and beta scaling to y[idx] then sum */ \ - y[idx * incy] = alpha * acc + beta * y[idx * incy]; \ - } \ - } + /* Only valid threads perform computation */ \ + if (idx < M) { \ + /* Compute dot product for this thread's row of matrix A */ \ + TYPE acc = 0; \ + if (trans == 0) { \ + /* Non-transposed matrix */ \ + for (int j = 0; j < N; ++j) { acc += A[idx * lda + j] * x[j * incX]; } \ + } else { \ + /* Transposed matrix */ \ + for (int j = 0; j < N; ++j) { acc += A[j * lda + idx] * x[j * incX]; } \ + } \ + /* Apply alpha scaling to acc and beta scaling to y[idx] then sum */ \ + y[idx * incy] = alpha * acc + beta * y[idx * incy]; \ + } \ + } GEMV_IMPL(int8_t) GEMV_IMPL(int16_t) diff --git a/librapid/include/librapid/array/linalg/level2/gemv.cu b/librapid/include/librapid/array/linalg/level2/gemv.cu index 302018ce..06e8c43d 100644 --- a/librapid/include/librapid/array/linalg/level2/gemv.cu +++ b/librapid/include/librapid/array/linalg/level2/gemv.cu @@ -1,19 +1,19 @@ #define TS 32 // Tile size template + typename TypeY> __global__ void gemv(bool trans, Int m, Int n, Alpha alpha, TypeA *a, Int lda, TypeX *x, Int incx, - Beta beta, TypeY *y, Int incy) { - const Int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < m) { - TypeY acc = 0; + Beta beta, TypeY *y, Int incy) { + const Int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < m) { + TypeY acc = 0; - if (trans) { - for (Int i = 0; i < n; i++) { acc += a[idx + i * lda] * x[i * incx]; } - } else { - for (Int i = 0; i < n; i++) { acc += a[i + idx * lda] * x[i * incx]; } - } + if (trans) { + for (Int i = 0; i < n; i++) { acc += a[idx + i * lda] * x[i * incx]; } + } else { + for (Int i = 0; i < n; i++) { acc += a[i + idx * lda] * x[i * incx]; } + } - y[idx * incy] = alpha * acc + beta * y[idx * incy]; - } + y[idx * incy] = alpha * acc + beta * y[idx * incy]; + } } diff --git a/librapid/include/librapid/array/linalg/level2/gemv.hpp b/librapid/include/librapid/array/linalg/level2/gemv.hpp index 24293bc6..57a506a2 100644 --- a/librapid/include/librapid/array/linalg/level2/gemv.hpp +++ b/librapid/include/librapid/array/linalg/level2/gemv.hpp @@ -2,231 +2,231 @@ #define LIBRAPID_ARRAY_LINALG_LEVEL2_GEMV_HPP namespace librapid::linalg { - /// \brief General matrix-vector multiplication. - /// - /// Computes \f$ y = \alpha \mathrm{op}(\mathbf{A}) \mathbf{x} + \beta \mathbf{y} \f$ for - /// matrix \f$ \mathbf{A} \f$ and vectors \f$ \mathbf{x} \f$ and \f$ \mathbf{y} \f$ - /// \tparam Int Integer type - /// \tparam Alpha Alpha scaling factor - /// \tparam A Matrix type - /// \tparam X First vector type - /// \tparam Beta Beta scaling factor - /// \tparam Y Second vector type - /// \param trans If true, \f$ \mathrm{op}(\mathbf{A}) = \mathbf{A}^T \f$, otherwise \f$ \mathrm{op}(\mathbf{A}) = \mathbf{A} \f$ - /// \param m Number of rows in \f$ \mathbf{A} \f$ - /// \param n Number of columns in \f$ \mathbf{A} \f$ - /// \param alpha Scaling factor for \f$ \mathrm{op}(\mathbf{A}) \mathbf{x} \f$ - /// \param a Pointer to matrix \f$ \mathbf{A} \f$ - /// \param lda Leading dimension of \f$ \mathbf{A} \f$ - /// \param x Pointer to vector \f$ \mathbf{x} \f$ - /// \param incX Increment of \f$ \mathbf{x} \f$ - /// \param beta Scaling factor for \f$ \mathbf{y} \f$ - /// \param y Pointer to vector \f$ \mathbf{y} \f$ - /// \param incY Increment of \f$ \mathbf{y} \f$ - /// \param backend Backend to use for computation - template - void gemv(bool trans, Int m, Int n, Alpha alpha, A *a, Int lda, X *x, Int incX, Beta beta, Y *y, - Int incY, backend::CPU backend = backend::CPU()) { - // On the CPU, cxxblas provides a generic implementation for all types along with BLAS - // implementations where available - - cxxblas::gemv(cxxblas::StorageOrder::RowMajor, - (trans ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans), - static_cast(m), - static_cast(n), - alpha, - a, - static_cast(lda), - x, - static_cast(incX), - beta, - y, - static_cast(incY)); - } + /// \brief General matrix-vector multiplication. + /// + /// Computes \f$ y = \alpha \mathrm{op}(\mathbf{A}) \mathbf{x} + \beta \mathbf{y} \f$ for + /// matrix \f$ \mathbf{A} \f$ and vectors \f$ \mathbf{x} \f$ and \f$ \mathbf{y} \f$ + /// \tparam Int Integer type + /// \tparam Alpha Alpha scaling factor + /// \tparam A Matrix type + /// \tparam X First vector type + /// \tparam Beta Beta scaling factor + /// \tparam Y Second vector type + /// \param trans If true, \f$ \mathrm{op}(\mathbf{A}) = \mathbf{A}^T \f$, otherwise \f$ + /// \mathrm{op}(\mathbf{A}) = \mathbf{A} \f$ \param m Number of rows in \f$ \mathbf{A} \f$ + /// \param n Number of columns in \f$ \mathbf{A} \f$ + /// \param alpha Scaling factor for \f$ \mathrm{op}(\mathbf{A}) \mathbf{x} \f$ + /// \param a Pointer to matrix \f$ \mathbf{A} \f$ + /// \param lda Leading dimension of \f$ \mathbf{A} \f$ + /// \param x Pointer to vector \f$ \mathbf{x} \f$ + /// \param incX Increment of \f$ \mathbf{x} \f$ + /// \param beta Scaling factor for \f$ \mathbf{y} \f$ + /// \param y Pointer to vector \f$ \mathbf{y} \f$ + /// \param incY Increment of \f$ \mathbf{y} \f$ + /// \param backend Backend to use for computation + template + void gemv(bool trans, Int m, Int n, Alpha alpha, A *a, Int lda, X *x, Int incX, Beta beta, Y *y, + Int incY, backend::CPU backend = backend::CPU()) { + // On the CPU, cxxblas provides a generic implementation for all types along with BLAS + // implementations where available + + cxxblas::gemv(cxxblas::StorageOrder::RowMajor, + (trans ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans), + static_cast(m), + static_cast(n), + alpha, + a, + static_cast(lda), + x, + static_cast(incX), + beta, + y, + static_cast(incY)); + } #if defined(LIBRAPID_HAS_OPENCL) - template - void gemv(bool trans, Int m, Int n, Alpha alpha, cl::Buffer a, Int lda, cl::Buffer x, Int incX, - Beta beta, cl::Buffer y, Int incY, backend::OpenCL) { - // We have no other type information, so this is the best we can do without introducing - // another template parameter - using GemvScalar = decltype(alpha * beta); - - if constexpr (typetraits::IsBlasType::value) { - // clblast only provides a BLAS implementation for supported types - auto status = - clblast::Gemv(clblast::Layout::kRowMajor, - (trans ? clblast::Transpose::kYes : clblast::Transpose::kNo), - m, - n, - alpha, - a(), - 0, - lda, - x(), - 0, - incX, - beta, - y(), - 0, - incY, - &global::openCLQueue()); - - LIBRAPID_ASSERT(status == clblast::StatusCode::kSuccess, - "clblast::Gemv failed: {}", - opencl::getCLBlastErrorString(status)); - } else { - // We have no BLAS implementation for this type, so we need to use our own kernel. - // Luckily, this is almost as fast as the clblast implementation, so we don't lose - // much performance. - - std::string kernelNameFull = - std::string("gemv_") + typetraits::TypeInfo::name; - cl::Kernel kernel(global::openCLProgram, kernelNameFull.c_str()); - kernel.setArg(0, (int)trans); - kernel.setArg(1, static_cast(m)); - kernel.setArg(2, static_cast(n)); - kernel.setArg(3, static_cast(alpha)); - kernel.setArg(4, a); - kernel.setArg(5, static_cast(lda)); - kernel.setArg(6, x); - kernel.setArg(7, static_cast(incX)); - kernel.setArg(8, static_cast(beta)); - kernel.setArg(9, y); - kernel.setArg(10, static_cast(incY)); - - cl::NDRange globalWorkSize = cl::NDRange(m * n); - - auto status = global::openCLQueue.enqueueNDRangeKernel( - kernel, cl::NullRange, globalWorkSize, cl::NullRange); - - LIBRAPID_ASSERT(status == CL_SUCCESS, - "cl::CommandQueue::enqueueNDRangeKernel GEMV call failed: {}", - opencl::getOpenCLErrorString(status)); - } - } + template + void gemv(bool trans, Int m, Int n, Alpha alpha, cl::Buffer a, Int lda, cl::Buffer x, Int incX, + Beta beta, cl::Buffer y, Int incY, backend::OpenCL) { + // We have no other type information, so this is the best we can do without introducing + // another template parameter + using GemvScalar = decltype(alpha * beta); + + if constexpr (typetraits::IsBlasType::value) { + // clblast only provides a BLAS implementation for supported types + auto status = + clblast::Gemv(clblast::Layout::kRowMajor, + (trans ? clblast::Transpose::kYes : clblast::Transpose::kNo), + m, + n, + alpha, + a(), + 0, + lda, + x(), + 0, + incX, + beta, + y(), + 0, + incY, + &global::openCLQueue()); + + LIBRAPID_ASSERT(status == clblast::StatusCode::kSuccess, + "clblast::Gemv failed: {}", + opencl::getCLBlastErrorString(status)); + } else { + // We have no BLAS implementation for this type, so we need to use our own kernel. + // Luckily, this is almost as fast as the clblast implementation, so we don't lose + // much performance. + + std::string kernelNameFull = + std::string("gemv_") + typetraits::TypeInfo::name; + cl::Kernel kernel(global::openCLProgram, kernelNameFull.c_str()); + kernel.setArg(0, (int)trans); + kernel.setArg(1, static_cast(m)); + kernel.setArg(2, static_cast(n)); + kernel.setArg(3, static_cast(alpha)); + kernel.setArg(4, a); + kernel.setArg(5, static_cast(lda)); + kernel.setArg(6, x); + kernel.setArg(7, static_cast(incX)); + kernel.setArg(8, static_cast(beta)); + kernel.setArg(9, y); + kernel.setArg(10, static_cast(incY)); + + cl::NDRange globalWorkSize = cl::NDRange(m * n); + + auto status = global::openCLQueue.enqueueNDRangeKernel( + kernel, cl::NullRange, globalWorkSize, cl::NullRange); + + LIBRAPID_ASSERT(status == CL_SUCCESS, + "cl::CommandQueue::enqueueNDRangeKernel GEMV call failed: {}", + opencl::getOpenCLErrorString(status)); + } + } #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - /* - template - void gemv(bool trans, Int m, Int n, Alpha alpha, float *a, Int lda, float *x, Int incX, - Beta beta, float *y, Int incY, backend::CUDA) { - cublasSafeCall(cublasSgemv(global::cublasHandle, - (trans ? CUBLAS_OP_N : CUBLAS_OP_T), - n, - m, - &alpha, - a, - lda, - x, - incX, - &beta, - y, - incY)); - } - - template - void gemv(bool trans, Int m, Int n, Alpha alpha, double *a, Int lda, double *x, Int incX, - Beta beta, double *y, Int incY, backend::CUDA) { - cublasSafeCall(cublasDgemv(global::cublasHandle, - (trans ? CUBLAS_OP_N : CUBLAS_OP_T), - n, - m, - &alpha, - a, - lda, - x, - incX, - &beta, - y, - incY)); - } - - template - void gemv(bool trans, Int m, Int n, Alpha alpha, Complex *a, Int lda, Complex *x, - Int incX, Beta beta, Complex *y, Int incY, backend::CUDA) { - cublasSafeCall(cublasCgemv(global::cublasHandle, - (trans ? CUBLAS_OP_N : CUBLAS_OP_T), - n, - m, - &alpha, - reinterpret_cast(a), - lda, - reinterpret_cast(x), - incX, - &beta, - reinterpret_cast(y), - incY)); - } - - template - void gemv(bool trans, Int m, Int n, Alpha alpha, Complex *a, Int lda, - Complex *x, Int incX, Beta beta, Complex *y, Int incY, - backend::CUDA) { - cublasSafeCall(cublasZgemv(global::cublasHandle, - (trans ? CUBLAS_OP_N : CUBLAS_OP_T), - n, - m, - &alpha, - reinterpret_cast(a), - lda, - reinterpret_cast(x), - incX, - &beta, - reinterpret_cast(y), - incY)); - } - - template - void gemv(bool trans, Int m, Int n, Alpha alpha, A *a, Int lda, X *x, Int incX, Beta beta, Y *y, - Int incY, backend::CUDA) { - jitify::Program program = global::jitCache.program(cuda::loadKernel( - fmt::format("{}/include/librapid/array/linalg/level2/gemv", LIBRAPID_SOURCE), false)); - - Int elements = m * n; - Int threadsPerBlock, blocksPerGrid; - - // Use 1 to 512 threads per block - if (elements < 512) { - threadsPerBlock = static_cast(elements); - blocksPerGrid = 1; - } else { - threadsPerBlock = 512; - blocksPerGrid = static_cast( - ceil(static_cast(elements) / static_cast(threadsPerBlock))); - } - - dim3 grid(blocksPerGrid); - dim3 block(threadsPerBlock); - - jitifyCall(program.kernel("gemv") - .instantiate(jitify::reflection::Type(), - jitify::reflection::Type(), - jitify::reflection::Type(), - jitify::reflection::Type(), - jitify::reflection::Type(), - jitify::reflection::Type()) - .configure(grid, block, 0, global::cudaStream) - .launch(trans, m, n, alpha, a, lda, x, incX, beta, y, incY)); - } - */ - - template - void gemv(bool trans, Int m, Int n, Alpha alpha, A *a, Int lda, X *x, Int incX, Beta beta, Y *y, - Int incY, backend::CUDA) { - // With CUDA, it's actually faster to use cuBLAS LT MatMul than to use cuBLAS GEMV, so - // we can just pass the call through to the gemm function instead. Additionally, this works with - // half precision types as well, so we don't need to implement a separate gemv function for - // those. This is likely because cuBLAS LT MatMul is optimized for small matrix sizes, which - // is what we're dealing with here. - - gemm( - trans, false, m, int64_t(1), n, alpha, a, lda, x, incX, beta, y, incY, backend::CUDA()); - } + /* + template + void gemv(bool trans, Int m, Int n, Alpha alpha, float *a, Int lda, float *x, Int incX, + Beta beta, float *y, Int incY, backend::CUDA) { + cublasSafeCall(cublasSgemv(global::cublasHandle, + (trans ? CUBLAS_OP_N : CUBLAS_OP_T), + n, + m, + &alpha, + a, + lda, + x, + incX, + &beta, + y, + incY)); + } + + template + void gemv(bool trans, Int m, Int n, Alpha alpha, double *a, Int lda, double *x, Int incX, + Beta beta, double *y, Int incY, backend::CUDA) { + cublasSafeCall(cublasDgemv(global::cublasHandle, + (trans ? CUBLAS_OP_N : CUBLAS_OP_T), + n, + m, + &alpha, + a, + lda, + x, + incX, + &beta, + y, + incY)); + } + + template + void gemv(bool trans, Int m, Int n, Alpha alpha, Complex *a, Int lda, Complex *x, + Int incX, Beta beta, Complex *y, Int incY, backend::CUDA) { + cublasSafeCall(cublasCgemv(global::cublasHandle, + (trans ? CUBLAS_OP_N : CUBLAS_OP_T), + n, + m, + &alpha, + reinterpret_cast(a), + lda, + reinterpret_cast(x), + incX, + &beta, + reinterpret_cast(y), + incY)); + } + + template + void gemv(bool trans, Int m, Int n, Alpha alpha, Complex *a, Int lda, + Complex *x, Int incX, Beta beta, Complex *y, Int incY, + backend::CUDA) { + cublasSafeCall(cublasZgemv(global::cublasHandle, + (trans ? CUBLAS_OP_N : CUBLAS_OP_T), + n, + m, + &alpha, + reinterpret_cast(a), + lda, + reinterpret_cast(x), + incX, + &beta, + reinterpret_cast(y), + incY)); + } + + template + void gemv(bool trans, Int m, Int n, Alpha alpha, A *a, Int lda, X *x, Int incX, Beta beta, Y *y, + Int incY, backend::CUDA) { + jitify::Program program = global::jitCache.program(cuda::loadKernel( + fmt::format("{}/include/librapid/array/linalg/level2/gemv", LIBRAPID_SOURCE), false)); + + Int elements = m * n; + Int threadsPerBlock, blocksPerGrid; + + // Use 1 to 512 threads per block + if (elements < 512) { + threadsPerBlock = static_cast(elements); + blocksPerGrid = 1; + } else { + threadsPerBlock = 512; + blocksPerGrid = static_cast( + ceil(static_cast(elements) / static_cast(threadsPerBlock))); + } + + dim3 grid(blocksPerGrid); + dim3 block(threadsPerBlock); + + jitifyCall(program.kernel("gemv") + .instantiate(jitify::reflection::Type(), + jitify::reflection::Type(), + jitify::reflection::Type(), + jitify::reflection::Type(), + jitify::reflection::Type(), + jitify::reflection::Type()) + .configure(grid, block, 0, global::cudaStream) + .launch(trans, m, n, alpha, a, lda, x, incX, beta, y, incY)); + } + */ + + template + void gemv(bool trans, Int m, Int n, Alpha alpha, A *a, Int lda, X *x, Int incX, Beta beta, Y *y, + Int incY, backend::CUDA) { + // With CUDA, it's actually faster to use cuBLAS LT MatMul than to use cuBLAS GEMV, so + // we can just pass the call through to the gemm function instead. Additionally, this works + // with half precision types as well, so we don't need to implement a separate gemv function + // for those. This is likely because cuBLAS LT MatMul is optimized for small matrix sizes, + // which is what we're dealing with here. + + gemm( + trans, false, m, int64_t(1), n, alpha, a, lda, x, incX, beta, y, incY, backend::CUDA()); + } #endif // LIBRAPID_HAS_CUDA diff --git a/librapid/include/librapid/array/linalg/level3/geam.hpp b/librapid/include/librapid/array/linalg/level3/geam.hpp index cce26292..3ab4570e 100644 --- a/librapid/include/librapid/array/linalg/level3/geam.hpp +++ b/librapid/include/librapid/array/linalg/level3/geam.hpp @@ -2,94 +2,90 @@ #define LIBRAPID_ARRAY_LINALG_LEVEL3_GEAM_HPP namespace librapid { - namespace linalg { + namespace linalg { #define GEAM_VALIDATION \ - LIBRAPID_ASSERT(a.shape() == b.shape(), "Input shapes must match"); \ - LIBRAPID_ASSERT(a.ndim() == 2, "Input array must be a Matrix (2D)"); \ - LIBRAPID_ASSERT(a.shape() == c.shape(), "Output shape must match input shapes"); \ - LIBRAPID_ASSERT((void *)&a != (void *)&c, "Input and output arrays must be different"); \ - LIBRAPID_ASSERT((void *)&b != (void *)&c, "Input and output arrays must be different") - - /// \brief General matrix-matrix addition. - /// - /// Computes \f$ \mathbf{C} = \alpha \mathrm{op}_A(\mathbf{A}) + \beta \mathrm{op}_B(\mathbf{B}) \f$, - /// for matrices \f$ \mathbf{A} \f$ and \f$ \mathbf{B} \f$ and scalars \f$ \alpha \f$ and \f$ \beta \f$. - /// \tparam StorageScalar Storage type of the input and output arrays. - /// \tparam ShapeTypeA Shape type of the first input array. - /// \tparam ShapeTypeB Shape type of the second input array. - /// \tparam ShapeTypeC Shape type of the output array. - /// \tparam Alpha Scalar type of the \f$ \alpha \f$ parameter. - /// \tparam Beta Scalar type of the \f$ \beta \f$ parameter. - /// \param a First input array. - /// \param alpha Scalar \f$ \alpha \f$. - /// \param b Second input array. - /// \param beta Scalar \f$ \beta \f$. - /// \param c Output array. - template - void geam(const array::ArrayContainer> &a, Alpha alpha, - const array::ArrayContainer> &b, Beta beta, - array::ArrayContainer> &c) { - GEAM_VALIDATION; - - c = a * static_cast(alpha) + b * static_cast(beta); - } - - template - void - geam(const array::Transpose>> &a, - Alpha alpha, const array::ArrayContainer> &b, - Beta beta, array::ArrayContainer> &c) { - GEAM_VALIDATION; - - // Eval before returning to avoid slow evaluation - // see https://librapid.readthedocs.io/en/latest/performance/performance.html - - const auto &dataA = a.array(); - - c = array::Transpose(dataA, {1, 0}, static_cast(alpha)).eval() + - b * static_cast(beta); - } - - template - void - geam(const array::ArrayContainer> &a, Alpha alpha, - const array::Transpose>> &b, - Beta beta, array::ArrayContainer> &c) { - GEAM_VALIDATION; - - // Eval before returning to avoid slow evaluation - // see https://librapid.readthedocs.io/en/latest/performance/performance.html - - const auto &dataB = b.array(); - - c = a * static_cast(alpha) + - array::Transpose(dataB, {1, 0}, static_cast(beta)).eval(); - } - - template - void - geam(const array::Transpose>> &a, - Alpha alpha, - const array::Transpose>> &b, - Beta beta, array::ArrayContainer> &c) { - GEAM_VALIDATION; - - // Eval before returning to avoid slow evaluation - // see https://librapid.readthedocs.io/en/latest/performance/performance.html - - // alpha * a^T + beta * b^T = (alpha * a + beta * b)^T - - const auto &dataA = a.array(); - const auto &dataB = b.array(); - - c = transpose( - (dataA * static_cast(alpha) + dataB * static_cast(beta)) - .eval()); - } + LIBRAPID_ASSERT(a.shape() == b.shape(), "Input shapes must match"); \ + LIBRAPID_ASSERT(a.ndim() == 2, "Input array must be a Matrix (2D)"); \ + LIBRAPID_ASSERT(a.shape() == c.shape(), "Output shape must match input shapes"); \ + LIBRAPID_ASSERT((void *)&a != (void *)&c, "Input and output arrays must be different"); \ + LIBRAPID_ASSERT((void *)&b != (void *)&c, "Input and output arrays must be different") + + /// \brief General matrix-matrix addition. + /// + /// Computes \f$ \mathbf{C} = \alpha \mathrm{op}_A(\mathbf{A}) + \beta + /// \mathrm{op}_B(\mathbf{B}) \f$, for matrices \f$ \mathbf{A} \f$ and \f$ \mathbf{B} \f$ + /// and scalars \f$ \alpha \f$ and \f$ \beta \f$. \tparam StorageScalar Storage type of the + /// input and output arrays. \tparam ShapeTypeA Shape type of the first input array. \tparam + /// ShapeTypeB Shape type of the second input array. \tparam ShapeTypeC Shape type of the + /// output array. \tparam Alpha Scalar type of the \f$ \alpha \f$ parameter. \tparam Beta + /// Scalar type of the \f$ \beta \f$ parameter. \param a First input array. \param alpha + /// Scalar \f$ \alpha \f$. \param b Second input array. \param beta Scalar \f$ \beta \f$. + /// \param c Output array. + template + void geam(const array::ArrayContainer> &a, Alpha alpha, + const array::ArrayContainer> &b, Beta beta, + array::ArrayContainer> &c) { + GEAM_VALIDATION; + + c = a * static_cast(alpha) + b * static_cast(beta); + } + + template + void + geam(const array::Transpose>> &a, + Alpha alpha, const array::ArrayContainer> &b, + Beta beta, array::ArrayContainer> &c) { + GEAM_VALIDATION; + + // Eval before returning to avoid slow evaluation + // see https://librapid.readthedocs.io/en/latest/performance/performance.html + + const auto &dataA = a.array(); + + c = array::Transpose(dataA, {1, 0}, static_cast(alpha)).eval() + + b * static_cast(beta); + } + + template + void + geam(const array::ArrayContainer> &a, Alpha alpha, + const array::Transpose>> &b, + Beta beta, array::ArrayContainer> &c) { + GEAM_VALIDATION; + + // Eval before returning to avoid slow evaluation + // see https://librapid.readthedocs.io/en/latest/performance/performance.html + + const auto &dataB = b.array(); + + c = a * static_cast(alpha) + + array::Transpose(dataB, {1, 0}, static_cast(beta)).eval(); + } + + template + void + geam(const array::Transpose>> &a, + Alpha alpha, + const array::Transpose>> &b, + Beta beta, array::ArrayContainer> &c) { + GEAM_VALIDATION; + + // Eval before returning to avoid slow evaluation + // see https://librapid.readthedocs.io/en/latest/performance/performance.html + + // alpha * a^T + beta * b^T = (alpha * a + beta * b)^T + + const auto &dataA = a.array(); + const auto &dataB = b.array(); + + c = transpose( + (dataA * static_cast(alpha) + dataB * static_cast(beta)) + .eval()); + } #if defined(LIBRAPID_HAS_OPENCL) @@ -97,285 +93,285 @@ namespace librapid { #if defined(LIBRAPID_HAS_CUDA) - template - void geam(const array::ArrayContainer> &a, - Alpha alpha, - const array::ArrayContainer> &b, Beta beta, - array::ArrayContainer> &c) { - GEAM_VALIDATION; - - c = a * static_cast(alpha) + b * static_cast(beta); - } - - template - void geam( - const array::Transpose>> &a, - Alpha alpha, const array::ArrayContainer> &b, - Beta beta, array::ArrayContainer> &c) { - GEAM_VALIDATION; - - const auto &dataA = a.array(); - - c = array::Transpose(dataA, {1, 0}, static_cast(alpha)).eval() + - b * static_cast(beta); - } - - template - void geam( - const array::ArrayContainer> &a, Alpha alpha, - const array::Transpose>> &b, - Beta beta, array::ArrayContainer> &c) { - GEAM_VALIDATION; - - const auto &dataB = b.array(); - - c = a * static_cast(alpha) + - array::Transpose(dataB, {1, 0}, static_cast(beta)).eval(); - } - - template - void geam( - const array::Transpose>> &a, - Alpha alpha, - const array::Transpose>> &b, - Beta beta, array::ArrayContainer> &c) { - GEAM_VALIDATION; - - const auto &dataA = a.array(); - const auto &dataB = b.array(); - - c = transpose( - (dataA * static_cast(alpha) + dataB * static_cast(beta)) - .eval()); - } - -# define LIBRAPID_CUDA_GEAM_IMPL(SCALAR, PREFIX) \ - template \ - void geam(const array::ArrayContainer> &a, \ - Alpha alpha, \ - const array::ArrayContainer> &b, \ - Beta beta, \ - array::ArrayContainer> &c) { \ - GEAM_VALIDATION; \ + template + void geam(const array::ArrayContainer> &a, + Alpha alpha, + const array::ArrayContainer> &b, Beta beta, + array::ArrayContainer> &c) { + GEAM_VALIDATION; + + c = a * static_cast(alpha) + b * static_cast(beta); + } + + template + void geam( + const array::Transpose>> &a, + Alpha alpha, const array::ArrayContainer> &b, + Beta beta, array::ArrayContainer> &c) { + GEAM_VALIDATION; + + const auto &dataA = a.array(); + + c = array::Transpose(dataA, {1, 0}, static_cast(alpha)).eval() + + b * static_cast(beta); + } + + template + void geam( + const array::ArrayContainer> &a, Alpha alpha, + const array::Transpose>> &b, + Beta beta, array::ArrayContainer> &c) { + GEAM_VALIDATION; + + const auto &dataB = b.array(); + + c = a * static_cast(alpha) + + array::Transpose(dataB, {1, 0}, static_cast(beta)).eval(); + } + + template + void geam( + const array::Transpose>> &a, + Alpha alpha, + const array::Transpose>> &b, + Beta beta, array::ArrayContainer> &c) { + GEAM_VALIDATION; + + const auto &dataA = a.array(); + const auto &dataB = b.array(); + + c = transpose( + (dataA * static_cast(alpha) + dataB * static_cast(beta)) + .eval()); + } + +# define LIBRAPID_CUDA_GEAM_IMPL(SCALAR, PREFIX) \ + template \ + void geam(const array::ArrayContainer> &a, \ + Alpha alpha, \ + const array::ArrayContainer> &b, \ + Beta beta, \ + array::ArrayContainer> &c) { \ + GEAM_VALIDATION; \ \ - auto *__restrict dataA = a.storage().begin().get(); \ - auto *__restrict dataB = b.storage().begin().get(); \ - auto *__restrict dataC = c.storage().begin().get(); \ + auto *__restrict dataA = a.storage().begin().get(); \ + auto *__restrict dataB = b.storage().begin().get(); \ + auto *__restrict dataC = c.storage().begin().get(); \ \ - auto alphaTmp = static_cast(alpha); \ - auto betaTmp = static_cast(beta); \ + auto alphaTmp = static_cast(alpha); \ + auto betaTmp = static_cast(beta); \ \ - cublasSafeCall(cublas##PREFIX##geam(global::cublasHandle, \ - CUBLAS_OP_N, \ - CUBLAS_OP_N, \ - a.shape()[0], \ - a.shape()[1], \ - &alphaTmp, \ - dataA, \ - a.shape()[0], \ - &betaTmp, \ - dataB, \ - b.shape()[0], \ - dataC, \ - c.shape()[0])); \ - } \ + cublasSafeCall(cublas##PREFIX##geam(global::cublasHandle, \ + CUBLAS_OP_N, \ + CUBLAS_OP_N, \ + a.shape()[0], \ + a.shape()[1], \ + &alphaTmp, \ + dataA, \ + a.shape()[0], \ + &betaTmp, \ + dataB, \ + b.shape()[0], \ + dataC, \ + c.shape()[0])); \ + } \ \ - template \ - void geam( \ - const array::Transpose>> &a, \ - Alpha alpha, \ - const array::ArrayContainer> &b, \ - Beta beta, \ - array::ArrayContainer> &c) { \ - GEAM_VALIDATION; \ + template \ + void geam( \ + const array::Transpose>> &a, \ + Alpha alpha, \ + const array::ArrayContainer> &b, \ + Beta beta, \ + array::ArrayContainer> &c) { \ + GEAM_VALIDATION; \ \ - auto *__restrict dataA = a.array().storage().begin().get(); \ - auto *__restrict dataB = b.storage().begin().get(); \ - auto *__restrict dataC = c.storage().begin().get(); \ + auto *__restrict dataA = a.array().storage().begin().get(); \ + auto *__restrict dataB = b.storage().begin().get(); \ + auto *__restrict dataC = c.storage().begin().get(); \ \ - auto alphaTmp = static_cast(alpha); \ - auto betaTmp = static_cast(beta); \ + auto alphaTmp = static_cast(alpha); \ + auto betaTmp = static_cast(beta); \ \ - cublasSafeCall(cublas##PREFIX##geam(global::cublasHandle, \ - CUBLAS_OP_T, \ - CUBLAS_OP_N, \ - a.shape()[1], \ - a.shape()[0], \ - &alphaTmp, \ - dataA, \ - a.shape()[0], \ - &betaTmp, \ - dataB, \ - b.shape()[0], \ - dataC, \ - c.shape()[0])); \ - } \ + cublasSafeCall(cublas##PREFIX##geam(global::cublasHandle, \ + CUBLAS_OP_T, \ + CUBLAS_OP_N, \ + a.shape()[1], \ + a.shape()[0], \ + &alphaTmp, \ + dataA, \ + a.shape()[0], \ + &betaTmp, \ + dataB, \ + b.shape()[0], \ + dataC, \ + c.shape()[0])); \ + } \ \ - template \ - void geam( \ - const array::ArrayContainer> &a, \ - Alpha alpha, \ - const array::Transpose>> &b, \ - Beta beta, \ - array::ArrayContainer> &c) { \ - GEAM_VALIDATION; \ + template \ + void geam( \ + const array::ArrayContainer> &a, \ + Alpha alpha, \ + const array::Transpose>> &b, \ + Beta beta, \ + array::ArrayContainer> &c) { \ + GEAM_VALIDATION; \ \ - auto *__restrict dataA = a.storage().begin().get(); \ - auto *__restrict dataB = b.array().storage().begin().get(); \ - auto *__restrict dataC = c.storage().begin().get(); \ + auto *__restrict dataA = a.storage().begin().get(); \ + auto *__restrict dataB = b.array().storage().begin().get(); \ + auto *__restrict dataC = c.storage().begin().get(); \ \ - auto alphaTmp = static_cast(alpha); \ - auto betaTmp = static_cast(beta); \ + auto alphaTmp = static_cast(alpha); \ + auto betaTmp = static_cast(beta); \ \ - cublasSafeCall(cublas##PREFIX##geam(global::cublasHandle, \ - CUBLAS_OP_N, \ - CUBLAS_OP_T, \ - a.shape()[0], \ - a.shape()[1], \ - &alphaTmp, \ - dataA, \ - a.shape()[0], \ - &betaTmp, \ - dataB, \ - b.shape()[0], \ - dataC, \ - c.shape()[0])); \ - } \ + cublasSafeCall(cublas##PREFIX##geam(global::cublasHandle, \ + CUBLAS_OP_N, \ + CUBLAS_OP_T, \ + a.shape()[0], \ + a.shape()[1], \ + &alphaTmp, \ + dataA, \ + a.shape()[0], \ + &betaTmp, \ + dataB, \ + b.shape()[0], \ + dataC, \ + c.shape()[0])); \ + } \ \ - template \ - void geam( \ - const array::Transpose>> &a, \ - Alpha alpha, \ - const array::Transpose>> &b, \ - Beta beta, \ - array::ArrayContainer> &c) { \ - GEAM_VALIDATION; \ + template \ + void geam( \ + const array::Transpose>> &a, \ + Alpha alpha, \ + const array::Transpose>> &b, \ + Beta beta, \ + array::ArrayContainer> &c) { \ + GEAM_VALIDATION; \ \ - auto *__restrict dataA = a.array().storage().begin().get(); \ - auto *__restrict dataB = b.array().storage().begin().get(); \ - auto *__restrict dataC = c.storage().begin().get(); \ + auto *__restrict dataA = a.array().storage().begin().get(); \ + auto *__restrict dataB = b.array().storage().begin().get(); \ + auto *__restrict dataC = c.storage().begin().get(); \ \ - auto alphaTmp = static_cast(alpha); \ - auto betaTmp = static_cast(beta); \ + auto alphaTmp = static_cast(alpha); \ + auto betaTmp = static_cast(beta); \ \ - cublasSafeCall(cublas##PREFIX##geam(global::cublasHandle, \ - CUBLAS_OP_T, \ - CUBLAS_OP_T, \ - a.shape()[1], \ - a.shape()[0], \ - &alphaTmp, \ - dataA, \ - a.shape()[0], \ - &betaTmp, \ - dataB, \ - b.shape()[0], \ - dataC, \ - c.shape()[0])); \ - } - - LIBRAPID_CUDA_GEAM_IMPL(float, S) - LIBRAPID_CUDA_GEAM_IMPL(double, D) - LIBRAPID_CUDA_GEAM_IMPL(Complex, C) - LIBRAPID_CUDA_GEAM_IMPL(Complex, Z) + cublasSafeCall(cublas##PREFIX##geam(global::cublasHandle, \ + CUBLAS_OP_T, \ + CUBLAS_OP_T, \ + a.shape()[1], \ + a.shape()[0], \ + &alphaTmp, \ + dataA, \ + a.shape()[0], \ + &betaTmp, \ + dataB, \ + b.shape()[0], \ + dataC, \ + c.shape()[0])); \ + } + + LIBRAPID_CUDA_GEAM_IMPL(float, S) + LIBRAPID_CUDA_GEAM_IMPL(double, D) + LIBRAPID_CUDA_GEAM_IMPL(Complex, C) + LIBRAPID_CUDA_GEAM_IMPL(Complex, Z) #endif // LIBRAPID_HAS_CUDA - } // namespace linalg - - namespace typetraits { - template - struct HasCustomEval< - detail::Function, ScalarType1>, - detail::Function, ScalarType2>>> - : std::true_type {}; - }; // namespace typetraits - - namespace detail { - // aT * b + cT * d - template - LIBRAPID_ALWAYS_INLINE void assign( - array::ArrayContainer &destination, - const Function< - Descriptor1, detail::Plus, - Function, ScalarType1>, - Function, ScalarType2>> - &function) { - // Since GEAM only applies to matrices, we must check that we can actually use it given - // the input matrices. If we can't, we fall back to the default implementation. - - using Scalar = typename DestinationStorageType::Scalar; - - bool canUseGeam = true; - auto left = std::get<0>(function.args()); - auto leftMat = std::get<0>(left.args()); - auto leftScalar = std::get<1>(left.args()); - auto right = std::get<1>(function.args()); - auto rightMat = std::get<0>(right.args()); - auto rightScalar = std::get<1>(right.args()); - - if (leftMat.ndim() != 2 || rightMat.ndim() != 2 || destination.ndim() != 2) { - canUseGeam = false; - } - - if (leftMat.shape() != rightMat.shape() || leftMat.shape() != destination.shape()) { - canUseGeam = false; - } - - if (canUseGeam) { - linalg::geam(leftMat, - static_cast(leftScalar), - rightMat, - static_cast(rightScalar), - destination); - } else { - auto axes1 = leftMat.axes(); - auto alpha = leftMat.alpha() * static_cast(leftScalar); - auto axes2 = rightMat.axes(); - auto beta = rightMat.alpha() * static_cast(rightScalar); - destination = array::Transpose(leftMat.array(), axes1, alpha).eval() + - array::Transpose(rightMat.array(), axes2, beta).eval(); - } - } - - template - LIBRAPID_ALWAYS_INLINE void assignParallel( - array::ArrayContainer &destination, - const Function< - Descriptor1, detail::Plus, - Function, ScalarType1>, - Function, ScalarType2>> - &function) { - assign(destination, function); - } - } // namespace detail + } // namespace linalg + + namespace typetraits { + template + struct HasCustomEval< + detail::Function, ScalarType1>, + detail::Function, ScalarType2>>> + : std::true_type {}; + }; // namespace typetraits + + namespace detail { + // aT * b + cT * d + template + LIBRAPID_ALWAYS_INLINE void assign( + array::ArrayContainer &destination, + const Function< + Descriptor1, detail::Plus, + Function, ScalarType1>, + Function, ScalarType2>> + &function) { + // Since GEAM only applies to matrices, we must check that we can actually use it given + // the input matrices. If we can't, we fall back to the default implementation. + + using Scalar = typename DestinationStorageType::Scalar; + + bool canUseGeam = true; + auto left = std::get<0>(function.args()); + auto leftMat = std::get<0>(left.args()); + auto leftScalar = std::get<1>(left.args()); + auto right = std::get<1>(function.args()); + auto rightMat = std::get<0>(right.args()); + auto rightScalar = std::get<1>(right.args()); + + if (leftMat.ndim() != 2 || rightMat.ndim() != 2 || destination.ndim() != 2) { + canUseGeam = false; + } + + if (leftMat.shape() != rightMat.shape() || leftMat.shape() != destination.shape()) { + canUseGeam = false; + } + + if (canUseGeam) { + linalg::geam(leftMat, + static_cast(leftScalar), + rightMat, + static_cast(rightScalar), + destination); + } else { + auto axes1 = leftMat.axes(); + auto alpha = leftMat.alpha() * static_cast(leftScalar); + auto axes2 = rightMat.axes(); + auto beta = rightMat.alpha() * static_cast(rightScalar); + destination = array::Transpose(leftMat.array(), axes1, alpha).eval() + + array::Transpose(rightMat.array(), axes2, beta).eval(); + } + } + + template + LIBRAPID_ALWAYS_INLINE void assignParallel( + array::ArrayContainer &destination, + const Function< + Descriptor1, detail::Plus, + Function, ScalarType1>, + Function, ScalarType2>> + &function) { + assign(destination, function); + } + } // namespace detail } // namespace librapid #endif // LIBRAPID_ARRAY_LINALG_LEVEL3_GEAM_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/linalg/level3/gemm.cl b/librapid/include/librapid/array/linalg/level3/gemm.cl index 62a7cc7a..de0e94fe 100644 --- a/librapid/include/librapid/array/linalg/level3/gemm.cl +++ b/librapid/include/librapid/array/linalg/level3/gemm.cl @@ -1,56 +1,54 @@ #define TS 32 // Tile size #define GEMM_IMPL(TYPE) \ - __kernel void gemm_##TYPE(const int transA, \ - const int transB, \ - const int32_t M, \ - const int32_t N, \ - const int32_t K, \ - const TYPE alpha, \ - __global TYPE *A, \ - const int32_t lda, \ - __global const TYPE *B, \ - const int32_t ldb, \ - const TYPE beta, \ - __global TYPE *C, \ - const int32_t ldc) { \ - const int32_t inx = get_global_id(0); \ - const int32_t iny = get_global_id(1); \ - const int32_t ibx = get_local_id(0); \ - const int32_t iby = get_local_id(1); \ - \ - __local TYPE Asub[TS][TS]; \ - __local TYPE Bsub[TS][TS]; \ - \ - TYPE acc = 0; \ - \ - const int32_t numTiles = K / TS + 1; \ - \ - for (int32_t t = 0; t < numTiles; t++) { \ - const int32_t tiledIndex = t * TS + ibx; \ - \ - Asub[iby][ibx] = (tiledIndex < K && iny < M) \ - ? (transA ? A[tiledIndex + lda * iny] : A[iny * lda + tiledIndex]) \ - : 0.0f; \ - Bsub[iby][ibx] = (tiledIndex < K && inx < N) \ - ? (transB ? B[tiledIndex + ldb * inx] : B[iny * ldb + tiledIndex]) \ - : 0.0f; \ - \ - barrier(CLK_LOCAL_MEM_FENCE); \ - \ - for (int32_t k = 0; k < TS; k++) { \ - if (t * TS + k < K) { \ - acc += Asub[iby][k] * Bsub[k][ibx]; \ - } \ - } \ - \ - barrier(CLK_LOCAL_MEM_FENCE); \ - } \ - \ - if (iny < M && inx < N) { \ - C[(iny * ldc) + inx] = alpha * acc + beta * C[(iny * ldc) + inx]; \ - } \ - } + __kernel void gemm_##TYPE(const int transA, \ + const int transB, \ + const int32_t M, \ + const int32_t N, \ + const int32_t K, \ + const TYPE alpha, \ + __global TYPE *A, \ + const int32_t lda, \ + __global const TYPE *B, \ + const int32_t ldb, \ + const TYPE beta, \ + __global TYPE *C, \ + const int32_t ldc) { \ + const int32_t inx = get_global_id(0); \ + const int32_t iny = get_global_id(1); \ + const int32_t ibx = get_local_id(0); \ + const int32_t iby = get_local_id(1); \ + \ + __local TYPE Asub[TS][TS]; \ + __local TYPE Bsub[TS][TS]; \ + \ + TYPE acc = 0; \ + \ + const int32_t numTiles = K / TS + 1; \ + \ + for (int32_t t = 0; t < numTiles; t++) { \ + const int32_t tiledIndex = t * TS + ibx; \ + \ + Asub[iby][ibx] = (tiledIndex < K && iny < M) \ + ? (transA ? A[tiledIndex + lda * iny] : A[iny * lda + tiledIndex]) \ + : 0.0f; \ + Bsub[iby][ibx] = (tiledIndex < K && inx < N) \ + ? (transB ? B[tiledIndex + ldb * inx] : B[iny * ldb + tiledIndex]) \ + : 0.0f; \ + \ + barrier(CLK_LOCAL_MEM_FENCE); \ + \ + for (int32_t k = 0; k < TS; k++) { \ + if (t * TS + k < K) { acc += Asub[iby][k] * Bsub[k][ibx]; } \ + } \ + \ + barrier(CLK_LOCAL_MEM_FENCE); \ + } \ + \ + if (iny < M && inx < N) { \ + C[(iny * ldc) + inx] = alpha * acc + beta * C[(iny * ldc) + inx]; \ + } \ + } GEMM_IMPL(int8_t) GEMM_IMPL(int16_t) diff --git a/librapid/include/librapid/array/linalg/level3/gemm.cu b/librapid/include/librapid/array/linalg/level3/gemm.cu index 8c312052..fafd2659 100644 --- a/librapid/include/librapid/array/linalg/level3/gemm.cu +++ b/librapid/include/librapid/array/linalg/level3/gemm.cu @@ -1,39 +1,39 @@ #define TS 32 // Tile size template + typename TypeC> __global__ void gemm(bool transA, bool transB, Int m, Int n, Int k, Alpha alpha, TypeA *a, Int lda, - TypeB *b, Int ldb, Beta beta, TypeC *c, Int ldc) { - const Int inx = blockIdx.x * blockDim.x + threadIdx.x; - const Int iny = blockIdx.y * blockDim.y + threadIdx.y; - const Int ibx = threadIdx.x; - const Int iby = threadIdx.y; + TypeB *b, Int ldb, Beta beta, TypeC *c, Int ldc) { + const Int inx = blockIdx.x * blockDim.x + threadIdx.x; + const Int iny = blockIdx.y * blockDim.y + threadIdx.y; + const Int ibx = threadIdx.x; + const Int iby = threadIdx.y; - __shared__ TypeA Asub[TS][TS]; - __shared__ TypeB Bsub[TS][TS]; + __shared__ TypeA Asub[TS][TS]; + __shared__ TypeB Bsub[TS][TS]; - TypeC acc = 0; + TypeC acc = 0; - const Int numTiles = (k + TS - 1) / TS; + const Int numTiles = (k + TS - 1) / TS; - for (Int t = 0; t < numTiles; t++) { - const Int tiledIndex = t * TS + ibx; + for (Int t = 0; t < numTiles; t++) { + const Int tiledIndex = t * TS + ibx; - Asub[iby][ibx] = (tiledIndex < k && iny < m) - ? (transA ? a[tiledIndex + lda * iny] : a[iny * lda + tiledIndex]) - : 0.0f; - Bsub[iby][ibx] = (tiledIndex < k && inx < n) - ? (transB ? b[tiledIndex + ldb * inx] : b[iny * ldb + tiledIndex]) - : 0.0f; + Asub[iby][ibx] = (tiledIndex < k && iny < m) + ? (transA ? a[tiledIndex + lda * iny] : a[iny * lda + tiledIndex]) + : 0.0f; + Bsub[iby][ibx] = (tiledIndex < k && inx < n) + ? (transB ? b[tiledIndex + ldb * inx] : b[iny * ldb + tiledIndex]) + : 0.0f; - __syncthreads(); + __syncthreads(); - for (Int j = 0; j < TS; j++) { - if (t * TS + j < k) { acc += Asub[iby][j] * Bsub[j][ibx]; } - } + for (Int j = 0; j < TS; j++) { + if (t * TS + j < k) { acc += Asub[iby][j] * Bsub[j][ibx]; } + } - __syncthreads(); - } + __syncthreads(); + } - if (iny < m && inx < n) { c[(iny * ldc) + inx] = alpha * acc + beta * c[(iny * ldc) + inx]; } + if (iny < m && inx < n) { c[(iny * ldc) + inx] = alpha * acc + beta * c[(iny * ldc) + inx]; } } diff --git a/librapid/include/librapid/array/linalg/level3/gemm.hpp b/librapid/include/librapid/array/linalg/level3/gemm.hpp index 784b2029..59ceb740 100644 --- a/librapid/include/librapid/array/linalg/level3/gemm.hpp +++ b/librapid/include/librapid/array/linalg/level3/gemm.hpp @@ -2,300 +2,300 @@ #define LIBRAPID_ARRAY_LINALG_LEVEL3_GEMM_HPP namespace librapid::linalg { - /// \brief General matrix-matrix multiplication - /// - /// Computes \f$ \mathbf{C} = \alpha \mathrm{OP}_A(\mathbf{A}) \mathrm{OP}_B(\mathbf{B}) + - /// \beta \mathbf{C} \f$ - /// for matrices \f$ \mathbf{A} \f$, \f$ \mathbf{B} \f$ and \f$ \mathbf{C} \f$. - /// \f$ \mathrm{OP}_A \f$ and \f$ \mathrm{OP}_B \f$ are - /// either the identity or the transpose operation. - /// \tparam Int Integer type for matrix dimensions - /// \tparam Alpha Type of \f$ \alpha \f$ - /// \tparam A Type of \f$ \mathbf{A} \f$ - /// \tparam B Type of \f$ \mathbf{B} \f$ - /// \tparam Beta Type of \f$ \beta \f$ - /// \tparam C Type of \f$ \mathbf{C} \f$ - /// \param transA Whether to transpose \f$ \mathbf{A} \f$ (determines \f$ \mathrm{OP}_A \f$) - /// \param transB Whether to transpose \f$ \mathbf{B} \f$ (determines \f$ \mathrm{OP}_B \f$) - /// \param m Rows of \f$ \mathbf{A} \f$ and \f$ \mathbf{C} \f$ - /// \param n Columns of \f$ \mathbf{B} \f$ and \f$ \mathbf{C} \f$ - /// \param k Columns of \f$ \mathbf{A} \f$ and rows of \f$ \mathbf{B} \f$ - /// \param alpha Scalar \f$ \alpha \f$ - /// \param a Pointer to \f$ \mathbf{A} \f$ - /// \param lda Leading dimension of \f$ \mathbf{A} \f$ - /// \param b Pointer to \f$ \mathbf{B} \f$ - /// \param ldb Leading dimension of \f$ \mathbf{B} \f$ - /// \param beta Scalar \f$ \beta \f$ - /// \param c Pointer to \f$ \mathbf{C} \f$ - /// \param ldc Leading dimension of \f$ \mathbf{C} \f$ - /// \param backend Backend to use for computation - template - void gemm(bool transA, bool transB, Int m, Int n, Int k, Alpha alpha, A *a, Int lda, B *b, - Int ldb, Beta beta, C *c, Int ldc, backend::CPU backend = backend::CPU()) { - cxxblas::gemm(cxxblas::StorageOrder::RowMajor, - (transA ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans), - (transB ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans), - m, - n, - k, - alpha, - a, - lda, - b, - ldb, - beta, - c, - ldc); - } + /// \brief General matrix-matrix multiplication + /// + /// Computes \f$ \mathbf{C} = \alpha \mathrm{OP}_A(\mathbf{A}) \mathrm{OP}_B(\mathbf{B}) + + /// \beta \mathbf{C} \f$ + /// for matrices \f$ \mathbf{A} \f$, \f$ \mathbf{B} \f$ and \f$ \mathbf{C} \f$. + /// \f$ \mathrm{OP}_A \f$ and \f$ \mathrm{OP}_B \f$ are + /// either the identity or the transpose operation. + /// \tparam Int Integer type for matrix dimensions + /// \tparam Alpha Type of \f$ \alpha \f$ + /// \tparam A Type of \f$ \mathbf{A} \f$ + /// \tparam B Type of \f$ \mathbf{B} \f$ + /// \tparam Beta Type of \f$ \beta \f$ + /// \tparam C Type of \f$ \mathbf{C} \f$ + /// \param transA Whether to transpose \f$ \mathbf{A} \f$ (determines \f$ \mathrm{OP}_A \f$) + /// \param transB Whether to transpose \f$ \mathbf{B} \f$ (determines \f$ \mathrm{OP}_B \f$) + /// \param m Rows of \f$ \mathbf{A} \f$ and \f$ \mathbf{C} \f$ + /// \param n Columns of \f$ \mathbf{B} \f$ and \f$ \mathbf{C} \f$ + /// \param k Columns of \f$ \mathbf{A} \f$ and rows of \f$ \mathbf{B} \f$ + /// \param alpha Scalar \f$ \alpha \f$ + /// \param a Pointer to \f$ \mathbf{A} \f$ + /// \param lda Leading dimension of \f$ \mathbf{A} \f$ + /// \param b Pointer to \f$ \mathbf{B} \f$ + /// \param ldb Leading dimension of \f$ \mathbf{B} \f$ + /// \param beta Scalar \f$ \beta \f$ + /// \param c Pointer to \f$ \mathbf{C} \f$ + /// \param ldc Leading dimension of \f$ \mathbf{C} \f$ + /// \param backend Backend to use for computation + template + void gemm(bool transA, bool transB, Int m, Int n, Int k, Alpha alpha, A *a, Int lda, B *b, + Int ldb, Beta beta, C *c, Int ldc, backend::CPU backend = backend::CPU()) { + cxxblas::gemm(cxxblas::StorageOrder::RowMajor, + (transA ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans), + (transB ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans), + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc); + } #if defined(LIBRAPID_HAS_OPENCL) - template - void gemm(bool transA, bool transB, Int m, Int n, Int k, Alpha alpha, cl::Buffer a, Int lda, - cl::Buffer b, Int ldb, Beta beta, cl::Buffer c, Int ldc, backend::OpenCL) { - using GemmScalar = decltype(alpha * beta); - - if constexpr (typetraits::IsBlasType::value) { - auto status = clblast::Gemm( - clblast::Layout::kRowMajor, - (transA ? clblast::Transpose::kYes : clblast::Transpose::kNo), - (transB ? clblast::Transpose::kYes : clblast::Transpose::kNo), - m, - n, - k, - alpha, - a(), - 0, - lda, - b(), - 0, - ldb, - beta, - c(), - 0, - ldc, - &global::openCLQueue()); - - LIBRAPID_ASSERT(status == clblast::StatusCode::kSuccess, - "clblast::Gemm failed: {}", - opencl::getCLBlastErrorString(status)); - } else { - std::string kernelNameFull = - std::string("gemm_") + typetraits::TypeInfo::name; - cl::Kernel kernel(global::openCLProgram, kernelNameFull.c_str()); - kernel.setArg(0, (int)transA); - kernel.setArg(1, (int)transB); - kernel.setArg(2, (int32_t)m); - kernel.setArg(3, (int32_t)n); - kernel.setArg(4, (int32_t)k); - kernel.setArg(5, (GemmScalar)alpha); - kernel.setArg(6, a); - kernel.setArg(7, (int32_t)lda); - kernel.setArg(8, b); - kernel.setArg(9, (int32_t)ldb); - kernel.setArg(10, (GemmScalar)beta); - kernel.setArg(11, c); - kernel.setArg(12, (int32_t)ldc); - - size_t TS = 32; // Must be the same as in the kernel (line 1 of gemm.cu) - - cl::NDRange globalWorkSize = - cl::NDRange(((n - 1) / TS + 1) * TS, ((m - 1) / TS + 1) * TS); - cl::NDRange localWorkSize = cl::NDRange(TS, TS); - - auto status = global::openCLQueue.enqueueNDRangeKernel( - kernel, cl::NullRange, globalWorkSize, localWorkSize); - - LIBRAPID_ASSERT(status == CL_SUCCESS, - "cl::CommandQueue::enqueueNDRangeKernel GEMM call failed: {}", - opencl::getOpenCLErrorString(status)); - } - } + template + void gemm(bool transA, bool transB, Int m, Int n, Int k, Alpha alpha, cl::Buffer a, Int lda, + cl::Buffer b, Int ldb, Beta beta, cl::Buffer c, Int ldc, backend::OpenCL) { + using GemmScalar = decltype(alpha * beta); + + if constexpr (typetraits::IsBlasType::value) { + auto status = clblast::Gemm( + clblast::Layout::kRowMajor, + (transA ? clblast::Transpose::kYes : clblast::Transpose::kNo), + (transB ? clblast::Transpose::kYes : clblast::Transpose::kNo), + m, + n, + k, + alpha, + a(), + 0, + lda, + b(), + 0, + ldb, + beta, + c(), + 0, + ldc, + &global::openCLQueue()); + + LIBRAPID_ASSERT(status == clblast::StatusCode::kSuccess, + "clblast::Gemm failed: {}", + opencl::getCLBlastErrorString(status)); + } else { + std::string kernelNameFull = + std::string("gemm_") + typetraits::TypeInfo::name; + cl::Kernel kernel(global::openCLProgram, kernelNameFull.c_str()); + kernel.setArg(0, (int)transA); + kernel.setArg(1, (int)transB); + kernel.setArg(2, (int32_t)m); + kernel.setArg(3, (int32_t)n); + kernel.setArg(4, (int32_t)k); + kernel.setArg(5, (GemmScalar)alpha); + kernel.setArg(6, a); + kernel.setArg(7, (int32_t)lda); + kernel.setArg(8, b); + kernel.setArg(9, (int32_t)ldb); + kernel.setArg(10, (GemmScalar)beta); + kernel.setArg(11, c); + kernel.setArg(12, (int32_t)ldc); + + size_t TS = 32; // Must be the same as in the kernel (line 1 of gemm.cu) + + cl::NDRange globalWorkSize = + cl::NDRange(((n - 1) / TS + 1) * TS, ((m - 1) / TS + 1) * TS); + cl::NDRange localWorkSize = cl::NDRange(TS, TS); + + auto status = global::openCLQueue.enqueueNDRangeKernel( + kernel, cl::NullRange, globalWorkSize, localWorkSize); + + LIBRAPID_ASSERT(status == CL_SUCCESS, + "cl::CommandQueue::enqueueNDRangeKernel GEMM call failed: {}", + opencl::getOpenCLErrorString(status)); + } + } #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - struct CuBLASGemmComputeType { - cublasComputeType_t computeType; - cublasDataType_t scaleType; - }; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE CuBLASGemmComputeType - cublasGemmComputeType(cublasDataType_t a, cublasDataType_t b, cublasDataType_t c) { - // A simple lambda to select the correct compute type from the two options -# if defined(LIBRAPID_FAST_MATH) - constexpr auto selector = [](CuBLASGemmComputeType fast, CuBLASGemmComputeType) { - return fast; - }; -# else - constexpr auto selector = [](CuBLASGemmComputeType, CuBLASGemmComputeType precise) { - return precise; - }; -# endif - - LIBRAPID_ASSERT(a == b, "Types of A and B must be the same"); - LIBRAPID_ASSERT(a == c, "Output type must be the same as input types"); - - // If provided with different types, work off of the "minimum" type (i.e. the lowest - // precision) - switch (::librapid::min(a, b, c)) { - case CUDA_R_16F: - case CUDA_C_16F: // 16-bit -> 16-bit - return selector({CUBLAS_COMPUTE_16F, CUDA_R_16F}, - {CUBLAS_COMPUTE_16F_PEDANTIC, CUDA_R_16F}); - case CUDA_R_32F: - case CUDA_C_32F: // 32-bit -> [ fast: 16-bit, precise: 32-bit ] - return selector({CUBLAS_COMPUTE_32F_FAST_TF32, CUDA_R_32F}, - {CUBLAS_COMPUTE_32F_PEDANTIC, CUDA_R_32F}); - case CUDA_R_64F: - case CUDA_C_64F: // 64-bit -> 64-bit - return selector({CUBLAS_COMPUTE_64F, CUDA_R_64F}, - {CUBLAS_COMPUTE_64F_PEDANTIC, CUDA_R_64F}); - case CUDA_R_32I: - case CUDA_C_32I: // 32-bit -> 32-bit - return selector({CUBLAS_COMPUTE_32I, CUDA_R_32I}, - {CUBLAS_COMPUTE_32I_PEDANTIC, CUDA_R_32I}); - default: { - LIBRAPID_ASSERT(false, "Invalid input types to CuBLAS gemm"); - return {CUBLAS_COMPUTE_32F_FAST_TF32, CUDA_R_32F}; - } - } - } - - template - void gemm(bool transA, bool transB, Int m, Int n, Int k, Alpha alpha, A *a, Int lda, B *b, - Int ldb, Beta beta, C *c, Int ldc, backend::CUDA) { - if constexpr (typetraits::IsBlasType::value && typetraits::IsBlasType::value && - typetraits::IsBlasType::value) { - // Using the cuBLAS LT API - - cublasLtMatmulDesc_t operationDescriptor = nullptr; - cublasLtMatrixLayout_t descriptorA = nullptr, descriptorB = nullptr, - descriptorC = nullptr; - cublasLtMatmulPreference_t preference = NULL; - - // Configure the maximum number of algorithms to try - constexpr int maxHeuristicResults = 32; - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResults[maxHeuristicResults] = {}; - - cublasOperation_t cublasTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cublasTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - - // Get the CUDA types for the input and output matrices - cudaDataType_t cudaTypeA = typetraits::TypeInfo::CudaType; - cudaDataType_t cudaTypeB = typetraits::TypeInfo::CudaType; - cudaDataType_t cudaTypeC = typetraits::TypeInfo::CudaType; - - // Create operation descriptors - auto [computeType, scaleType] = cublasGemmComputeType(cudaTypeA, cudaTypeB, cudaTypeC); - cublasSafeCall(cublasLtMatmulDescCreate(&operationDescriptor, computeType, scaleType)); - cublasSafeCall(cublasLtMatmulDescSetAttribute(operationDescriptor, - CUBLASLT_MATMUL_DESC_TRANSA, - &cublasTransA, - sizeof(cublasTransA))); - cublasSafeCall(cublasLtMatmulDescSetAttribute(operationDescriptor, - CUBLASLT_MATMUL_DESC_TRANSB, - &cublasTransB, - sizeof(cublasTransB))); - - // Create matrix descriptors - cublasSafeCall(cublasLtMatrixLayoutCreate( - &descriptorA, cudaTypeA, !transA ? m : k, !transA ? k : m, lda)); - cublasSafeCall(cublasLtMatrixLayoutCreate( - &descriptorB, cudaTypeB, !transB ? k : n, !transB ? n : k, ldb)); - cublasSafeCall(cublasLtMatrixLayoutCreate(&descriptorC, cudaTypeC, m, n, ldc)); - - // Set layout attributes - const cublasLtOrder_t order = CUBLASLT_ORDER_ROW; - cublasSafeCall(cublasLtMatrixLayoutSetAttribute( - descriptorA, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order))); - cublasSafeCall(cublasLtMatrixLayoutSetAttribute( - descriptorB, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order))); - cublasSafeCall(cublasLtMatrixLayoutSetAttribute( - descriptorC, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order))); - - // Create preference handle - cublasSafeCall(cublasLtMatmulPreferenceCreate(&preference)); - cublasSafeCall( - cublasLtMatmulPreferenceSetAttribute(preference, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &global::cublasLtWorkspaceSize, - sizeof(global::cublasLtWorkspaceSize))); - - // Find the best algorithm to use for the given problem - cublasSafeCall(cublasLtMatmulAlgoGetHeuristic(global::cublasLtHandle, - operationDescriptor, - descriptorA, - descriptorB, - descriptorC, - descriptorC, - preference, - maxHeuristicResults, - &heuristicResults[0], - &returnedResults)); - - LIBRAPID_ASSERT(returnedResults != 0, "Invalid matrices for GEMM. No algorithm found."); - - // Execute the first valid algorithm - size_t i = 0; - for (; i < returnedResults; ++i) { - if (heuristicResults[i].state == CUBLAS_STATUS_SUCCESS) { - cublasSafeCall(cublasLtMatmul(global::cublasLtHandle, - operationDescriptor, - &alpha, - a, - descriptorA, - b, - descriptorB, - &beta, - c, - descriptorC, - c, - descriptorC, - &heuristicResults[i].algo, - global::cublasLtWorkspace, - global::cublasLtWorkspaceSize, - global::cudaStream)); - break; - } - } - - LIBRAPID_ASSERT(i != returnedResults, "Invalid matrices for GEMM. No algorithm found."); - - // Cleanup - cublasSafeCall(cublasLtMatmulPreferenceDestroy(preference)); - cublasSafeCall(cublasLtMatrixLayoutDestroy(descriptorA)); - cublasSafeCall(cublasLtMatrixLayoutDestroy(descriptorB)); - cublasSafeCall(cublasLtMatrixLayoutDestroy(descriptorC)); - cublasSafeCall(cublasLtMatmulDescDestroy(operationDescriptor)); - } else { - // If the provided types are not supported by cuBLAS, use the custom fallback kernel - - jitify::Program program = global::jitCache.program( - cuda::loadKernel( - fmt::format("{}/include/librapid/array/linalg/level3/gemm", LIBRAPID_SOURCE), - false), - {}, - {fmt::format("-I{}", CUDA_INCLUDE_DIRS)}); - - size_t TS = 32; - - dim3 threadsPerBlock(TS, TS); - dim3 numBlocks((n + TS - 1) / TS, (m + TS - 1) / TS); - - jitifyCall(program.kernel("gemm") - .instantiate(jitify::reflection::Type(), - jitify::reflection::Type(), - jitify::reflection::Type(), - jitify::reflection::Type(), - jitify::reflection::Type(), - jitify::reflection::Type()) - .configure(numBlocks, threadsPerBlock, 0, global::cudaStream) - .launch(transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)); - } - } + struct CuBLASGemmComputeType { + cublasComputeType_t computeType; + cublasDataType_t scaleType; + }; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE CuBLASGemmComputeType + cublasGemmComputeType(cublasDataType_t a, cublasDataType_t b, cublasDataType_t c) { + // A simple lambda to select the correct compute type from the two options +# if defined(LIBRAPID_FAST_MATH) + constexpr auto selector = [](CuBLASGemmComputeType fast, CuBLASGemmComputeType) { + return fast; + }; +# else + constexpr auto selector = [](CuBLASGemmComputeType, CuBLASGemmComputeType precise) { + return precise; + }; +# endif + + LIBRAPID_ASSERT(a == b, "Types of A and B must be the same"); + LIBRAPID_ASSERT(a == c, "Output type must be the same as input types"); + + // If provided with different types, work off of the "minimum" type (i.e. the lowest + // precision) + switch (::librapid::min(a, b, c)) { + case CUDA_R_16F: + case CUDA_C_16F: // 16-bit -> 16-bit + return selector({CUBLAS_COMPUTE_16F, CUDA_R_16F}, + {CUBLAS_COMPUTE_16F_PEDANTIC, CUDA_R_16F}); + case CUDA_R_32F: + case CUDA_C_32F: // 32-bit -> [ fast: 16-bit, precise: 32-bit ] + return selector({CUBLAS_COMPUTE_32F_FAST_TF32, CUDA_R_32F}, + {CUBLAS_COMPUTE_32F_PEDANTIC, CUDA_R_32F}); + case CUDA_R_64F: + case CUDA_C_64F: // 64-bit -> 64-bit + return selector({CUBLAS_COMPUTE_64F, CUDA_R_64F}, + {CUBLAS_COMPUTE_64F_PEDANTIC, CUDA_R_64F}); + case CUDA_R_32I: + case CUDA_C_32I: // 32-bit -> 32-bit + return selector({CUBLAS_COMPUTE_32I, CUDA_R_32I}, + {CUBLAS_COMPUTE_32I_PEDANTIC, CUDA_R_32I}); + default: { + LIBRAPID_ASSERT(false, "Invalid input types to CuBLAS gemm"); + return {CUBLAS_COMPUTE_32F_FAST_TF32, CUDA_R_32F}; + } + } + } + + template + void gemm(bool transA, bool transB, Int m, Int n, Int k, Alpha alpha, A *a, Int lda, B *b, + Int ldb, Beta beta, C *c, Int ldc, backend::CUDA) { + if constexpr (typetraits::IsBlasType::value && typetraits::IsBlasType::value && + typetraits::IsBlasType::value) { + // Using the cuBLAS LT API + + cublasLtMatmulDesc_t operationDescriptor = nullptr; + cublasLtMatrixLayout_t descriptorA = nullptr, descriptorB = nullptr, + descriptorC = nullptr; + cublasLtMatmulPreference_t preference = NULL; + + // Configure the maximum number of algorithms to try + constexpr int maxHeuristicResults = 32; + int returnedResults = 0; + cublasLtMatmulHeuristicResult_t heuristicResults[maxHeuristicResults] = {}; + + cublasOperation_t cublasTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cublasTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + // Get the CUDA types for the input and output matrices + cudaDataType_t cudaTypeA = typetraits::TypeInfo::CudaType; + cudaDataType_t cudaTypeB = typetraits::TypeInfo::CudaType; + cudaDataType_t cudaTypeC = typetraits::TypeInfo::CudaType; + + // Create operation descriptors + auto [computeType, scaleType] = cublasGemmComputeType(cudaTypeA, cudaTypeB, cudaTypeC); + cublasSafeCall(cublasLtMatmulDescCreate(&operationDescriptor, computeType, scaleType)); + cublasSafeCall(cublasLtMatmulDescSetAttribute(operationDescriptor, + CUBLASLT_MATMUL_DESC_TRANSA, + &cublasTransA, + sizeof(cublasTransA))); + cublasSafeCall(cublasLtMatmulDescSetAttribute(operationDescriptor, + CUBLASLT_MATMUL_DESC_TRANSB, + &cublasTransB, + sizeof(cublasTransB))); + + // Create matrix descriptors + cublasSafeCall(cublasLtMatrixLayoutCreate( + &descriptorA, cudaTypeA, !transA ? m : k, !transA ? k : m, lda)); + cublasSafeCall(cublasLtMatrixLayoutCreate( + &descriptorB, cudaTypeB, !transB ? k : n, !transB ? n : k, ldb)); + cublasSafeCall(cublasLtMatrixLayoutCreate(&descriptorC, cudaTypeC, m, n, ldc)); + + // Set layout attributes + const cublasLtOrder_t order = CUBLASLT_ORDER_ROW; + cublasSafeCall(cublasLtMatrixLayoutSetAttribute( + descriptorA, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order))); + cublasSafeCall(cublasLtMatrixLayoutSetAttribute( + descriptorB, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order))); + cublasSafeCall(cublasLtMatrixLayoutSetAttribute( + descriptorC, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(order))); + + // Create preference handle + cublasSafeCall(cublasLtMatmulPreferenceCreate(&preference)); + cublasSafeCall( + cublasLtMatmulPreferenceSetAttribute(preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &global::cublasLtWorkspaceSize, + sizeof(global::cublasLtWorkspaceSize))); + + // Find the best algorithm to use for the given problem + cublasSafeCall(cublasLtMatmulAlgoGetHeuristic(global::cublasLtHandle, + operationDescriptor, + descriptorA, + descriptorB, + descriptorC, + descriptorC, + preference, + maxHeuristicResults, + &heuristicResults[0], + &returnedResults)); + + LIBRAPID_ASSERT(returnedResults != 0, "Invalid matrices for GEMM. No algorithm found."); + + // Execute the first valid algorithm + size_t i = 0; + for (; i < returnedResults; ++i) { + if (heuristicResults[i].state == CUBLAS_STATUS_SUCCESS) { + cublasSafeCall(cublasLtMatmul(global::cublasLtHandle, + operationDescriptor, + &alpha, + a, + descriptorA, + b, + descriptorB, + &beta, + c, + descriptorC, + c, + descriptorC, + &heuristicResults[i].algo, + global::cublasLtWorkspace, + global::cublasLtWorkspaceSize, + global::cudaStream)); + break; + } + } + + LIBRAPID_ASSERT(i != returnedResults, "Invalid matrices for GEMM. No algorithm found."); + + // Cleanup + cublasSafeCall(cublasLtMatmulPreferenceDestroy(preference)); + cublasSafeCall(cublasLtMatrixLayoutDestroy(descriptorA)); + cublasSafeCall(cublasLtMatrixLayoutDestroy(descriptorB)); + cublasSafeCall(cublasLtMatrixLayoutDestroy(descriptorC)); + cublasSafeCall(cublasLtMatmulDescDestroy(operationDescriptor)); + } else { + // If the provided types are not supported by cuBLAS, use the custom fallback kernel + + jitify::Program program = global::jitCache.program( + cuda::loadKernel( + fmt::format("{}/include/librapid/array/linalg/level3/gemm", LIBRAPID_SOURCE), + false), + {}, + {fmt::format("-I{}", CUDA_INCLUDE_DIRS)}); + + size_t TS = 32; + + dim3 threadsPerBlock(TS, TS); + dim3 numBlocks((n + TS - 1) / TS, (m + TS - 1) / TS); + + jitifyCall(program.kernel("gemm") + .instantiate(jitify::reflection::Type(), + jitify::reflection::Type(), + jitify::reflection::Type(), + jitify::reflection::Type(), + jitify::reflection::Type(), + jitify::reflection::Type()) + .configure(numBlocks, threadsPerBlock, 0, global::cudaStream) + .launch(transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)); + } + } #endif // LIBRAPID_HAS_CUDA } // namespace librapid::linalg diff --git a/librapid/include/librapid/array/linalg/linalg.hpp b/librapid/include/librapid/array/linalg/linalg.hpp index 52aa7f1b..b57f2fd6 100644 --- a/librapid/include/librapid/array/linalg/linalg.hpp +++ b/librapid/include/librapid/array/linalg/linalg.hpp @@ -2,24 +2,24 @@ #define LIBRAPID_ARRAY_LINALG namespace librapid::typetraits { - template - struct IsBlasType : std::false_type {}; + template + struct IsBlasType : std::false_type {}; - template<> - struct IsBlasType : std::true_type {}; + template<> + struct IsBlasType : std::true_type {}; - template<> - struct IsBlasType : std::true_type {}; + template<> + struct IsBlasType : std::true_type {}; - template<> - struct IsBlasType : std::true_type {}; + template<> + struct IsBlasType : std::true_type {}; - template<> - struct IsBlasType> : std::true_type {}; + template<> + struct IsBlasType> : std::true_type {}; - template<> - struct IsBlasType> : std::true_type {}; -} + template<> + struct IsBlasType> : std::true_type {}; +} // namespace librapid::typetraits #include "transpose.hpp" diff --git a/librapid/include/librapid/array/linalg/transpose.hpp b/librapid/include/librapid/array/linalg/transpose.hpp index c1bccfd8..4d2dd7c4 100644 --- a/librapid/include/librapid/array/linalg/transpose.hpp +++ b/librapid/include/librapid/array/linalg/transpose.hpp @@ -2,720 +2,716 @@ #define LIBRAPID_ARRAY_TRANSPOSE_HPP namespace librapid { - namespace typetraits { - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction; - using Scalar = typename TypeInfo>::Scalar; - using Backend = typename TypeInfo>::Backend; - static constexpr bool allowVectorisation = false; - }; - - LIBRAPID_DEFINE_AS_TYPE(typename T, array::Transpose); - } // namespace typetraits - - namespace kernels { + namespace typetraits { + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction; + using Scalar = typename TypeInfo>::Scalar; + using Backend = typename TypeInfo>::Backend; + static constexpr bool allowVectorisation = false; + }; + + LIBRAPID_DEFINE_AS_TYPE(typename T, array::Transpose); + } // namespace typetraits + + namespace kernels { #if defined(LIBRAPID_NATIVE_ARCH) -# if !defined(LIBRAPID_APPLE) && LIBRAPID_ARCH >= AVX2 -# define LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE 4 -# define LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE 8 - - template - LIBRAPID_ALWAYS_INLINE void transposeFloatKernel(float *__restrict out, - float *__restrict in, Alpha alpha, - int64_t cols) { - __m256 r0, r1, r2, r3, r4, r5, r6, r7; - __m256 t0, t1, t2, t3, t4, t5, t6, t7; - -# define LOAD256_IMPL(LEFT_, RIGHT_) \ - _mm256_insertf128_ps( \ - _mm256_castps128_ps256(_mm_loadu_ps(&(LEFT_))), _mm_loadu_ps(&(RIGHT_)), 1) - - r0 = LOAD256_IMPL(in[0 * cols + 0], in[4 * cols + 0]); - r1 = LOAD256_IMPL(in[1 * cols + 0], in[5 * cols + 0]); - r2 = LOAD256_IMPL(in[2 * cols + 0], in[6 * cols + 0]); - r3 = LOAD256_IMPL(in[3 * cols + 0], in[7 * cols + 0]); - r4 = LOAD256_IMPL(in[0 * cols + 4], in[4 * cols + 4]); - r5 = LOAD256_IMPL(in[1 * cols + 4], in[5 * cols + 4]); - r6 = LOAD256_IMPL(in[2 * cols + 4], in[6 * cols + 4]); - r7 = LOAD256_IMPL(in[3 * cols + 4], in[7 * cols + 4]); - -# undef LOAD256_IMPL - - t0 = _mm256_unpacklo_ps(r0, r1); - t1 = _mm256_unpackhi_ps(r0, r1); - t2 = _mm256_unpacklo_ps(r2, r3); - t3 = _mm256_unpackhi_ps(r2, r3); - t4 = _mm256_unpacklo_ps(r4, r5); - t5 = _mm256_unpackhi_ps(r4, r5); - t6 = _mm256_unpacklo_ps(r6, r7); - t7 = _mm256_unpackhi_ps(r6, r7); - - __m256 v; - - v = _mm256_shuffle_ps(t0, t2, 0x4E); - r0 = _mm256_blend_ps(t0, v, 0xCC); - r1 = _mm256_blend_ps(t2, v, 0x33); - - v = _mm256_shuffle_ps(t1, t3, 0x4E); - r2 = _mm256_blend_ps(t1, v, 0xCC); - r3 = _mm256_blend_ps(t3, v, 0x33); - - v = _mm256_shuffle_ps(t4, t6, 0x4E); - r4 = _mm256_blend_ps(t4, v, 0xCC); - r5 = _mm256_blend_ps(t6, v, 0x33); - - v = _mm256_shuffle_ps(t5, t7, 0x4E); - r6 = _mm256_blend_ps(t5, v, 0xCC); - r7 = _mm256_blend_ps(t7, v, 0x33); - - __m256 alphaVec = _mm256_set1_ps(alpha); - - _mm256_store_ps(&out[0 * cols], _mm256_mul_ps(r0, alphaVec)); - _mm256_store_ps(&out[1 * cols], _mm256_mul_ps(r1, alphaVec)); - _mm256_store_ps(&out[2 * cols], _mm256_mul_ps(r2, alphaVec)); - _mm256_store_ps(&out[3 * cols], _mm256_mul_ps(r3, alphaVec)); - _mm256_store_ps(&out[4 * cols], _mm256_mul_ps(r4, alphaVec)); - _mm256_store_ps(&out[5 * cols], _mm256_mul_ps(r5, alphaVec)); - _mm256_store_ps(&out[6 * cols], _mm256_mul_ps(r6, alphaVec)); - _mm256_store_ps(&out[7 * cols], _mm256_mul_ps(r7, alphaVec)); - } - - template - LIBRAPID_ALWAYS_INLINE void transposeDoubleKernel(double *__restrict out, - double *__restrict in, Alpha alpha, - int64_t cols) { - __m256d r0, r1, r2, r3; - __m256d t0, t1, t2, t3; - -# define LOAD256_IMPL(LEFT_, RIGHT_) \ - _mm256_insertf128_pd( \ - _mm256_castpd128_pd256(_mm_loadu_pd(&(LEFT_))), _mm_loadu_pd(&(RIGHT_)), 1) - - r0 = LOAD256_IMPL(in[0 * cols + 0], in[2 * cols + 0]); - r1 = LOAD256_IMPL(in[1 * cols + 0], in[3 * cols + 0]); - r2 = LOAD256_IMPL(in[0 * cols + 2], in[2 * cols + 2]); - r3 = LOAD256_IMPL(in[1 * cols + 2], in[3 * cols + 2]); - -# undef LOAD256_IMPL - - t0 = _mm256_unpacklo_pd(r0, r1); - t1 = _mm256_unpackhi_pd(r0, r1); - t2 = _mm256_unpacklo_pd(r2, r3); - t3 = _mm256_unpackhi_pd(r2, r3); - - __m256d v; - - v = _mm256_shuffle_pd(t0, t2, 0x0); - r0 = _mm256_blend_pd(t0, v, 0xC); - r1 = _mm256_blend_pd(t2, v, 0x3); - - v = _mm256_shuffle_pd(t1, t3, 0x0); - r2 = _mm256_blend_pd(t1, v, 0xC); - r3 = _mm256_blend_pd(t3, v, 0x3); - - __m256d alphaVec = _mm256_set1_pd(alpha); - - _mm256_store_pd(&out[0 * cols], _mm256_mul_pd(r0, alphaVec)); - _mm256_store_pd(&out[1 * cols], _mm256_mul_pd(r1, alphaVec)); - _mm256_store_pd(&out[2 * cols], _mm256_mul_pd(r2, alphaVec)); - _mm256_store_pd(&out[3 * cols], _mm256_mul_pd(r3, alphaVec)); - } -# elif !defined(LIBRAPID_APPLE) && LIBRAPID_ARCH >= SSE2 - -# define LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE 2 -# define LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE 4 - - template - LIBRAPID_ALWAYS_INLINE void transposeFloatKernel(float *__restrict out, - float *__restrict in, Alpha alpha, - int64_t cols) { - __m128 tmp3, tmp2, tmp1, tmp0; - - tmp0 = _mm_shuffle_ps(_mm_load_ps(in + 0 * cols), _mm_load_ps(in + 1 * cols), 0x44); - tmp2 = _mm_shuffle_ps(_mm_load_ps(in + 0 * cols), _mm_load_ps(in + 1 * cols), 0xEE); - tmp1 = _mm_shuffle_ps(_mm_load_ps(in + 2 * cols), _mm_load_ps(in + 3 * cols), 0x44); - tmp3 = _mm_shuffle_ps(_mm_load_ps(in + 2 * cols), _mm_load_ps(in + 3 * cols), 0xEE); - - __m128 alphaVec = _mm_set1_ps(alpha); - - _mm_store_ps(out + 0 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp0, tmp1, 0x88), alphaVec)); - _mm_store_ps(out + 1 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp0, tmp1, 0xDD), alphaVec)); - _mm_store_ps(out + 2 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp2, tmp3, 0x88), alphaVec)); - _mm_store_ps(out + 3 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp2, tmp3, 0xDD), alphaVec)); - } - - template - LIBRAPID_ALWAYS_INLINE void transposeDoubleKernel(double *__restrict out, - double *__restrict in, Alpha alpha, - int64_t cols) { - __m128d tmp0, tmp1; - - // Load the values from input matrix - tmp0 = _mm_load_pd(in + 0 * cols); - tmp1 = _mm_load_pd(in + 1 * cols); - - // Transpose the 2x2 matrix - __m128d tmp0Unpck = _mm_unpacklo_pd(tmp0, tmp1); - __m128d tmp1Unpck = _mm_unpackhi_pd(tmp0, tmp1); - - // Store the transposed values in the output matrix - __m128d alphaVec = _mm_set1_pd(alpha); - _mm_store_pd(out + 0 * cols, _mm_mul_pd(tmp0Unpck, alphaVec)); - _mm_store_pd(out + 1 * cols, _mm_mul_pd(tmp1Unpck, alphaVec)); - } - -# endif // LIBRAPID_MSVC -#endif // LIBRAPID_NATIVE_ARCH - } // namespace kernels - - namespace detail { - namespace cpu { - template - LIBRAPID_ALWAYS_INLINE void - transposeImpl(Scalar *__restrict out, const Scalar *__restrict in, int64_t rows, - int64_t cols, Alpha alpha, int64_t blockSize) { +# if !defined(LIBRAPID_APPLE) && LIBRAPID_ARCH >= AVX2 +# define LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE 4 +# define LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE 8 + + template + LIBRAPID_ALWAYS_INLINE void transposeFloatKernel(float *__restrict out, + float *__restrict in, Alpha alpha, + int64_t cols) { + __m256 r0, r1, r2, r3, r4, r5, r6, r7; + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + +# define LOAD256_IMPL(LEFT_, RIGHT_) \ + _mm256_insertf128_ps( \ + _mm256_castps128_ps256(_mm_loadu_ps(&(LEFT_))), _mm_loadu_ps(&(RIGHT_)), 1) + + r0 = LOAD256_IMPL(in[0 * cols + 0], in[4 * cols + 0]); + r1 = LOAD256_IMPL(in[1 * cols + 0], in[5 * cols + 0]); + r2 = LOAD256_IMPL(in[2 * cols + 0], in[6 * cols + 0]); + r3 = LOAD256_IMPL(in[3 * cols + 0], in[7 * cols + 0]); + r4 = LOAD256_IMPL(in[0 * cols + 4], in[4 * cols + 4]); + r5 = LOAD256_IMPL(in[1 * cols + 4], in[5 * cols + 4]); + r6 = LOAD256_IMPL(in[2 * cols + 4], in[6 * cols + 4]); + r7 = LOAD256_IMPL(in[3 * cols + 4], in[7 * cols + 4]); + +# undef LOAD256_IMPL + + t0 = _mm256_unpacklo_ps(r0, r1); + t1 = _mm256_unpackhi_ps(r0, r1); + t2 = _mm256_unpacklo_ps(r2, r3); + t3 = _mm256_unpackhi_ps(r2, r3); + t4 = _mm256_unpacklo_ps(r4, r5); + t5 = _mm256_unpackhi_ps(r4, r5); + t6 = _mm256_unpacklo_ps(r6, r7); + t7 = _mm256_unpackhi_ps(r6, r7); + + __m256 v; + + v = _mm256_shuffle_ps(t0, t2, 0x4E); + r0 = _mm256_blend_ps(t0, v, 0xCC); + r1 = _mm256_blend_ps(t2, v, 0x33); + + v = _mm256_shuffle_ps(t1, t3, 0x4E); + r2 = _mm256_blend_ps(t1, v, 0xCC); + r3 = _mm256_blend_ps(t3, v, 0x33); + + v = _mm256_shuffle_ps(t4, t6, 0x4E); + r4 = _mm256_blend_ps(t4, v, 0xCC); + r5 = _mm256_blend_ps(t6, v, 0x33); + + v = _mm256_shuffle_ps(t5, t7, 0x4E); + r6 = _mm256_blend_ps(t5, v, 0xCC); + r7 = _mm256_blend_ps(t7, v, 0x33); + + __m256 alphaVec = _mm256_set1_ps(alpha); + + _mm256_store_ps(&out[0 * cols], _mm256_mul_ps(r0, alphaVec)); + _mm256_store_ps(&out[1 * cols], _mm256_mul_ps(r1, alphaVec)); + _mm256_store_ps(&out[2 * cols], _mm256_mul_ps(r2, alphaVec)); + _mm256_store_ps(&out[3 * cols], _mm256_mul_ps(r3, alphaVec)); + _mm256_store_ps(&out[4 * cols], _mm256_mul_ps(r4, alphaVec)); + _mm256_store_ps(&out[5 * cols], _mm256_mul_ps(r5, alphaVec)); + _mm256_store_ps(&out[6 * cols], _mm256_mul_ps(r6, alphaVec)); + _mm256_store_ps(&out[7 * cols], _mm256_mul_ps(r7, alphaVec)); + } + + template + LIBRAPID_ALWAYS_INLINE void transposeDoubleKernel(double *__restrict out, + double *__restrict in, Alpha alpha, + int64_t cols) { + __m256d r0, r1, r2, r3; + __m256d t0, t1, t2, t3; + +# define LOAD256_IMPL(LEFT_, RIGHT_) \ + _mm256_insertf128_pd( \ + _mm256_castpd128_pd256(_mm_loadu_pd(&(LEFT_))), _mm_loadu_pd(&(RIGHT_)), 1) + + r0 = LOAD256_IMPL(in[0 * cols + 0], in[2 * cols + 0]); + r1 = LOAD256_IMPL(in[1 * cols + 0], in[3 * cols + 0]); + r2 = LOAD256_IMPL(in[0 * cols + 2], in[2 * cols + 2]); + r3 = LOAD256_IMPL(in[1 * cols + 2], in[3 * cols + 2]); + +# undef LOAD256_IMPL + + t0 = _mm256_unpacklo_pd(r0, r1); + t1 = _mm256_unpackhi_pd(r0, r1); + t2 = _mm256_unpacklo_pd(r2, r3); + t3 = _mm256_unpackhi_pd(r2, r3); + + __m256d v; + + v = _mm256_shuffle_pd(t0, t2, 0x0); + r0 = _mm256_blend_pd(t0, v, 0xC); + r1 = _mm256_blend_pd(t2, v, 0x3); + + v = _mm256_shuffle_pd(t1, t3, 0x0); + r2 = _mm256_blend_pd(t1, v, 0xC); + r3 = _mm256_blend_pd(t3, v, 0x3); + + __m256d alphaVec = _mm256_set1_pd(alpha); + + _mm256_store_pd(&out[0 * cols], _mm256_mul_pd(r0, alphaVec)); + _mm256_store_pd(&out[1 * cols], _mm256_mul_pd(r1, alphaVec)); + _mm256_store_pd(&out[2 * cols], _mm256_mul_pd(r2, alphaVec)); + _mm256_store_pd(&out[3 * cols], _mm256_mul_pd(r3, alphaVec)); + } +# elif !defined(LIBRAPID_APPLE) && LIBRAPID_ARCH >= SSE2 + +# define LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE 2 +# define LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE 4 + + template + LIBRAPID_ALWAYS_INLINE void transposeFloatKernel(float *__restrict out, + float *__restrict in, Alpha alpha, + int64_t cols) { + __m128 tmp3, tmp2, tmp1, tmp0; + + tmp0 = _mm_shuffle_ps(_mm_load_ps(in + 0 * cols), _mm_load_ps(in + 1 * cols), 0x44); + tmp2 = _mm_shuffle_ps(_mm_load_ps(in + 0 * cols), _mm_load_ps(in + 1 * cols), 0xEE); + tmp1 = _mm_shuffle_ps(_mm_load_ps(in + 2 * cols), _mm_load_ps(in + 3 * cols), 0x44); + tmp3 = _mm_shuffle_ps(_mm_load_ps(in + 2 * cols), _mm_load_ps(in + 3 * cols), 0xEE); + + __m128 alphaVec = _mm_set1_ps(alpha); + + _mm_store_ps(out + 0 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp0, tmp1, 0x88), alphaVec)); + _mm_store_ps(out + 1 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp0, tmp1, 0xDD), alphaVec)); + _mm_store_ps(out + 2 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp2, tmp3, 0x88), alphaVec)); + _mm_store_ps(out + 3 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp2, tmp3, 0xDD), alphaVec)); + } + + template + LIBRAPID_ALWAYS_INLINE void transposeDoubleKernel(double *__restrict out, + double *__restrict in, Alpha alpha, + int64_t cols) { + __m128d tmp0, tmp1; + + // Load the values from input matrix + tmp0 = _mm_load_pd(in + 0 * cols); + tmp1 = _mm_load_pd(in + 1 * cols); + + // Transpose the 2x2 matrix + __m128d tmp0Unpck = _mm_unpacklo_pd(tmp0, tmp1); + __m128d tmp1Unpck = _mm_unpackhi_pd(tmp0, tmp1); + + // Store the transposed values in the output matrix + __m128d alphaVec = _mm_set1_pd(alpha); + _mm_store_pd(out + 0 * cols, _mm_mul_pd(tmp0Unpck, alphaVec)); + _mm_store_pd(out + 1 * cols, _mm_mul_pd(tmp1Unpck, alphaVec)); + } + +# endif // LIBRAPID_MSVC +#endif // LIBRAPID_NATIVE_ARCH + } // namespace kernels + + namespace detail { + namespace cpu { + template + LIBRAPID_ALWAYS_INLINE void + transposeImpl(Scalar *__restrict out, const Scalar *__restrict in, int64_t rows, + int64_t cols, Alpha alpha, int64_t blockSize) { #if !defined(LIBRAPID_OPTIMISE_SMALL_ARRAYS) - if (rows * cols > global::multithreadThreshold) { -# pragma omp parallel for shared(rows, cols, blockSize, in, out, alpha) default(none) \ - num_threads((int)global::numThreads) - for (int64_t i = 0; i < rows; i += blockSize) { - for (int64_t j = 0; j < cols; j += blockSize) { - for (int64_t row = i; row < i + blockSize && row < rows; ++row) { - for (int64_t col = j; col < j + blockSize && col < cols; ++col) { - out[col * rows + row] = in[row * cols + col] * alpha; - } - } - } - } - } else + if (rows * cols > global::multithreadThreshold) { +# pragma omp parallel for shared(rows, cols, blockSize, in, out, alpha) default(none) \ + num_threads((int)global::numThreads) + for (int64_t i = 0; i < rows; i += blockSize) { + for (int64_t j = 0; j < cols; j += blockSize) { + for (int64_t row = i; row < i + blockSize && row < rows; ++row) { + for (int64_t col = j; col < j + blockSize && col < cols; ++col) { + out[col * rows + row] = in[row * cols + col] * alpha; + } + } + } + } + } else #endif // LIBRAPID_OPTIMISE_SMALL_ARRAYS - { - for (int64_t i = 0; i < rows; i += blockSize) { - for (int64_t j = 0; j < cols; j += blockSize) { - for (int64_t row = i; row < i + blockSize && row < rows; ++row) { - for (int64_t col = j; col < j + blockSize && col < cols; ++col) { - out[col * rows + row] = in[row * cols + col] * alpha; - } - } - } - } - } - } + { + for (int64_t i = 0; i < rows; i += blockSize) { + for (int64_t j = 0; j < cols; j += blockSize) { + for (int64_t row = i; row < i + blockSize && row < rows; ++row) { + for (int64_t col = j; col < j + blockSize && col < cols; ++col) { + out[col * rows + row] = in[row * cols + col] * alpha; + } + } + } + } + } + } #if LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE > 0 - template - LIBRAPID_ALWAYS_INLINE void transposeImpl(float *__restrict out, float *__restrict in, - int64_t rows, int64_t cols, Alpha alpha, - int64_t) { - constexpr int64_t blockSize = LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE; - -# if !defined(LIBRAPID_OPTIMISE_SMALL_ARRAYS) - if (rows * cols > global::multithreadThreshold) { -# pragma omp parallel for shared(rows, cols, in, out, alpha) default(none) \ - num_threads((int)global::numThreads) - for (int64_t i = 0; i < rows; i += blockSize) { - for (int64_t j = 0; j < cols; j += blockSize) { - if (i + blockSize <= rows && j + blockSize <= cols) { - kernels::transposeFloatKernel( - &out[j * rows + i], &in[i * cols + j], alpha, rows); - } else { - for (int64_t row = i; row < i + blockSize && row < rows; ++row) { - for (int64_t col = j; col < j + blockSize && col < cols; - ++col) { - out[col * rows + row] = in[row * cols + col]; - } - } - } - } - } - } else -# endif - { - for (int64_t i = 0; i < rows; i += blockSize) { - for (int64_t j = 0; j < cols; j += blockSize) { - if (i + blockSize <= rows && j + blockSize <= cols) { - kernels::transposeFloatKernel( - &out[j * rows + i], &in[i * cols + j], alpha, rows); - } else { - for (int64_t row = i; row < i + blockSize && row < rows; ++row) { - for (int64_t col = j; col < j + blockSize && col < cols; - ++col) { - out[col * rows + row] = in[row * cols + col]; - } - } - } - } - } - } - } + template + LIBRAPID_ALWAYS_INLINE void transposeImpl(float *__restrict out, float *__restrict in, + int64_t rows, int64_t cols, Alpha alpha, + int64_t) { + constexpr int64_t blockSize = LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE; + +# if !defined(LIBRAPID_OPTIMISE_SMALL_ARRAYS) + if (rows * cols > global::multithreadThreshold) { +# pragma omp parallel for shared(rows, cols, in, out, alpha) default(none) \ + num_threads((int)global::numThreads) + for (int64_t i = 0; i < rows; i += blockSize) { + for (int64_t j = 0; j < cols; j += blockSize) { + if (i + blockSize <= rows && j + blockSize <= cols) { + kernels::transposeFloatKernel( + &out[j * rows + i], &in[i * cols + j], alpha, rows); + } else { + for (int64_t row = i; row < i + blockSize && row < rows; ++row) { + for (int64_t col = j; col < j + blockSize && col < cols; + ++col) { + out[col * rows + row] = in[row * cols + col]; + } + } + } + } + } + } else +# endif + { + for (int64_t i = 0; i < rows; i += blockSize) { + for (int64_t j = 0; j < cols; j += blockSize) { + if (i + blockSize <= rows && j + blockSize <= cols) { + kernels::transposeFloatKernel( + &out[j * rows + i], &in[i * cols + j], alpha, rows); + } else { + for (int64_t row = i; row < i + blockSize && row < rows; ++row) { + for (int64_t col = j; col < j + blockSize && col < cols; + ++col) { + out[col * rows + row] = in[row * cols + col]; + } + } + } + } + } + } + } #endif // LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE > 0 #if LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE > 0 - template - LIBRAPID_ALWAYS_INLINE void transposeImpl(double *__restrict out, double *__restrict in, - int64_t rows, int64_t cols, Alpha alpha, - int64_t) { - constexpr int64_t blockSize = LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE; - -# if !defined(LIBRAPID_OPTIMISE_SMALL_ARRAYS) - if (rows * cols > global::multithreadThreshold) { -# pragma omp parallel for shared(rows, cols, in, out, alpha) default(none) \ - num_threads((int)global::numThreads) - for (int64_t i = 0; i < rows; i += blockSize) { - for (int64_t j = 0; j < cols; j += blockSize) { - if (i + blockSize <= rows && j + blockSize <= cols) { - kernels::transposeDoubleKernel( - &out[j * rows + i], &in[i * cols + j], alpha, rows); - } else { - for (int64_t row = i; row < i + blockSize && row < rows; ++row) { - for (int64_t col = j; col < j + blockSize && col < cols; - ++col) { - out[col * rows + row] = in[row * cols + col] * alpha; - } - } - } - } - } - } else -# endif // LIBRAPID_OPTIMISE_SMALL_ARRAYS - { - for (int64_t i = 0; i < rows; i += blockSize) { - for (int64_t j = 0; j < cols; j += blockSize) { - if (i + blockSize <= rows && j + blockSize <= cols) { - kernels::transposeDoubleKernel( - &out[j * rows + i], &in[i * cols + j], alpha, rows); - } else { - for (int64_t row = i; row < i + blockSize && row < rows; ++row) { - for (int64_t col = j; col < j + blockSize && col < cols; - ++col) { - out[col * rows + row] = in[row * cols + col] * alpha; - } - } - } - } - } - } - } -#endif // LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE > 0 - } // namespace cpu + template + LIBRAPID_ALWAYS_INLINE void transposeImpl(double *__restrict out, double *__restrict in, + int64_t rows, int64_t cols, Alpha alpha, + int64_t) { + constexpr int64_t blockSize = LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE; + +# if !defined(LIBRAPID_OPTIMISE_SMALL_ARRAYS) + if (rows * cols > global::multithreadThreshold) { +# pragma omp parallel for shared(rows, cols, in, out, alpha) default(none) \ + num_threads((int)global::numThreads) + for (int64_t i = 0; i < rows; i += blockSize) { + for (int64_t j = 0; j < cols; j += blockSize) { + if (i + blockSize <= rows && j + blockSize <= cols) { + kernels::transposeDoubleKernel( + &out[j * rows + i], &in[i * cols + j], alpha, rows); + } else { + for (int64_t row = i; row < i + blockSize && row < rows; ++row) { + for (int64_t col = j; col < j + blockSize && col < cols; + ++col) { + out[col * rows + row] = in[row * cols + col] * alpha; + } + } + } + } + } + } else +# endif // LIBRAPID_OPTIMISE_SMALL_ARRAYS + { + for (int64_t i = 0; i < rows; i += blockSize) { + for (int64_t j = 0; j < cols; j += blockSize) { + if (i + blockSize <= rows && j + blockSize <= cols) { + kernels::transposeDoubleKernel( + &out[j * rows + i], &in[i * cols + j], alpha, rows); + } else { + for (int64_t row = i; row < i + blockSize && row < rows; ++row) { + for (int64_t col = j; col < j + blockSize && col < cols; + ++col) { + out[col * rows + row] = in[row * cols + col] * alpha; + } + } + } + } + } + } + } +#endif // LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE > 0 + } // namespace cpu #if defined(LIBRAPID_HAS_OPENCL) - namespace opencl { - template - LIBRAPID_ALWAYS_INLINE void transposeImpl(cl::Buffer &out, const cl::Buffer &in, - int64_t rows, int64_t cols, Alpha alpha, - int64_t) { - std::string kernelName = - fmt::format("transpose_{}", typetraits::TypeInfo::name); - cl::Kernel kernel(global::openCLProgram, kernelName.c_str()); - kernel.setArg(0, out); - kernel.setArg(1, in); - kernel.setArg(2, int(rows)); - kernel.setArg(3, int(cols)); - kernel.setArg(4, Scalar(alpha)); - int TILE_DIM = 16; - cl::NDRange global((cols + TILE_DIM - 1) / TILE_DIM * TILE_DIM, - (rows + TILE_DIM - 1) / TILE_DIM * TILE_DIM); - cl::NDRange local(TILE_DIM, TILE_DIM); - auto ret = - global::openCLQueue.enqueueNDRangeKernel(kernel, cl::NullRange, global, local); - LIBRAPID_ASSERT(ret == CL_SUCCESS, "OpenCL kernel failed"); - } - } // namespace opencl + namespace opencl { + template + LIBRAPID_ALWAYS_INLINE void transposeImpl(cl::Buffer &out, const cl::Buffer &in, + int64_t rows, int64_t cols, Alpha alpha, + int64_t) { + std::string kernelName = + fmt::format("transpose_{}", typetraits::TypeInfo::name); + cl::Kernel kernel(global::openCLProgram, kernelName.c_str()); + kernel.setArg(0, out); + kernel.setArg(1, in); + kernel.setArg(2, int(rows)); + kernel.setArg(3, int(cols)); + kernel.setArg(4, Scalar(alpha)); + int TILE_DIM = 16; + cl::NDRange global((cols + TILE_DIM - 1) / TILE_DIM * TILE_DIM, + (rows + TILE_DIM - 1) / TILE_DIM * TILE_DIM); + cl::NDRange local(TILE_DIM, TILE_DIM); + auto ret = + global::openCLQueue.enqueueNDRangeKernel(kernel, cl::NullRange, global, local); + LIBRAPID_ASSERT(ret == CL_SUCCESS, "OpenCL kernel failed"); + } + } // namespace opencl #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - namespace cuda { - template - LIBRAPID_ALWAYS_INLINE void transposeImpl(Scalar *__restrict out, Scalar *__restrict in, - int64_t rows, int64_t cols, Alpha alpha, - int64_t blockSize) { - LIBRAPID_NOT_IMPLEMENTED - } - - template - LIBRAPID_ALWAYS_INLINE void transposeImpl(float *__restrict out, float *__restrict in, - int64_t rows, int64_t cols, Alpha alpha, - int64_t) { - float zero = 0.0f; - cublasSafeCall(cublasSgeam(global::cublasHandle, - CUBLAS_OP_T, - CUBLAS_OP_N, - (int)rows, - (int)cols, - &alpha, - in, - (int)cols, - &zero, - in, - (int)cols, - out, - rows)); - } - - template - LIBRAPID_ALWAYS_INLINE void transposeImpl(double *__restrict out, double *__restrict in, - int64_t rows, int64_t cols, Alpha alpha, - int64_t) { - double zero = 0.0; - cublasSafeCall(cublasDgeam(global::cublasHandle, - CUBLAS_OP_T, - CUBLAS_OP_N, - rows, - cols, - &alpha, - in, - cols, - &zero, - in, - cols, - out, - rows)); - } - - template - LIBRAPID_ALWAYS_INLINE void transposeImpl(Complex *__restrict out, - Complex *__restrict in, int64_t rows, - int64_t cols, Complex alpha, int64_t) { - cuComplex alphaCu {alpha.real(), alpha.imag()}; - cuComplex zero {0.0f, 0.0f}; - cublasSafeCall(cublasCgeam(global::cublasHandle, - CUBLAS_OP_T, - CUBLAS_OP_N, - rows, - cols, - &alphaCu, - reinterpret_cast(in), - cols, - &zero, - reinterpret_cast(in), - cols, - reinterpret_cast(out), - rows)); - } - - template - LIBRAPID_ALWAYS_INLINE void transposeImpl(Complex *__restrict out, - Complex *__restrict in, int64_t rows, - int64_t cols, Complex alpha, int64_t) { - cuDoubleComplex alphaCu {alpha.real(), alpha.imag()}; - cuDoubleComplex zero {0.0, 0.0}; - cublasSafeCall(cublasZgeam(global::cublasHandle, - CUBLAS_OP_T, - CUBLAS_OP_N, - rows, - cols, - &alphaCu, - reinterpret_cast(in), - cols, - &zero, - reinterpret_cast(in), - cols, - reinterpret_cast(out), - rows)); - } - } // namespace cuda -#endif // LIBRAPID_HAS_CUDA - } // namespace detail - - namespace array { - template - class Transpose { - public: - using ArrayType = T; - using BaseType = typename std::decay_t; - using Scalar = typename typetraits::TypeInfo::Scalar; - using Reference = BaseType &; - using ConstReference = const BaseType &; - using ShapeType = typename BaseType::ShapeType; - using Backend = typename typetraits::TypeInfo::Backend; - - static constexpr bool allowVectorisation = - typetraits::TypeInfo::allowVectorisation; - static constexpr bool isArray = typetraits::IsArrayContainer::value; - static constexpr bool isHost = std::is_same_v; - static constexpr bool isOpenCL = std::is_same_v; - static constexpr bool isCUDA = std::is_same_v; - - /// Default constructor should never be used - Transpose() = delete; - - /// Create a Transpose object from an array/operation - /// \param array The array to copy - /// \param axes The transposition axes - Transpose(const T &array, const ShapeType &axes, Scalar alpha = Scalar(1.0)); - - /// Copy a Transpose object - Transpose(const Transpose &other) = default; - - /// Move constructor - Transpose(Transpose &&other) noexcept = default; - - /// Assign another Transpose object to this one - /// \param other The Transpose to assign - /// \return *this; - Transpose &operator=(const Transpose &other) = default; - - /// Access sub-array of this Transpose object - /// \param index Array index - /// \return ArrayView - ArrayView operator[](int64_t index) const; - - /// Get the shape of this Transpose object - /// \return Shape - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const; - - /// Return the number of dimensions of the Transpose object - /// \return Number of dimensions - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const; - - /// Access a scalar at a given index in the object. The index will be converted into - /// a multi-dimensional index using the shape of the object, and counts in row-major - /// order - /// \param index Index of the scalar - /// \return Scalar type at the given index - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalar(int64_t index) const; - - /// \brief Return the axes of the Transpose object - /// \return `ShapeType` containing the axes - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const ShapeType &axes() const; - - /// \brief Return the alpha value of the Transpose object - /// \return Alpha scaling factor - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const Scalar &alpha() const; - - /// \brief Return the untransposed array object - /// \return Array object - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const ArrayType &array() const; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ArrayType &array(); - - template - LIBRAPID_ALWAYS_INLINE void applyTo(ArrayRef &out) const; - - /// Evaluate the Transpose object and return the result. Depending on your use case, - /// calling this function mid-expression might result in better performance, but you - /// should always test the available options before making a decision. - /// \return Evaluated expression - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const; - - /// Return a string representation of the Transpose object, formatting each scalar with - /// the given format string - /// \param format Format string - /// \return Stringified object - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; - - private: - ArrayType m_array; - ShapeType m_inputShape; - ShapeType m_outputShape; - ShapeType m_axes; - Scalar m_alpha; - }; - - template - Transpose::Transpose(const T &array, const ShapeType &axes, Scalar alpha) : - m_array(array), m_inputShape(array.shape()), m_axes(axes), m_alpha(alpha) { - LIBRAPID_ASSERT(m_inputShape.ndim() == m_axes.ndim(), - "Shape and axes must have the same number of dimensions"); - - m_outputShape = m_inputShape; - for (size_t i = 0; i < m_inputShape.ndim(); i++) { - m_outputShape[i] = m_inputShape[m_axes[i]]; - } - } - - template - auto Transpose::shape() const -> ShapeType { - return m_outputShape; - } - - template - auto Transpose::ndim() const -> int64_t { - return m_outputShape.ndim(); - } - - template - auto Transpose::axes() const -> const ShapeType & { - return m_axes; - } - - template - auto Transpose::alpha() const -> const Scalar & { - return m_alpha; - } - - template - auto Transpose::array() const -> const ArrayType & { - return m_array; - } - - template - auto Transpose::array() -> ArrayType & { - return m_array; - } - - template - template - void Transpose::applyTo(ArrayRef &out) const { - bool inplace = ((void *)&out) == ((void *)&m_array); - LIBRAPID_ASSERT(!inplace, "Cannot transpose inplace"); - LIBRAPID_ASSERT(out.shape() == m_outputShape, "Transpose assignment shape mismatch"); - - if constexpr (isArray) { - if constexpr (isHost) { - auto *__restrict outPtr = out.storage().begin(); - auto *__restrict inPtr = m_array.storage().begin(); - int64_t blockSize = global::cacheLineSize / sizeof(Scalar); - - if (m_inputShape.ndim() == 2) { - detail::cpu::transposeImpl( - outPtr, inPtr, m_inputShape[0], m_inputShape[1], m_alpha, blockSize); - - } else { - LIBRAPID_NOT_IMPLEMENTED - } - } + namespace cuda { + template + LIBRAPID_ALWAYS_INLINE void transposeImpl(Scalar *__restrict out, Scalar *__restrict in, + int64_t rows, int64_t cols, Alpha alpha, + int64_t blockSize) { + LIBRAPID_NOT_IMPLEMENTED + } + + template + LIBRAPID_ALWAYS_INLINE void transposeImpl(float *__restrict out, float *__restrict in, + int64_t rows, int64_t cols, Alpha alpha, + int64_t) { + float zero = 0.0f; + cublasSafeCall(cublasSgeam(global::cublasHandle, + CUBLAS_OP_T, + CUBLAS_OP_N, + (int)rows, + (int)cols, + &alpha, + in, + (int)cols, + &zero, + in, + (int)cols, + out, + rows)); + } + + template + LIBRAPID_ALWAYS_INLINE void transposeImpl(double *__restrict out, double *__restrict in, + int64_t rows, int64_t cols, Alpha alpha, + int64_t) { + double zero = 0.0; + cublasSafeCall(cublasDgeam(global::cublasHandle, + CUBLAS_OP_T, + CUBLAS_OP_N, + rows, + cols, + &alpha, + in, + cols, + &zero, + in, + cols, + out, + rows)); + } + + template + LIBRAPID_ALWAYS_INLINE void transposeImpl(Complex *__restrict out, + Complex *__restrict in, int64_t rows, + int64_t cols, Complex alpha, int64_t) { + cuComplex alphaCu {alpha.real(), alpha.imag()}; + cuComplex zero {0.0f, 0.0f}; + cublasSafeCall(cublasCgeam(global::cublasHandle, + CUBLAS_OP_T, + CUBLAS_OP_N, + rows, + cols, + &alphaCu, + reinterpret_cast(in), + cols, + &zero, + reinterpret_cast(in), + cols, + reinterpret_cast(out), + rows)); + } + + template + LIBRAPID_ALWAYS_INLINE void transposeImpl(Complex *__restrict out, + Complex *__restrict in, int64_t rows, + int64_t cols, Complex alpha, int64_t) { + cuDoubleComplex alphaCu {alpha.real(), alpha.imag()}; + cuDoubleComplex zero {0.0, 0.0}; + cublasSafeCall(cublasZgeam(global::cublasHandle, + CUBLAS_OP_T, + CUBLAS_OP_N, + rows, + cols, + &alphaCu, + reinterpret_cast(in), + cols, + &zero, + reinterpret_cast(in), + cols, + reinterpret_cast(out), + rows)); + } + } // namespace cuda +#endif // LIBRAPID_HAS_CUDA + } // namespace detail + + namespace array { + template + class Transpose { + public: + using ArrayType = T; + using BaseType = typename std::decay_t; + using Scalar = typename typetraits::TypeInfo::Scalar; + using Reference = BaseType &; + using ConstReference = const BaseType &; + using ShapeType = typename BaseType::ShapeType; + using Backend = typename typetraits::TypeInfo::Backend; + + static constexpr bool allowVectorisation = + typetraits::TypeInfo::allowVectorisation; + static constexpr bool isArray = typetraits::IsArrayContainer::value; + static constexpr bool isHost = std::is_same_v; + static constexpr bool isOpenCL = std::is_same_v; + static constexpr bool isCUDA = std::is_same_v; + + /// Default constructor should never be used + Transpose() = delete; + + /// Create a Transpose object from an array/operation + /// \param array The array to copy + /// \param axes The transposition axes + Transpose(const T &array, const ShapeType &axes, Scalar alpha = Scalar(1.0)); + + /// Copy a Transpose object + Transpose(const Transpose &other) = default; + + /// Move constructor + Transpose(Transpose &&other) noexcept = default; + + /// Assign another Transpose object to this one + /// \param other The Transpose to assign + /// \return *this; + Transpose &operator=(const Transpose &other) = default; + + /// Access sub-array of this Transpose object + /// \param index Array index + /// \return ArrayView + ArrayView operator[](int64_t index) const; + + /// Get the shape of this Transpose object + /// \return Shape + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const; + + /// Return the number of dimensions of the Transpose object + /// \return Number of dimensions + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const; + + /// Access a scalar at a given index in the object. The index will be converted into + /// a multi-dimensional index using the shape of the object, and counts in row-major + /// order + /// \param index Index of the scalar + /// \return Scalar type at the given index + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto scalar(int64_t index) const; + + /// \brief Return the axes of the Transpose object + /// \return `ShapeType` containing the axes + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const ShapeType &axes() const; + + /// \brief Return the alpha value of the Transpose object + /// \return Alpha scaling factor + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const Scalar &alpha() const; + + /// \brief Return the untransposed array object + /// \return Array object + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const ArrayType &array() const; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ArrayType &array(); + + template + LIBRAPID_ALWAYS_INLINE void applyTo(ArrayRef &out) const; + + /// Evaluate the Transpose object and return the result. Depending on your use case, + /// calling this function mid-expression might result in better performance, but you + /// should always test the available options before making a decision. + /// \return Evaluated expression + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const; + + /// Return a string representation of the Transpose object, formatting each scalar with + /// the given format string + /// \param format Format string + /// \return Stringified object + LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; + + private: + ArrayType m_array; + ShapeType m_inputShape; + ShapeType m_outputShape; + ShapeType m_axes; + Scalar m_alpha; + }; + + template + Transpose::Transpose(const T &array, const ShapeType &axes, Scalar alpha) : + m_array(array), m_inputShape(array.shape()), m_axes(axes), m_alpha(alpha) { + LIBRAPID_ASSERT(m_inputShape.ndim() == m_axes.ndim(), + "Shape and axes must have the same number of dimensions"); + + m_outputShape = m_inputShape; + for (size_t i = 0; i < m_inputShape.ndim(); i++) { + m_outputShape[i] = m_inputShape[m_axes[i]]; + } + } + + template + auto Transpose::shape() const -> ShapeType { + return m_outputShape; + } + + template + auto Transpose::ndim() const -> int64_t { + return m_outputShape.ndim(); + } + + template + auto Transpose::axes() const -> const ShapeType & { + return m_axes; + } + + template + auto Transpose::alpha() const -> const Scalar & { + return m_alpha; + } + + template + auto Transpose::array() const -> const ArrayType & { + return m_array; + } + + template + auto Transpose::array() -> ArrayType & { + return m_array; + } + + template + template + void Transpose::applyTo(ArrayRef &out) const { + bool inplace = ((void *)&out) == ((void *)&m_array); + LIBRAPID_ASSERT(!inplace, "Cannot transpose inplace"); + LIBRAPID_ASSERT(out.shape() == m_outputShape, "Transpose assignment shape mismatch"); + + if constexpr (isArray) { + if constexpr (isHost) { + auto *__restrict outPtr = out.storage().begin(); + auto *__restrict inPtr = m_array.storage().begin(); + int64_t blockSize = global::cacheLineSize / sizeof(Scalar); + + if (m_inputShape.ndim() == 2) { + detail::cpu::transposeImpl( + outPtr, inPtr, m_inputShape[0], m_inputShape[1], m_alpha, blockSize); + + } else { + LIBRAPID_NOT_IMPLEMENTED + } + } #if defined(LIBRAPID_HAS_OPENCL) - else if constexpr (isOpenCL) { - cl::Buffer &outBuffer = out.storage().data(); - const cl::Buffer &inBuffer = m_array.storage().data(); - - if (m_inputShape.ndim() == 2) { - detail::opencl::transposeImpl( - outBuffer, inBuffer, m_inputShape[0], m_inputShape[1], m_alpha, 0); - } else { - LIBRAPID_NOT_IMPLEMENTED - } - } + else if constexpr (isOpenCL) { + cl::Buffer &outBuffer = out.storage().data(); + const cl::Buffer &inBuffer = m_array.storage().data(); + + if (m_inputShape.ndim() == 2) { + detail::opencl::transposeImpl( + outBuffer, inBuffer, m_inputShape[0], m_inputShape[1], m_alpha, 0); + } else { + LIBRAPID_NOT_IMPLEMENTED + } + } #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - else { - if (m_inputShape.ndim() == 2) { - int64_t blockSize = global::cacheLineSize / sizeof(Scalar); - auto *__restrict outPtr = out.storage().begin().get(); - auto *__restrict inPtr = m_array.storage().begin().get(); - detail::cuda::transposeImpl( - outPtr, inPtr, m_inputShape[0], m_inputShape[1], m_alpha, blockSize); - } else { - LIBRAPID_NOT_IMPLEMENTED - } - } + else { + if (m_inputShape.ndim() == 2) { + int64_t blockSize = global::cacheLineSize / sizeof(Scalar); + auto *__restrict outPtr = out.storage().begin().get(); + auto *__restrict inPtr = m_array.storage().begin().get(); + detail::cuda::transposeImpl( + outPtr, inPtr, m_inputShape[0], m_inputShape[1], m_alpha, blockSize); + } else { + LIBRAPID_NOT_IMPLEMENTED + } + } #endif // LIBRAPID_HAS_CUDA - } else { - LIBRAPID_NOT_IMPLEMENTED - } - } - - template - auto Transpose::eval() const { - using NonConstArrayType = std::remove_const_t; - NonConstArrayType res(m_outputShape); - applyTo(res); - return res; - } - - template - std::string Transpose::str(const std::string &format) const { - return eval().str(format); - } - }; // namespace array - - template, - typename std::enable_if_t::type == - detail::LibRapidType::ArrayContainer, int> = 0> - auto transpose(T &&array, const ShapeType &axes = ShapeType()) { - // If axes is empty, transpose the array in reverse order - ShapeType newAxes = axes; - if (axes.ndim() == 0) { - newAxes = ShapeType::zeros(array.ndim()); - for (size_t i = 0; i < array.ndim(); i++) { - newAxes[i] = array.ndim() - i - 1; - } - } - - return array::Transpose(array, newAxes); - } - - template, - typename std::enable_if_t::type != - detail::LibRapidType::ArrayContainer, int> = 0> - auto transpose(const T &function, const ShapeType &axes = ShapeType()) { - // If axes is empty, transpose the array in reverse order - auto array = function.eval(); - ShapeType newAxes = axes; - if (axes.ndim() == 0) { - newAxes = ShapeType::zeros(array.ndim()); - for (size_t i = 0; i < array.ndim(); i++) { - newAxes[i] = array.ndim() - i - 1; - } - } - - return array::Transpose(array, newAxes); - } - - namespace typetraits { - template - struct HasCustomEval, ScalarType>> - : std::true_type {}; - - template - struct HasCustomEval>> : std::true_type {}; - }; // namespace typetraits - - namespace detail { - // If assigning an operation of the form aT * b, where a is a matrix and b is a scalar, - // we can replace this with Transpose(a, b) to get better performance - - // aT * b - template - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer &destination, - const Function, - ScalarType> &function) { - auto axes = std::get<0>(function.args()).axes(); - auto alpha = std::get<0>(function.args()).alpha(); - destination = array::Transpose( - std::get<0>(function.args()).array(), axes, alpha * std::get<1>(function.args())); - } - - template - LIBRAPID_ALWAYS_INLINE void - assignParallel(array::ArrayContainer &destination, - const Function, - ScalarType> &function) { - // The assign function runs in parallel if possible by default, so just call that - assign(destination, function); - } - - // a * bT - template - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer &destination, - const Function> &function) { - auto axes = std::get<1>(function.args()).axes(); - auto alpha = std::get<1>(function.args()).alpha(); - destination = array::Transpose( - std::get<1>(function.args()).array(), axes, alpha * std::get<0>(function.args())); - } - - template - LIBRAPID_ALWAYS_INLINE void - assignParallel(array::ArrayContainer &destination, - const Function> &function) { - assign(destination, function); - } - } // namespace detail + } else { + LIBRAPID_NOT_IMPLEMENTED + } + } + + template + auto Transpose::eval() const { + using NonConstArrayType = std::remove_const_t; + NonConstArrayType res(m_outputShape); + applyTo(res); + return res; + } + + template + std::string Transpose::str(const std::string &format) const { + return eval().str(format); + } + }; // namespace array + + template, + typename std::enable_if_t< + typetraits::TypeInfo::type == detail::LibRapidType::ArrayContainer, int> = 0> + auto transpose(T &&array, const ShapeType &axes = ShapeType()) { + // If axes is empty, transpose the array in reverse order + ShapeType newAxes = axes; + if (axes.ndim() == 0) { + newAxes = ShapeType::zeros(array.ndim()); + for (size_t i = 0; i < array.ndim(); i++) { newAxes[i] = array.ndim() - i - 1; } + } + + return array::Transpose(array, newAxes); + } + + template, + typename std::enable_if_t< + typetraits::TypeInfo::type != detail::LibRapidType::ArrayContainer, int> = 0> + auto transpose(const T &function, const ShapeType &axes = ShapeType()) { + // If axes is empty, transpose the array in reverse order + auto array = function.eval(); + ShapeType newAxes = axes; + if (axes.ndim() == 0) { + newAxes = ShapeType::zeros(array.ndim()); + for (size_t i = 0; i < array.ndim(); i++) { newAxes[i] = array.ndim() - i - 1; } + } + + return array::Transpose(array, newAxes); + } + + namespace typetraits { + template + struct HasCustomEval, ScalarType>> + : std::true_type {}; + + template + struct HasCustomEval>> : std::true_type {}; + }; // namespace typetraits + + namespace detail { + // If assigning an operation of the form aT * b, where a is a matrix and b is a scalar, + // we can replace this with Transpose(a, b) to get better performance + + // aT * b + template + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer &destination, + const Function, + ScalarType> &function) { + auto axes = std::get<0>(function.args()).axes(); + auto alpha = std::get<0>(function.args()).alpha(); + destination = array::Transpose( + std::get<0>(function.args()).array(), axes, alpha * std::get<1>(function.args())); + } + + template + LIBRAPID_ALWAYS_INLINE void + assignParallel(array::ArrayContainer &destination, + const Function, + ScalarType> &function) { + // The assign function runs in parallel if possible by default, so just call that + assign(destination, function); + } + + // a * bT + template + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer &destination, + const Function> &function) { + auto axes = std::get<1>(function.args()).axes(); + auto alpha = std::get<1>(function.args()).alpha(); + destination = array::Transpose( + std::get<1>(function.args()).array(), axes, alpha * std::get<0>(function.args())); + } + + template + LIBRAPID_ALWAYS_INLINE void + assignParallel(array::ArrayContainer &destination, + const Function> &function) { + assign(destination, function); + } + } // namespace detail } // namespace librapid // Support FMT printing diff --git a/librapid/include/librapid/array/operations.hpp b/librapid/include/librapid/array/operations.hpp index 690dab6d..ca0d3512 100644 --- a/librapid/include/librapid/array/operations.hpp +++ b/librapid/include/librapid/array/operations.hpp @@ -2,1082 +2,1082 @@ #define LIBRAPID_ARRAY_OPERATIONS_HPP #define LIBRAPID_BINARY_FUNCTOR(NAME_, OP_) \ - struct NAME_ { \ - template \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const T &lhs, \ - const V &rhs) const { \ - return lhs OP_ rhs; \ - } \ + struct NAME_ { \ + template \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const T &lhs, \ + const V &rhs) const { \ + return lhs OP_ rhs; \ + } \ \ - template \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto packet(const Packet &lhs, \ - const Packet &rhs) const { \ - return lhs OP_ rhs; \ - } \ - } + template \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto packet(const Packet &lhs, \ + const Packet &rhs) const { \ + return lhs OP_ rhs; \ + } \ + } #define LIBRAPID_BINARY_COMPARISON_FUNCTOR(NAME_, OP_) \ - struct NAME_ { \ - template \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const T &lhs, \ - const V &rhs) const { \ - return (typename std::common_type_t)(lhs OP_ rhs); \ - } \ + struct NAME_ { \ + template \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const T &lhs, \ + const V &rhs) const { \ + return (typename std::common_type_t)(lhs OP_ rhs); \ + } \ \ - template \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto packet(const Packet &lhs, \ - const Packet &rhs) const { \ - return Packet(lhs OP_ rhs); \ - } \ - } + template \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto packet(const Packet &lhs, \ + const Packet &rhs) const { \ + return Packet(lhs OP_ rhs); \ + } \ + } #define LIBRAPID_UNARY_KERNEL_GETTER \ - template \ - static constexpr const char *getKernelName(std::tuple args) { \ - static_assert(sizeof...(Args) == 1, "Invalid number of arguments for unary operation"); \ - return kernelName; \ - } + template \ + static constexpr const char *getKernelName(std::tuple args) { \ + static_assert(sizeof...(Args) == 1, "Invalid number of arguments for unary operation"); \ + return kernelName; \ + } #define LIBRAPID_BINARY_KERNEL_GETTER \ - template \ - static constexpr const char *getKernelNameImpl(std::tuple args) { \ - if constexpr (TypeInfo>::type != detail::LibRapidType::Scalar && \ - TypeInfo>::type != detail::LibRapidType::Scalar) { \ - return kernelName; \ - } else if constexpr (TypeInfo>::type == detail::LibRapidType::Scalar) { \ - return kernelNameScalarLhs; \ - } else if constexpr (TypeInfo>::type == detail::LibRapidType::Scalar) { \ - return kernelNameScalarRhs; \ - } else { \ - return kernelName; \ - } \ - } \ + template \ + static constexpr const char *getKernelNameImpl(std::tuple args) { \ + if constexpr (TypeInfo>::type != detail::LibRapidType::Scalar && \ + TypeInfo>::type != detail::LibRapidType::Scalar) { \ + return kernelName; \ + } else if constexpr (TypeInfo>::type == detail::LibRapidType::Scalar) { \ + return kernelNameScalarLhs; \ + } else if constexpr (TypeInfo>::type == detail::LibRapidType::Scalar) { \ + return kernelNameScalarRhs; \ + } else { \ + return kernelName; \ + } \ + } \ \ - template \ - static constexpr const char *getKernelName(std::tuple args) { \ - static_assert(sizeof...(Args) == 2, "Invalid number of arguments for binary operation"); \ - return getKernelNameImpl(args); \ - } + template \ + static constexpr const char *getKernelName(std::tuple args) { \ + static_assert(sizeof...(Args) == 2, "Invalid number of arguments for binary operation"); \ + return getKernelNameImpl(args); \ + } #define LIBRAPID_UNARY_SHAPE_EXTRACTOR \ - template \ - LIBRAPID_NODISCARD static LIBRAPID_ALWAYS_INLINE auto getShape( \ - const std::tuple &args) { \ - static_assert(sizeof...(Args) == 1, "Invalid number of arguments for unary operation"); \ - return std::get<0>(args).shape(); \ - } + template \ + LIBRAPID_NODISCARD static LIBRAPID_ALWAYS_INLINE auto getShape( \ + const std::tuple &args) { \ + static_assert(sizeof...(Args) == 1, "Invalid number of arguments for unary operation"); \ + return std::get<0>(args).shape(); \ + } #define LIBRAPID_BINARY_SHAPE_EXTRACTOR \ - template \ - LIBRAPID_NODISCARD static LIBRAPID_ALWAYS_INLINE auto getShapeImpl( \ - const std::tuple &tup) { \ - if constexpr (TypeInfo>::type != detail::LibRapidType::Scalar && \ - TypeInfo>::type != detail::LibRapidType::Scalar) { \ - LIBRAPID_ASSERT(std::get<0>(tup).shape() == std::get<1>(tup).shape(), \ - "Shapes must match for binary operations"); \ - return std::get<0>(tup).shape(); \ - } else if constexpr (TypeInfo>::type == \ - detail::LibRapidType::Scalar) { \ - return std::get<1>(tup).shape(); \ - } else { \ - return std::get<0>(tup).shape(); \ - } \ - } \ + template \ + LIBRAPID_NODISCARD static LIBRAPID_ALWAYS_INLINE auto getShapeImpl( \ + const std::tuple &tup) { \ + if constexpr (TypeInfo>::type != detail::LibRapidType::Scalar && \ + TypeInfo>::type != detail::LibRapidType::Scalar) { \ + LIBRAPID_ASSERT(std::get<0>(tup).shape() == std::get<1>(tup).shape(), \ + "Shapes must match for binary operations"); \ + return std::get<0>(tup).shape(); \ + } else if constexpr (TypeInfo>::type == \ + detail::LibRapidType::Scalar) { \ + return std::get<1>(tup).shape(); \ + } else { \ + return std::get<0>(tup).shape(); \ + } \ + } \ \ - template \ - LIBRAPID_NODISCARD static LIBRAPID_ALWAYS_INLINE auto getShape( \ - const std::tuple &args) { \ - static_assert(sizeof...(Args) == 2, "Invalid number of arguments for binary operation"); \ - return getShapeImpl(args); \ - } + template \ + LIBRAPID_NODISCARD static LIBRAPID_ALWAYS_INLINE auto getShape( \ + const std::tuple &args) { \ + static_assert(sizeof...(Args) == 2, "Invalid number of arguments for binary operation"); \ + return getShapeImpl(args); \ + } #define LIBRAPID_UNARY_FUNCTOR(NAME, OP) \ - struct NAME { \ - template \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const T &arg) const { \ - return (T)(OP(arg)); \ - } \ + struct NAME { \ + template \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const T &arg) const { \ + return (T)(OP(arg)); \ + } \ \ - template \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto packet(const Packet &arg) const { \ - return OP(arg); \ - } \ - }; + template \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto packet(const Packet &arg) const { \ + return OP(arg); \ + } \ + }; namespace librapid { - namespace detail { - /// Construct a new function object with the given functor type and arguments. - /// \tparam desc Functor descriptor - /// \tparam Functor Function type - /// \tparam Args Argument types - /// \param args Arguments passed to the function - /// \return A new Function instance - template - auto makeFunction(const Args &...args) { - using OperationType = Function; - return OperationType(Functor(), args...); - } - - LIBRAPID_BINARY_FUNCTOR(Plus, +); // a + b - LIBRAPID_BINARY_FUNCTOR(Minus, -); // a - b - LIBRAPID_BINARY_FUNCTOR(Multiply, *); // a * b - LIBRAPID_BINARY_FUNCTOR(Divide, /); // a / b - - LIBRAPID_BINARY_COMPARISON_FUNCTOR(LessThan, <); // a < b - LIBRAPID_BINARY_COMPARISON_FUNCTOR(GreaterThan, >); // a > b - LIBRAPID_BINARY_COMPARISON_FUNCTOR(LessThanEqual, <=); // a <= b - LIBRAPID_BINARY_COMPARISON_FUNCTOR(GreaterThanEqual, >=); // a >= b - LIBRAPID_BINARY_COMPARISON_FUNCTOR(ElementWiseEqual, ==); // a == b - LIBRAPID_BINARY_COMPARISON_FUNCTOR(ElementWiseNotEqual, !=); // a != b - - LIBRAPID_UNARY_FUNCTOR(Neg, -); - - LIBRAPID_UNARY_FUNCTOR(Sin, ::librapid::sin); // sin(a) - LIBRAPID_UNARY_FUNCTOR(Cos, ::librapid::cos); // cos(a) - LIBRAPID_UNARY_FUNCTOR(Tan, ::librapid::tan); // tan(a) - LIBRAPID_UNARY_FUNCTOR(Asin, ::librapid::asin); // asin(a) - LIBRAPID_UNARY_FUNCTOR(Acos, ::librapid::acos); // acos(a) - LIBRAPID_UNARY_FUNCTOR(Atan, ::librapid::atan); // atan(a) - LIBRAPID_UNARY_FUNCTOR(Sinh, ::librapid::sinh); // sinh(a) - LIBRAPID_UNARY_FUNCTOR(Cosh, ::librapid::cosh); // cosh(a) - LIBRAPID_UNARY_FUNCTOR(Tanh, ::librapid::tanh); // tanh(a) - - LIBRAPID_UNARY_FUNCTOR(Exp, ::librapid::exp); // exp(a) - LIBRAPID_UNARY_FUNCTOR(Log, ::librapid::log); // log(a) - LIBRAPID_UNARY_FUNCTOR(Log2, ::librapid::log2); // log2(a) - LIBRAPID_UNARY_FUNCTOR(Log10, ::librapid::log10); // log10(a) - LIBRAPID_UNARY_FUNCTOR(Sqrt, ::librapid::sqrt); // sqrt(a) - LIBRAPID_UNARY_FUNCTOR(Cbrt, ::librapid::cbrt); // cbrt(a) - LIBRAPID_UNARY_FUNCTOR(Abs, ::librapid::abs); // abs(a) - LIBRAPID_UNARY_FUNCTOR(Floor, ::librapid::floor); // floor(a) - LIBRAPID_UNARY_FUNCTOR(Ceil, ::librapid::ceil); // ceil(a) - - } // namespace detail - - namespace typetraits { - /// Merge together two Descriptor types. Two trivial operations will result in - /// another trivial operation, while any other combination will result in a Combined - /// operation. \tparam Descriptor1 The first descriptor \tparam Descriptor2 The - /// second descriptor - template - struct DescriptorMerger { - using Type = ::librapid::detail::descriptor::Combined; - }; - - template - struct DescriptorMerger { - using Type = Descriptor1; - }; - - /// Extracts the Descriptor type of the provided type. - /// \tparam T The type to extract the descriptor from - template - struct DescriptorExtractor { - using Type = ::librapid::detail::descriptor::Trivial; - }; - - /// Extracts the Descriptor type of an ArrayContainer object. In this case, the - /// Descriptor is Trivial \tparam ShapeType The shape type of the ArrayContainer - /// \tparam StorageType The storage type of the ArrayContainer - template - struct DescriptorExtractor> { - using Type = ::librapid::detail::descriptor::Trivial; - }; - - /// Extracts the Descriptor type of an ArrayView object - /// \tparam T The Array type of the ArrayView - template - struct DescriptorExtractor> { - using Type = ::librapid::detail::descriptor::Trivial; - }; - - /// Extracts the Descriptor type of a Function object - /// \tparam Descriptor The descriptor of the Function - /// \tparam Functor The functor type of the Function - /// \tparam Args The argument types of the Function - template - struct DescriptorExtractor<::librapid::detail::Function> { - using Type = Descriptor; - }; - - /// Return the combined Descriptor type of the provided types - /// \tparam First The first type to merge - /// \tparam Rest The remaining types - template - struct DescriptorType; - - namespace impl { - /// A `constexpr` function which supports the DescriptorType for multi-type - /// inputs \tparam Rest \return - template - constexpr auto descriptorExtractor() { - if constexpr (sizeof...(Rest) > 0) { - using ReturnType = typename DescriptorType::Type; - return ReturnType {}; - } else { - using ReturnType = ::librapid::detail::descriptor::Trivial; - return ReturnType {}; - } - } - } // namespace impl - - /// Allows a number of Descriptor types to be merged together into a single - /// Descriptor type. The Descriptors used are extracted from the ***typenames*** of - /// the provided types. \tparam First The first type to merge \tparam Rest The - /// remaining types - template - struct DescriptorType { - using FirstType = std::decay_t; - using FirstDescriptor = typename DescriptorExtractor::Type; - using RestDescriptor = decltype(impl::descriptorExtractor()); - - using Type = typename DescriptorMerger::Type; - }; - - /// A simplification of the DescriptorType to reduce code size - /// \tparam Args Input types - /// \see DescriptorType - template - using DescriptorType_t = typename DescriptorType::Type; - - template<> - struct TypeInfo<::librapid::detail::Plus> { - static constexpr const char *name = "plus"; - static constexpr const char *filename = "arithmetic"; - static constexpr const char *kernelName = "addArrays"; - static constexpr const char *kernelNameScalarRhs = "addArraysScalarRhs"; - static constexpr const char *kernelNameScalarLhs = "addArraysScalarLhs"; - LIBRAPID_BINARY_KERNEL_GETTER - LIBRAPID_BINARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Minus> { - static constexpr const char *name = "minus"; - static constexpr const char *filename = "arithmetic"; - static constexpr const char *kernelName = "subArrays"; - static constexpr const char *kernelNameScalarRhs = "subArraysScalarRhs"; - static constexpr const char *kernelNameScalarLhs = "subArraysScalarLhs"; - LIBRAPID_BINARY_KERNEL_GETTER - LIBRAPID_BINARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Multiply> { - static constexpr const char *name = "multiply"; - static constexpr const char *filename = "arithmetic"; - static constexpr const char *kernelName = "mulArrays"; - static constexpr const char *kernelNameScalarRhs = "mulArraysScalarRhs"; - static constexpr const char *kernelNameScalarLhs = "mulArraysScalarLhs"; - LIBRAPID_BINARY_KERNEL_GETTER - LIBRAPID_BINARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Divide> { - static constexpr const char *name = "divide"; - static constexpr const char *filename = "arithmetic"; - static constexpr const char *kernelName = "divArrays"; - static constexpr const char *kernelNameScalarRhs = "divArraysScalarRhs"; - static constexpr const char *kernelNameScalarLhs = "divArraysScalarLhs"; - LIBRAPID_BINARY_KERNEL_GETTER - LIBRAPID_BINARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::LessThan> { - static constexpr const char *name = "less than"; - static constexpr const char *filename = "arithmetic"; - static constexpr const char *kernelName = "lessThanArrays"; - static constexpr const char *kernelNameScalarRhs = "lessThanArraysScalarRhs"; - static constexpr const char *kernelNameScalarLhs = "lessThanArraysScalarLhs"; - LIBRAPID_BINARY_KERNEL_GETTER - LIBRAPID_BINARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::GreaterThan> { - static constexpr const char *name = "greater than"; - static constexpr const char *filename = "arithmetic"; - static constexpr const char *kernelName = "greaterThanArrays"; - static constexpr const char *kernelNameScalarRhs = "greaterThanArraysScalarRhs"; - static constexpr const char *kernelNameScalarLhs = "greaterThanArraysScalarLhs"; - LIBRAPID_BINARY_KERNEL_GETTER - LIBRAPID_BINARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::LessThanEqual> { - static constexpr const char *name = "less than or equal"; - static constexpr const char *filename = "arithmetic"; - static constexpr const char *kernelName = "lessThanEqualArrays"; - static constexpr const char *kernelNameScalarRhs = "lessThanEqualArraysScalarRhs"; - static constexpr const char *kernelNameScalarLhs = "lessThanEqualArraysScalarLhs"; - LIBRAPID_BINARY_KERNEL_GETTER - LIBRAPID_BINARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::GreaterThanEqual> { - static constexpr const char *name = "greater than or equal"; - static constexpr const char *filename = "arithmetic"; - static constexpr const char *kernelName = "greaterThanEqualArrays"; - static constexpr const char *kernelNameScalarRhs = "greaterThanEqualArraysScalarRhs"; - static constexpr const char *kernelNameScalarLhs = "greaterThanEqualArraysScalarLhs"; - LIBRAPID_BINARY_KERNEL_GETTER - LIBRAPID_BINARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::ElementWiseEqual> { - static constexpr const char *name = "element wise equal"; - static constexpr const char *filename = "arithmetic"; - static constexpr const char *kernelName = "elementWiseEqualArrays"; - static constexpr const char *kernelNameScalarRhs = "elementWiseEqualArraysScalarRhs"; - static constexpr const char *kernelNameScalarLhs = "elementWiseEqualArraysScalarLhs"; - LIBRAPID_BINARY_KERNEL_GETTER - LIBRAPID_BINARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::ElementWiseNotEqual> { - static constexpr const char *name = "element wise not equal"; - static constexpr const char *filename = "arithmetic"; - static constexpr const char *kernelName = "elementWiseNotEqualArrays"; - static constexpr const char *kernelNameScalarRhs = "elementWiseNotEqualArraysScalarRhs"; - static constexpr const char *kernelNameScalarLhs = "elementWiseNotEqualArraysScalarLhs"; - LIBRAPID_BINARY_KERNEL_GETTER - LIBRAPID_BINARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Neg> { - static constexpr const char *name = "negate"; - static constexpr const char *filename = "negate"; - static constexpr const char *kernelName = "negateArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Sin> { - static constexpr const char *name = "sin"; - static constexpr const char *filename = "trigonometry"; - static constexpr const char *kernelName = "sinArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Cos> { - static constexpr const char *name = "cos"; - static constexpr const char *filename = "trigonometry"; - static constexpr const char *kernelName = "cosArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Tan> { - static constexpr const char *name = "tan"; - static constexpr const char *filename = "trigonometry"; - static constexpr const char *kernelName = "tanArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Asin> { - static constexpr const char *name = "arcsin"; - static constexpr const char *filename = "trigonometry"; - static constexpr const char *kernelName = "asinArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Acos> { - static constexpr const char *name = "arcos"; - static constexpr const char *filename = "trigonometry"; - static constexpr const char *kernelName = "acosArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Atan> { - static constexpr const char *name = "arctan"; - static constexpr const char *filename = "trigonometry"; - static constexpr const char *kernelName = "atanArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Sinh> { - static constexpr const char *name = "hyperbolic sine"; - static constexpr const char *filename = "trigonometry"; - static constexpr const char *kernelName = "sinhArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Cosh> { - static constexpr const char *name = "hyperbolic cosine"; - static constexpr const char *filename = "trigonometry"; - static constexpr const char *kernelName = "coshArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Tanh> { - static constexpr const char *name = "hyperbolic tangent"; - static constexpr const char *filename = "trigonometry"; - static constexpr const char *kernelName = "tanhArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Exp> { - static constexpr const char *name = "exponent"; - static constexpr const char *filename = "expLogPow"; - static constexpr const char *kernelName = "expArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Log> { - static constexpr const char *name = "logarithm"; - static constexpr const char *filename = "expLogPow"; - static constexpr const char *kernelName = "logArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Log2> { - static constexpr const char *name = "logarithm base 2"; - static constexpr const char *filename = "expLogPow"; - static constexpr const char *kernelName = "log2Arrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Log10> { - static constexpr const char *name = "logarithm base 10"; - static constexpr const char *filename = "expLogPow"; - static constexpr const char *kernelName = "log10Arrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Sqrt> { - static constexpr const char *name = "square root"; - static constexpr const char *filename = "expLogPow"; - static constexpr const char *kernelName = "sqrtArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Cbrt> { - static constexpr const char *name = "cube root"; - static constexpr const char *filename = "expLogPow"; - static constexpr const char *kernelName = "cbrtArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Abs> { - static constexpr const char *name = "absolute value"; - static constexpr const char *filename = "abs"; - static constexpr const char *kernelName = "absArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Floor> { - static constexpr const char *name = "floor"; - static constexpr const char *filename = "floorCeilRound"; - static constexpr const char *kernelName = "floorArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - - template<> - struct TypeInfo<::librapid::detail::Ceil> { - static constexpr const char *name = "ceiling"; - static constexpr const char *filename = "floorCeilRound"; - static constexpr const char *kernelName = "ceilArrays"; - LIBRAPID_UNARY_KERNEL_GETTER - LIBRAPID_UNARY_SHAPE_EXTRACTOR - }; - } // namespace typetraits - - namespace detail { - template - constexpr bool isArrayOp() { - return (typetraits::IsArrayContainer>::value || - typetraits::IsLibRapidType>::value); - } - - template - constexpr bool isArrayOpArray() { - return (typetraits::TypeInfo>::type != LibRapidType::Scalar) && - (typetraits::TypeInfo>::type != LibRapidType::Scalar) && - typetraits::IsLibRapidType>::value && - typetraits::IsLibRapidType>::value; - } - - template - constexpr bool isArrayOpWithScalar() { - return (typetraits::IsLibRapidType>::value && - typetraits::TypeInfo>::type == LibRapidType::Scalar) || - (typetraits::TypeInfo>::type == LibRapidType::Scalar && - typetraits::IsLibRapidType>::value); - } - } // namespace detail - - namespace array { -#define IS_ARRAY_OP detail::isArrayOp() -#define IS_ARRAY_OP_ARRAY detail::isArrayOpArray() + namespace detail { + /// Construct a new function object with the given functor type and arguments. + /// \tparam desc Functor descriptor + /// \tparam Functor Function type + /// \tparam Args Argument types + /// \param args Arguments passed to the function + /// \return A new Function instance + template + auto makeFunction(const Args &...args) { + using OperationType = Function; + return OperationType(Functor(), args...); + } + + LIBRAPID_BINARY_FUNCTOR(Plus, +); // a + b + LIBRAPID_BINARY_FUNCTOR(Minus, -); // a - b + LIBRAPID_BINARY_FUNCTOR(Multiply, *); // a * b + LIBRAPID_BINARY_FUNCTOR(Divide, /); // a / b + + LIBRAPID_BINARY_COMPARISON_FUNCTOR(LessThan, <); // a < b + LIBRAPID_BINARY_COMPARISON_FUNCTOR(GreaterThan, >); // a > b + LIBRAPID_BINARY_COMPARISON_FUNCTOR(LessThanEqual, <=); // a <= b + LIBRAPID_BINARY_COMPARISON_FUNCTOR(GreaterThanEqual, >=); // a >= b + LIBRAPID_BINARY_COMPARISON_FUNCTOR(ElementWiseEqual, ==); // a == b + LIBRAPID_BINARY_COMPARISON_FUNCTOR(ElementWiseNotEqual, !=); // a != b + + LIBRAPID_UNARY_FUNCTOR(Neg, -); + + LIBRAPID_UNARY_FUNCTOR(Sin, ::librapid::sin); // sin(a) + LIBRAPID_UNARY_FUNCTOR(Cos, ::librapid::cos); // cos(a) + LIBRAPID_UNARY_FUNCTOR(Tan, ::librapid::tan); // tan(a) + LIBRAPID_UNARY_FUNCTOR(Asin, ::librapid::asin); // asin(a) + LIBRAPID_UNARY_FUNCTOR(Acos, ::librapid::acos); // acos(a) + LIBRAPID_UNARY_FUNCTOR(Atan, ::librapid::atan); // atan(a) + LIBRAPID_UNARY_FUNCTOR(Sinh, ::librapid::sinh); // sinh(a) + LIBRAPID_UNARY_FUNCTOR(Cosh, ::librapid::cosh); // cosh(a) + LIBRAPID_UNARY_FUNCTOR(Tanh, ::librapid::tanh); // tanh(a) + + LIBRAPID_UNARY_FUNCTOR(Exp, ::librapid::exp); // exp(a) + LIBRAPID_UNARY_FUNCTOR(Log, ::librapid::log); // log(a) + LIBRAPID_UNARY_FUNCTOR(Log2, ::librapid::log2); // log2(a) + LIBRAPID_UNARY_FUNCTOR(Log10, ::librapid::log10); // log10(a) + LIBRAPID_UNARY_FUNCTOR(Sqrt, ::librapid::sqrt); // sqrt(a) + LIBRAPID_UNARY_FUNCTOR(Cbrt, ::librapid::cbrt); // cbrt(a) + LIBRAPID_UNARY_FUNCTOR(Abs, ::librapid::abs); // abs(a) + LIBRAPID_UNARY_FUNCTOR(Floor, ::librapid::floor); // floor(a) + LIBRAPID_UNARY_FUNCTOR(Ceil, ::librapid::ceil); // ceil(a) + + } // namespace detail + + namespace typetraits { + /// Merge together two Descriptor types. Two trivial operations will result in + /// another trivial operation, while any other combination will result in a Combined + /// operation. \tparam Descriptor1 The first descriptor \tparam Descriptor2 The + /// second descriptor + template + struct DescriptorMerger { + using Type = ::librapid::detail::descriptor::Combined; + }; + + template + struct DescriptorMerger { + using Type = Descriptor1; + }; + + /// Extracts the Descriptor type of the provided type. + /// \tparam T The type to extract the descriptor from + template + struct DescriptorExtractor { + using Type = ::librapid::detail::descriptor::Trivial; + }; + + /// Extracts the Descriptor type of an ArrayContainer object. In this case, the + /// Descriptor is Trivial \tparam ShapeType The shape type of the ArrayContainer + /// \tparam StorageType The storage type of the ArrayContainer + template + struct DescriptorExtractor> { + using Type = ::librapid::detail::descriptor::Trivial; + }; + + /// Extracts the Descriptor type of an ArrayView object + /// \tparam T The Array type of the ArrayView + template + struct DescriptorExtractor> { + using Type = ::librapid::detail::descriptor::Trivial; + }; + + /// Extracts the Descriptor type of a Function object + /// \tparam Descriptor The descriptor of the Function + /// \tparam Functor The functor type of the Function + /// \tparam Args The argument types of the Function + template + struct DescriptorExtractor<::librapid::detail::Function> { + using Type = Descriptor; + }; + + /// Return the combined Descriptor type of the provided types + /// \tparam First The first type to merge + /// \tparam Rest The remaining types + template + struct DescriptorType; + + namespace impl { + /// A `constexpr` function which supports the DescriptorType for multi-type + /// inputs \tparam Rest \return + template + constexpr auto descriptorExtractor() { + if constexpr (sizeof...(Rest) > 0) { + using ReturnType = typename DescriptorType::Type; + return ReturnType {}; + } else { + using ReturnType = ::librapid::detail::descriptor::Trivial; + return ReturnType {}; + } + } + } // namespace impl + + /// Allows a number of Descriptor types to be merged together into a single + /// Descriptor type. The Descriptors used are extracted from the ***typenames*** of + /// the provided types. \tparam First The first type to merge \tparam Rest The + /// remaining types + template + struct DescriptorType { + using FirstType = std::decay_t; + using FirstDescriptor = typename DescriptorExtractor::Type; + using RestDescriptor = decltype(impl::descriptorExtractor()); + + using Type = typename DescriptorMerger::Type; + }; + + /// A simplification of the DescriptorType to reduce code size + /// \tparam Args Input types + /// \see DescriptorType + template + using DescriptorType_t = typename DescriptorType::Type; + + template<> + struct TypeInfo<::librapid::detail::Plus> { + static constexpr const char *name = "plus"; + static constexpr const char *filename = "arithmetic"; + static constexpr const char *kernelName = "addArrays"; + static constexpr const char *kernelNameScalarRhs = "addArraysScalarRhs"; + static constexpr const char *kernelNameScalarLhs = "addArraysScalarLhs"; + LIBRAPID_BINARY_KERNEL_GETTER + LIBRAPID_BINARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Minus> { + static constexpr const char *name = "minus"; + static constexpr const char *filename = "arithmetic"; + static constexpr const char *kernelName = "subArrays"; + static constexpr const char *kernelNameScalarRhs = "subArraysScalarRhs"; + static constexpr const char *kernelNameScalarLhs = "subArraysScalarLhs"; + LIBRAPID_BINARY_KERNEL_GETTER + LIBRAPID_BINARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Multiply> { + static constexpr const char *name = "multiply"; + static constexpr const char *filename = "arithmetic"; + static constexpr const char *kernelName = "mulArrays"; + static constexpr const char *kernelNameScalarRhs = "mulArraysScalarRhs"; + static constexpr const char *kernelNameScalarLhs = "mulArraysScalarLhs"; + LIBRAPID_BINARY_KERNEL_GETTER + LIBRAPID_BINARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Divide> { + static constexpr const char *name = "divide"; + static constexpr const char *filename = "arithmetic"; + static constexpr const char *kernelName = "divArrays"; + static constexpr const char *kernelNameScalarRhs = "divArraysScalarRhs"; + static constexpr const char *kernelNameScalarLhs = "divArraysScalarLhs"; + LIBRAPID_BINARY_KERNEL_GETTER + LIBRAPID_BINARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::LessThan> { + static constexpr const char *name = "less than"; + static constexpr const char *filename = "arithmetic"; + static constexpr const char *kernelName = "lessThanArrays"; + static constexpr const char *kernelNameScalarRhs = "lessThanArraysScalarRhs"; + static constexpr const char *kernelNameScalarLhs = "lessThanArraysScalarLhs"; + LIBRAPID_BINARY_KERNEL_GETTER + LIBRAPID_BINARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::GreaterThan> { + static constexpr const char *name = "greater than"; + static constexpr const char *filename = "arithmetic"; + static constexpr const char *kernelName = "greaterThanArrays"; + static constexpr const char *kernelNameScalarRhs = "greaterThanArraysScalarRhs"; + static constexpr const char *kernelNameScalarLhs = "greaterThanArraysScalarLhs"; + LIBRAPID_BINARY_KERNEL_GETTER + LIBRAPID_BINARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::LessThanEqual> { + static constexpr const char *name = "less than or equal"; + static constexpr const char *filename = "arithmetic"; + static constexpr const char *kernelName = "lessThanEqualArrays"; + static constexpr const char *kernelNameScalarRhs = "lessThanEqualArraysScalarRhs"; + static constexpr const char *kernelNameScalarLhs = "lessThanEqualArraysScalarLhs"; + LIBRAPID_BINARY_KERNEL_GETTER + LIBRAPID_BINARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::GreaterThanEqual> { + static constexpr const char *name = "greater than or equal"; + static constexpr const char *filename = "arithmetic"; + static constexpr const char *kernelName = "greaterThanEqualArrays"; + static constexpr const char *kernelNameScalarRhs = "greaterThanEqualArraysScalarRhs"; + static constexpr const char *kernelNameScalarLhs = "greaterThanEqualArraysScalarLhs"; + LIBRAPID_BINARY_KERNEL_GETTER + LIBRAPID_BINARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::ElementWiseEqual> { + static constexpr const char *name = "element wise equal"; + static constexpr const char *filename = "arithmetic"; + static constexpr const char *kernelName = "elementWiseEqualArrays"; + static constexpr const char *kernelNameScalarRhs = "elementWiseEqualArraysScalarRhs"; + static constexpr const char *kernelNameScalarLhs = "elementWiseEqualArraysScalarLhs"; + LIBRAPID_BINARY_KERNEL_GETTER + LIBRAPID_BINARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::ElementWiseNotEqual> { + static constexpr const char *name = "element wise not equal"; + static constexpr const char *filename = "arithmetic"; + static constexpr const char *kernelName = "elementWiseNotEqualArrays"; + static constexpr const char *kernelNameScalarRhs = "elementWiseNotEqualArraysScalarRhs"; + static constexpr const char *kernelNameScalarLhs = "elementWiseNotEqualArraysScalarLhs"; + LIBRAPID_BINARY_KERNEL_GETTER + LIBRAPID_BINARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Neg> { + static constexpr const char *name = "negate"; + static constexpr const char *filename = "negate"; + static constexpr const char *kernelName = "negateArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Sin> { + static constexpr const char *name = "sin"; + static constexpr const char *filename = "trigonometry"; + static constexpr const char *kernelName = "sinArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Cos> { + static constexpr const char *name = "cos"; + static constexpr const char *filename = "trigonometry"; + static constexpr const char *kernelName = "cosArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Tan> { + static constexpr const char *name = "tan"; + static constexpr const char *filename = "trigonometry"; + static constexpr const char *kernelName = "tanArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Asin> { + static constexpr const char *name = "arcsin"; + static constexpr const char *filename = "trigonometry"; + static constexpr const char *kernelName = "asinArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Acos> { + static constexpr const char *name = "arcos"; + static constexpr const char *filename = "trigonometry"; + static constexpr const char *kernelName = "acosArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Atan> { + static constexpr const char *name = "arctan"; + static constexpr const char *filename = "trigonometry"; + static constexpr const char *kernelName = "atanArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Sinh> { + static constexpr const char *name = "hyperbolic sine"; + static constexpr const char *filename = "trigonometry"; + static constexpr const char *kernelName = "sinhArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Cosh> { + static constexpr const char *name = "hyperbolic cosine"; + static constexpr const char *filename = "trigonometry"; + static constexpr const char *kernelName = "coshArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Tanh> { + static constexpr const char *name = "hyperbolic tangent"; + static constexpr const char *filename = "trigonometry"; + static constexpr const char *kernelName = "tanhArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Exp> { + static constexpr const char *name = "exponent"; + static constexpr const char *filename = "expLogPow"; + static constexpr const char *kernelName = "expArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Log> { + static constexpr const char *name = "logarithm"; + static constexpr const char *filename = "expLogPow"; + static constexpr const char *kernelName = "logArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Log2> { + static constexpr const char *name = "logarithm base 2"; + static constexpr const char *filename = "expLogPow"; + static constexpr const char *kernelName = "log2Arrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Log10> { + static constexpr const char *name = "logarithm base 10"; + static constexpr const char *filename = "expLogPow"; + static constexpr const char *kernelName = "log10Arrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Sqrt> { + static constexpr const char *name = "square root"; + static constexpr const char *filename = "expLogPow"; + static constexpr const char *kernelName = "sqrtArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Cbrt> { + static constexpr const char *name = "cube root"; + static constexpr const char *filename = "expLogPow"; + static constexpr const char *kernelName = "cbrtArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Abs> { + static constexpr const char *name = "absolute value"; + static constexpr const char *filename = "abs"; + static constexpr const char *kernelName = "absArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Floor> { + static constexpr const char *name = "floor"; + static constexpr const char *filename = "floorCeilRound"; + static constexpr const char *kernelName = "floorArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + + template<> + struct TypeInfo<::librapid::detail::Ceil> { + static constexpr const char *name = "ceiling"; + static constexpr const char *filename = "floorCeilRound"; + static constexpr const char *kernelName = "ceilArrays"; + LIBRAPID_UNARY_KERNEL_GETTER + LIBRAPID_UNARY_SHAPE_EXTRACTOR + }; + } // namespace typetraits + + namespace detail { + template + constexpr bool isArrayOp() { + return (typetraits::IsArrayContainer>::value || + typetraits::IsLibRapidType>::value); + } + + template + constexpr bool isArrayOpArray() { + return (typetraits::TypeInfo>::type != LibRapidType::Scalar) && + (typetraits::TypeInfo>::type != LibRapidType::Scalar) && + typetraits::IsLibRapidType>::value && + typetraits::IsLibRapidType>::value; + } + + template + constexpr bool isArrayOpWithScalar() { + return (typetraits::IsLibRapidType>::value && + typetraits::TypeInfo>::type == LibRapidType::Scalar) || + (typetraits::TypeInfo>::type == LibRapidType::Scalar && + typetraits::IsLibRapidType>::value); + } + } // namespace detail + + namespace array { +#define IS_ARRAY_OP detail::isArrayOp() +#define IS_ARRAY_OP_ARRAY detail::isArrayOpArray() #define IS_ARRAY_OP_WITH_SCALAR detail::isArrayOpWithScalar() - /// \brief Element-wise array addition - /// - /// Performs element-wise addition on two arrays. They must both be the same size - /// and of the same data type. - /// - /// \tparam LHS Type of the LHS element - /// \tparam RHS Type of the RHS element - /// \param lhs The first array - /// \param rhs The second array - /// \return The element-wise sum of the two arrays - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator+(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Plus, LHS, RHS> { - LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); - return detail::makeFunction, detail::Plus>(lhs, - rhs); - } - - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator+(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Plus, LHS, RHS> { - return detail::makeFunction, detail::Plus>(lhs, - rhs); - } - - /// \brief Element-wise array subtraction - /// - /// Performs element-wise subtraction on two arrays. They must both be the same size - /// and of the same data type. - /// - /// \tparam LHS Type of the LHS element - /// \tparam RHS Type of the RHS element - /// \param lhs The first array - /// \param rhs The second array - /// \return The element-wise difference of the two arrays - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator-(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Minus, LHS, RHS> { - LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); - return detail::makeFunction, detail::Minus>(lhs, - rhs); - } - - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator-(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Minus, LHS, RHS> { - return detail::makeFunction, detail::Minus>(lhs, - rhs); - } - - /// \brief Element-wise array multiplication - /// - /// Performs element-wise multiplication on two arrays. They must both be the same - /// size and of the same data type. - /// - /// \tparam LHS Type of the LHS element - /// \tparam RHS Type of the RHS element - /// \param lhs The first array - /// \param rhs The second array - /// \return The element-wise product of the two arrays - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator*(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Multiply, LHS, RHS> { - LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); - return detail::makeFunction, detail::Multiply>( - lhs, rhs); - } - - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator*(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Multiply, LHS, RHS> { - return detail::makeFunction, detail::Multiply>( - lhs, rhs); - } - - /// \brief Element-wise array division - /// - /// Performs element-wise division on two arrays. They must both be the same size - /// and of the same data type. - /// - /// \tparam LHS Type of the LHS element - /// \tparam RHS Type of the RHS element - /// \param lhs The first array - /// \param rhs The second array - /// \return The element-wise division of the two arrays - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator/(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Divide, LHS, RHS> { - LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); - return detail::makeFunction, detail::Divide>( - lhs, rhs); - } - - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator/(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Divide, LHS, RHS> { - return detail::makeFunction, detail::Divide>( - lhs, rhs); - } - - /// \brief Element-wise array comparison, checking whether a < b for all a, b in - /// input arrays - /// - /// Performs an element-wise comparison on two arrays, checking if the first value - /// is less than the second. They must both be the same size and of the same data - /// type. - /// - /// \tparam LHS Type of the LHS element - /// \tparam RHS Type of the RHS element - /// \param lhs The first array - /// \param rhs The second array - /// \return The element-wise comparison of the two arrays - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator<(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::LessThan, LHS, RHS> { - LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); - return detail::makeFunction, detail::LessThan>( - lhs, rhs); - } - - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator<(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::LessThan, LHS, RHS> { - return detail::makeFunction, detail::LessThan>( - lhs, rhs); - } - - /// \brief Element-wise array comparison, checking whether a > b for all a, b in - /// input arrays - /// - /// Performs an element-wise comparison on two arrays, checking if the first value - /// is greater than the second. They must both be the same size and of the same data - /// type. - /// - /// \tparam LHS Type of the LHS element - /// \tparam RHS Type of the RHS element - /// \param lhs The first array - /// \param rhs The second array - /// \return The element-wise comparison of the two arrays - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator>(const LHS &lhs, const RHS &rhs) - LIBRAPID_RELEASE_NOEXCEPT->detail::Function, - detail::GreaterThan, LHS, RHS> { - LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); - return detail::makeFunction, - detail::GreaterThan>(lhs, rhs); - } - - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator>(const LHS &lhs, const RHS &rhs) - LIBRAPID_RELEASE_NOEXCEPT->detail::Function, - detail::GreaterThan, LHS, RHS> { - return detail::makeFunction, - detail::GreaterThan>(lhs, rhs); - } - - /// \brief Element-wise array comparison, checking whether a <= b for all a, b in - /// input arrays - /// - /// Performs an element-wise comparison on two arrays, checking if the first value - /// is less than or equal to the second. They must both be the same size and of the - /// same data type. - /// - /// \tparam LHS Type of the LHS element - /// \tparam RHS Type of the RHS element - /// \param lhs The first array - /// \param rhs The second array - /// \return The element-wise comparison of the two arrays - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator<=(const LHS &lhs, const RHS &rhs) - LIBRAPID_RELEASE_NOEXCEPT->detail::Function, - detail::LessThanEqual, LHS, RHS> { - LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); - return detail::makeFunction, - detail::LessThanEqual>(lhs, rhs); - } - - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator<=(const LHS &lhs, const RHS &rhs) - LIBRAPID_RELEASE_NOEXCEPT->detail::Function, - detail::LessThanEqual, LHS, RHS> { - return detail::makeFunction, - detail::LessThanEqual>(lhs, rhs); - } - - /// \brief Element-wise array comparison, checking whether a >= b for all a, b in - /// input arrays - /// - /// Performs an element-wise comparison on two arrays, checking if the first value - /// is greater than or equal to the second. They must both be the same size and of - /// the same data type. - /// - /// \tparam LHS Type of the LHS element - /// \tparam RHS Type of the RHS element - /// \param lhs The first array - /// \param rhs The second array - /// \return The element-wise comparison of the two arrays - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator>=(const LHS &lhs, const RHS &rhs) - LIBRAPID_RELEASE_NOEXCEPT->detail::Function, - detail::GreaterThanEqual, LHS, RHS> { - LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); - return detail::makeFunction, - detail::GreaterThanEqual>(lhs, rhs); - } - - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator>=(const LHS &lhs, const RHS &rhs) - LIBRAPID_RELEASE_NOEXCEPT->detail::Function, - detail::GreaterThanEqual, LHS, RHS> { - return detail::makeFunction, - detail::GreaterThanEqual>(lhs, rhs); - } - - /// \brief Element-wise array comparison, checking whether a == b for all a, b in - /// input arrays - /// - /// Performs an element-wise comparison on two arrays, checking if the first value - /// is equal to the second. They must both be the same size and of the same data - /// type. - /// - /// \tparam LHS Type of the LHS element - /// \tparam RHS Type of the RHS element - /// \param lhs The first array - /// \param rhs The second array - /// \return The element-wise comparison of the two arrays - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const LHS &lhs, const RHS &rhs) - LIBRAPID_RELEASE_NOEXCEPT->detail::Function, - detail::ElementWiseEqual, LHS, RHS> { - LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); - return detail::makeFunction, - detail::ElementWiseEqual>(lhs, rhs); - } - - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const LHS &lhs, const RHS &rhs) - LIBRAPID_RELEASE_NOEXCEPT->detail::Function, - detail::ElementWiseEqual, LHS, RHS> { - return detail::makeFunction, - detail::ElementWiseEqual>(lhs, rhs); - } - - /// \brief Element-wise array comparison, checking whether a != b for all a, b in - /// input arrays - /// - /// Performs an element-wise comparison on two arrays, checking if the first value - /// is not equal to the second. They must both be the same size and of the same data - /// type. - /// - /// \tparam LHS Type of the LHS element - /// \tparam RHS Type of the RHS element - /// \param lhs The first array - /// \param rhs The second array - /// \return The element-wise comparison of the two arrays - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const LHS &lhs, const RHS &rhs) - LIBRAPID_RELEASE_NOEXCEPT->detail::Function, - detail::ElementWiseNotEqual, LHS, RHS> { - LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); - return detail::makeFunction, - detail::ElementWiseNotEqual>(lhs, rhs); - } - - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const LHS &lhs, const RHS &rhs) - LIBRAPID_RELEASE_NOEXCEPT->detail::Function, - detail::ElementWiseNotEqual, LHS, RHS> { - return detail::makeFunction, - detail::ElementWiseNotEqual>(lhs, rhs); - } - - /// \brief Negate each element in the array - /// \tparam VAL Type to negate - /// \param val The input array or function - /// \return Negation function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto - operator-(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Neg, VAL> { - return detail::makeFunction, detail::Neg>(val); - } - } // namespace array - - /// \brief Calculate the sine of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \sin(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Sine function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sin(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Sin, VAL> { - return detail::makeFunction, detail::Sin>(val); - } - - /// \brief Calculate the cosine of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \cos(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Cosine function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cos(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Cos, VAL> { - return detail::makeFunction, detail::Cos>(val); - } - - /// \brief Calculate the tangent of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \tan(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Tangent function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto tan(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Tan, VAL> { - return detail::makeFunction, detail::Tan>(val); - } - - /// \brief Calculate the arcsine of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \sin^{-1}(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Arcsine function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto asin(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Asin, VAL> { - return detail::makeFunction, detail::Asin>(val); - } - - /// \brief Calculate the arccosine of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \cos^{-1}(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Arccosine function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto acos(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Acos, VAL> { - return detail::makeFunction, detail::Acos>(val); - } - - /// \brief Calculate the arctangent of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \tan^{-1}(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Arctangent function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto atan(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Atan, VAL> { - return detail::makeFunction, detail::Atan>(val); - } - - /// \brief Calculate the hyperbolic sine of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \sinh(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Hyperbolic sine function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sinh(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Sinh, VAL> { - return detail::makeFunction, detail::Sinh>(val); - } - - /// \brief Calculate the hyperbolic cosine of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \cosh(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Hyperbolic cosine function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cosh(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Cosh, VAL> { - return detail::makeFunction, detail::Cosh>(val); - } - - /// \brief Calculate the hyperbolic tangent of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \tanh(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Hyperbolic tangent function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto tanh(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Tanh, VAL> { - return detail::makeFunction, detail::Tanh>(val); - } - - /// \brief Raise e to the power of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = e^{A_i}\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Exponential function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto exp(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Exp, VAL> { - return detail::makeFunction, detail::Exp>(val); - } - - // \brief Compute the natural logarithm of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \ln(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Natural logarithm function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Log, VAL> { - return detail::makeFunction, detail::Log>(val); - } - - /// \brief Compute the base 10 logarithm of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \log_{10}(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Base 10 logarithm function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log10(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Log10, VAL> { - return detail::makeFunction, detail::Log10>(val); - } - - /// \brief Compute the base 2 logarithm of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \log_{2}(A_i)\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Base 2 logarithm function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log2(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Log2, VAL> { - return detail::makeFunction, detail::Log2>(val); - } - - /// \brief Compute the square root of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \sqrt{A_i}\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Square root function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrt(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Sqrt, VAL> { - return detail::makeFunction, detail::Sqrt>(val); - } - - /// \brief Compute the cube root of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \sqrt[3]{A_i}\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Cube root function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cbrt(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Cbrt, VAL> { - return detail::makeFunction, detail::Cbrt>(val); - } - - /// \brief Compute the absolute value of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = |A_i|\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Absolute value function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto abs(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Abs, VAL> { - return detail::makeFunction, detail::Abs>(val); - } - - /// \brief Compute the floor of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \lfloor A_i \rfloor\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Floor function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto floor(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Floor, VAL> { - return detail::makeFunction, detail::Floor>(val); - } - - /// \brief Compute the ceiling of each element in the array - /// - /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \lceil A_i \rceil\f$ - /// - /// \tparam VAL Type of the input - /// \param val The input array or function - /// \return Ceiling function object - template = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ceil(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT - ->detail::Function, detail::Ceil, VAL> { - return detail::makeFunction, detail::Ceil>(val); - } + /// \brief Element-wise array addition + /// + /// Performs element-wise addition on two arrays. They must both be the same size + /// and of the same data type. + /// + /// \tparam LHS Type of the LHS element + /// \tparam RHS Type of the RHS element + /// \param lhs The first array + /// \param rhs The second array + /// \return The element-wise sum of the two arrays + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator+(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Plus, LHS, RHS> { + LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); + return detail::makeFunction, detail::Plus>(lhs, + rhs); + } + + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator+(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Plus, LHS, RHS> { + return detail::makeFunction, detail::Plus>(lhs, + rhs); + } + + /// \brief Element-wise array subtraction + /// + /// Performs element-wise subtraction on two arrays. They must both be the same size + /// and of the same data type. + /// + /// \tparam LHS Type of the LHS element + /// \tparam RHS Type of the RHS element + /// \param lhs The first array + /// \param rhs The second array + /// \return The element-wise difference of the two arrays + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator-(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Minus, LHS, RHS> { + LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); + return detail::makeFunction, detail::Minus>(lhs, + rhs); + } + + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator-(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Minus, LHS, RHS> { + return detail::makeFunction, detail::Minus>(lhs, + rhs); + } + + /// \brief Element-wise array multiplication + /// + /// Performs element-wise multiplication on two arrays. They must both be the same + /// size and of the same data type. + /// + /// \tparam LHS Type of the LHS element + /// \tparam RHS Type of the RHS element + /// \param lhs The first array + /// \param rhs The second array + /// \return The element-wise product of the two arrays + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator*(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Multiply, LHS, RHS> { + LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); + return detail::makeFunction, detail::Multiply>( + lhs, rhs); + } + + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator*(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Multiply, LHS, RHS> { + return detail::makeFunction, detail::Multiply>( + lhs, rhs); + } + + /// \brief Element-wise array division + /// + /// Performs element-wise division on two arrays. They must both be the same size + /// and of the same data type. + /// + /// \tparam LHS Type of the LHS element + /// \tparam RHS Type of the RHS element + /// \param lhs The first array + /// \param rhs The second array + /// \return The element-wise division of the two arrays + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator/(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Divide, LHS, RHS> { + LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); + return detail::makeFunction, detail::Divide>( + lhs, rhs); + } + + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator/(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Divide, LHS, RHS> { + return detail::makeFunction, detail::Divide>( + lhs, rhs); + } + + /// \brief Element-wise array comparison, checking whether a < b for all a, b in + /// input arrays + /// + /// Performs an element-wise comparison on two arrays, checking if the first value + /// is less than the second. They must both be the same size and of the same data + /// type. + /// + /// \tparam LHS Type of the LHS element + /// \tparam RHS Type of the RHS element + /// \param lhs The first array + /// \param rhs The second array + /// \return The element-wise comparison of the two arrays + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator<(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::LessThan, LHS, RHS> { + LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); + return detail::makeFunction, detail::LessThan>( + lhs, rhs); + } + + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator<(const LHS &lhs, const RHS &rhs) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::LessThan, LHS, RHS> { + return detail::makeFunction, detail::LessThan>( + lhs, rhs); + } + + /// \brief Element-wise array comparison, checking whether a > b for all a, b in + /// input arrays + /// + /// Performs an element-wise comparison on two arrays, checking if the first value + /// is greater than the second. They must both be the same size and of the same data + /// type. + /// + /// \tparam LHS Type of the LHS element + /// \tparam RHS Type of the RHS element + /// \param lhs The first array + /// \param rhs The second array + /// \return The element-wise comparison of the two arrays + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator>(const LHS &lhs, const RHS &rhs) + LIBRAPID_RELEASE_NOEXCEPT->detail::Function, + detail::GreaterThan, LHS, RHS> { + LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); + return detail::makeFunction, + detail::GreaterThan>(lhs, rhs); + } + + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator>(const LHS &lhs, const RHS &rhs) + LIBRAPID_RELEASE_NOEXCEPT->detail::Function, + detail::GreaterThan, LHS, RHS> { + return detail::makeFunction, + detail::GreaterThan>(lhs, rhs); + } + + /// \brief Element-wise array comparison, checking whether a <= b for all a, b in + /// input arrays + /// + /// Performs an element-wise comparison on two arrays, checking if the first value + /// is less than or equal to the second. They must both be the same size and of the + /// same data type. + /// + /// \tparam LHS Type of the LHS element + /// \tparam RHS Type of the RHS element + /// \param lhs The first array + /// \param rhs The second array + /// \return The element-wise comparison of the two arrays + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator<=(const LHS &lhs, const RHS &rhs) + LIBRAPID_RELEASE_NOEXCEPT->detail::Function, + detail::LessThanEqual, LHS, RHS> { + LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); + return detail::makeFunction, + detail::LessThanEqual>(lhs, rhs); + } + + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator<=(const LHS &lhs, const RHS &rhs) + LIBRAPID_RELEASE_NOEXCEPT->detail::Function, + detail::LessThanEqual, LHS, RHS> { + return detail::makeFunction, + detail::LessThanEqual>(lhs, rhs); + } + + /// \brief Element-wise array comparison, checking whether a >= b for all a, b in + /// input arrays + /// + /// Performs an element-wise comparison on two arrays, checking if the first value + /// is greater than or equal to the second. They must both be the same size and of + /// the same data type. + /// + /// \tparam LHS Type of the LHS element + /// \tparam RHS Type of the RHS element + /// \param lhs The first array + /// \param rhs The second array + /// \return The element-wise comparison of the two arrays + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator>=(const LHS &lhs, const RHS &rhs) + LIBRAPID_RELEASE_NOEXCEPT->detail::Function, + detail::GreaterThanEqual, LHS, RHS> { + LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); + return detail::makeFunction, + detail::GreaterThanEqual>(lhs, rhs); + } + + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator>=(const LHS &lhs, const RHS &rhs) + LIBRAPID_RELEASE_NOEXCEPT->detail::Function, + detail::GreaterThanEqual, LHS, RHS> { + return detail::makeFunction, + detail::GreaterThanEqual>(lhs, rhs); + } + + /// \brief Element-wise array comparison, checking whether a == b for all a, b in + /// input arrays + /// + /// Performs an element-wise comparison on two arrays, checking if the first value + /// is equal to the second. They must both be the same size and of the same data + /// type. + /// + /// \tparam LHS Type of the LHS element + /// \tparam RHS Type of the RHS element + /// \param lhs The first array + /// \param rhs The second array + /// \return The element-wise comparison of the two arrays + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const LHS &lhs, const RHS &rhs) + LIBRAPID_RELEASE_NOEXCEPT->detail::Function, + detail::ElementWiseEqual, LHS, RHS> { + LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); + return detail::makeFunction, + detail::ElementWiseEqual>(lhs, rhs); + } + + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator==(const LHS &lhs, const RHS &rhs) + LIBRAPID_RELEASE_NOEXCEPT->detail::Function, + detail::ElementWiseEqual, LHS, RHS> { + return detail::makeFunction, + detail::ElementWiseEqual>(lhs, rhs); + } + + /// \brief Element-wise array comparison, checking whether a != b for all a, b in + /// input arrays + /// + /// Performs an element-wise comparison on two arrays, checking if the first value + /// is not equal to the second. They must both be the same size and of the same data + /// type. + /// + /// \tparam LHS Type of the LHS element + /// \tparam RHS Type of the RHS element + /// \param lhs The first array + /// \param rhs The second array + /// \return The element-wise comparison of the two arrays + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const LHS &lhs, const RHS &rhs) + LIBRAPID_RELEASE_NOEXCEPT->detail::Function, + detail::ElementWiseNotEqual, LHS, RHS> { + LIBRAPID_ASSERT(lhs.shape().operator==(rhs.shape()), "Shapes must be equal"); + return detail::makeFunction, + detail::ElementWiseNotEqual>(lhs, rhs); + } + + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator!=(const LHS &lhs, const RHS &rhs) + LIBRAPID_RELEASE_NOEXCEPT->detail::Function, + detail::ElementWiseNotEqual, LHS, RHS> { + return detail::makeFunction, + detail::ElementWiseNotEqual>(lhs, rhs); + } + + /// \brief Negate each element in the array + /// \tparam VAL Type to negate + /// \param val The input array or function + /// \return Negation function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto + operator-(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Neg, VAL> { + return detail::makeFunction, detail::Neg>(val); + } + } // namespace array + + /// \brief Calculate the sine of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \sin(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Sine function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sin(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Sin, VAL> { + return detail::makeFunction, detail::Sin>(val); + } + + /// \brief Calculate the cosine of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \cos(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Cosine function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cos(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Cos, VAL> { + return detail::makeFunction, detail::Cos>(val); + } + + /// \brief Calculate the tangent of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \tan(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Tangent function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto tan(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Tan, VAL> { + return detail::makeFunction, detail::Tan>(val); + } + + /// \brief Calculate the arcsine of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \sin^{-1}(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Arcsine function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto asin(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Asin, VAL> { + return detail::makeFunction, detail::Asin>(val); + } + + /// \brief Calculate the arccosine of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \cos^{-1}(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Arccosine function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto acos(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Acos, VAL> { + return detail::makeFunction, detail::Acos>(val); + } + + /// \brief Calculate the arctangent of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \tan^{-1}(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Arctangent function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto atan(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Atan, VAL> { + return detail::makeFunction, detail::Atan>(val); + } + + /// \brief Calculate the hyperbolic sine of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \sinh(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Hyperbolic sine function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sinh(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Sinh, VAL> { + return detail::makeFunction, detail::Sinh>(val); + } + + /// \brief Calculate the hyperbolic cosine of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \cosh(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Hyperbolic cosine function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cosh(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Cosh, VAL> { + return detail::makeFunction, detail::Cosh>(val); + } + + /// \brief Calculate the hyperbolic tangent of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \tanh(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Hyperbolic tangent function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto tanh(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Tanh, VAL> { + return detail::makeFunction, detail::Tanh>(val); + } + + /// \brief Raise e to the power of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = e^{A_i}\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Exponential function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto exp(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Exp, VAL> { + return detail::makeFunction, detail::Exp>(val); + } + + // \brief Compute the natural logarithm of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \ln(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Natural logarithm function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Log, VAL> { + return detail::makeFunction, detail::Log>(val); + } + + /// \brief Compute the base 10 logarithm of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \log_{10}(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Base 10 logarithm function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log10(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Log10, VAL> { + return detail::makeFunction, detail::Log10>(val); + } + + /// \brief Compute the base 2 logarithm of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \log_{2}(A_i)\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Base 2 logarithm function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log2(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Log2, VAL> { + return detail::makeFunction, detail::Log2>(val); + } + + /// \brief Compute the square root of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \sqrt{A_i}\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Square root function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrt(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Sqrt, VAL> { + return detail::makeFunction, detail::Sqrt>(val); + } + + /// \brief Compute the cube root of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \sqrt[3]{A_i}\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Cube root function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cbrt(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Cbrt, VAL> { + return detail::makeFunction, detail::Cbrt>(val); + } + + /// \brief Compute the absolute value of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = |A_i|\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Absolute value function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto abs(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Abs, VAL> { + return detail::makeFunction, detail::Abs>(val); + } + + /// \brief Compute the floor of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \lfloor A_i \rfloor\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Floor function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto floor(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Floor, VAL> { + return detail::makeFunction, detail::Floor>(val); + } + + /// \brief Compute the ceiling of each element in the array + /// + /// \f$R = \{ R_0, R_1, R_2, ... \} \f$ \text{ where } \f$R_i = \lceil A_i \rceil\f$ + /// + /// \tparam VAL Type of the input + /// \param val The input array or function + /// \return Ceiling function object + template = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ceil(const VAL &val) LIBRAPID_RELEASE_NOEXCEPT + ->detail::Function, detail::Ceil, VAL> { + return detail::makeFunction, detail::Ceil>(val); + } } // namespace librapid #endif // LIBRAPID_ARRAY_OPERATIONS_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/pseudoConstructors.hpp b/librapid/include/librapid/array/pseudoConstructors.hpp index c22d661e..58499e4b 100644 --- a/librapid/include/librapid/array/pseudoConstructors.hpp +++ b/librapid/include/librapid/array/pseudoConstructors.hpp @@ -2,227 +2,227 @@ #define LIBRAPID_ARRAY_PSEUDO_CONSTRUCTORS_HPP namespace librapid { - /// \brief Force the input to be evaluated to an Array - /// - /// When given a scalar or Array type, this function will return the input unchanged. When given - /// a Function, it will evaluate the function and return the result. This is useful for - /// functions which require an Array instance as input and cannot function with function types. - /// - /// Note that the input is not copied or moved, so the returned Array will be a reference to the - /// input. - /// - /// \tparam T Input type - /// \param other Input - /// \return Evaluated input - template - auto evaluated(const T &other) { - return other; - } - - template - auto evaluated(const array::ArrayContainer &other) { - return other; - } - - template - auto evaluated(const detail::Function &other) { - return other.eval(); - } - - /// \brief Create a new array with the same type and shape as the input array, but without - /// initializing any of the data - /// \tparam T Input array type - /// \param other Input array - /// \return New array - template - auto emptyLike(const T &other) { - using Scalar = typename typetraits::TypeInfo::Scalar; - using Backend = typename typetraits::TypeInfo::Backend; - return Array(other.shape()); - } - - /// \brief Create an Array filled with zeros - /// - /// Create an array with a specified shape, scalar type and Backend, and fill it with zeros. - /// - /// \tparam Scalar Scalar type of the Array - /// \tparam Backend Backend type of the Array - /// \tparam T Scalar type of the Shape - /// \tparam N Maximum number of dimensions of the Shape - /// \param shape Shape of the Array - /// \return Array filled with zeros - template - Array zeros(const Shape &shape) { - return Array(shape, Scalar(0)); - } - - /// \brief Create an Array filled with zeros, with the same type and shape as the input array - /// - /// \tparam T Input array type - /// \param other Input array - /// \return New array - template - auto zerosLike(const T &other) { - using Scalar = typename typetraits::TypeInfo::Scalar; - using Backend = typename typetraits::TypeInfo::Backend; - return zeros(other.shape()); - } - - /// \brief Create an Array filled with ones - /// - /// Create an array with a specified shape, scalar type and Backend, and fill it with ones. - /// - /// \tparam Scalar Scalar type of the Array - /// \tparam Backend Backend type of the Array - /// \tparam T Scalar type of the Shape - /// \tparam N Maximum number of dimensions of the Shape - /// \param shape Shape of the Array - /// \return Array filled with ones - template - Array ones(const Shape &shape) { - return Array(shape, Scalar(1)); - } - - /// \brief Create an Array filled with ones, with the same type and shape as the input array - /// - /// \tparam T Input array type - /// \param other Input array - /// \return New array - template - auto onesLike(const T &other) { - using Scalar = typename typetraits::TypeInfo::Scalar; - using Backend = typename typetraits::TypeInfo::Backend; - return ones(other.shape()); - } - - /// \brief Create an Array filled, in order, with the numbers 0 to N-1 - /// - /// Create a new Array object with a given shape, where each value is filled with a number from - /// 0 to N-1, where N is the total number of elements in the array. The values are filled in - /// the same order as the array is stored in memory. - /// - /// \tparam Scalar Scalar type of the Array - /// \tparam Backend Backend type of the Array - /// \tparam T Scalar type of the Shape - /// \tparam N Maximum number of dimensions of the Shape - /// \param shape Shape of the Array - /// \return Array filled with numbers from 0 to N-1 - template - Array ordered(const Shape &shape) { - Array result(shape); - for (size_t i = 0; i < shape.size(); i++) { result.storage()[i] = Scalar(i); } - return result; - } - - /// \brief Create a 1-dimensional Array from a range of numbers and a step size - /// - /// Provided with a start value and a stop value, create a 1-dimensional Array with - /// \f$\lfloor \frac{stop - start}{step} \rfloor \f$ elements, where each element is - /// \f$start + i \times step\f$, for \f$i \in [0, \lfloor \frac{stop - start}{step} \rfloor)\f$. - /// - /// \tparam Scalar Scalar type of the Array - /// \tparam Backend Backend for the Array - /// \tparam Start Scalar type of the start value - /// \tparam Stop Scalar type of the stop value - /// \tparam Step Scalar type of the step size - /// \param start First value in the range - /// \param stop Second value in the range - /// \param step Step size between values in the range - /// \return Array - template - Array arange(Start start, Stop stop, Step step) { - LIBRAPID_ASSERT(step != 0, "Step size cannot be zero"); - LIBRAPID_ASSERT((stop - start) / step > 0, "Step size is invalid for the specified range"); - - Shape shape = {(int64_t)::librapid::abs((stop - start) / step)}; - Array result(shape); - for (size_t i = 0; i < shape.size(); i++) { - result.storage()[i] = Scalar(start + i * step); - } - return result; - } - - template - Array arange(T start, T stop) { - LIBRAPID_ASSERT((stop - start) > 0, "Step size is invalid for the specified range"); - - Shape shape = {(int64_t)::librapid::abs(stop - start)}; - Array result(shape); - for (size_t i = 0; i < shape.size(); i++) { result.storage()[i] = Scalar(start + i); } - return result; - } - - template - Array arange(T stop) { - Shape shape = {(int64_t)::librapid::abs(stop)}; - Array result(shape); - for (size_t i = 0; i < shape.size(); i++) { result.storage()[i] = Scalar(i); } - return result; - } - - /// \brief Create a 1-dimensional Array with a specified number of elements, evenly spaced - /// between two values - /// - /// Create a 1-dimensional Array with a specified number of elements, evenly spaced between - /// two values. If \p includeEnd is true, the last element of the Array will be equal to - /// \p stop, otherwise it will be equal to \p stop - \f$\frac{stop - start}{num}\f$. - /// - /// \tparam Scalar Scalar type of the Array - /// \tparam Backend Backend for the Array - /// \tparam Start Scalar type of the start value - /// \tparam Stop Scalar type of the stop value - /// \param start First value in the range - /// \param stop Second value in the range - /// \param num Number of elements in the Array - /// \param includeEnd Whether or not to include the end value in the Array - /// \return Linearly spaced Array - template - Array linspace(Start start, Stop stop, int64_t num, bool includeEnd = true) { - LIBRAPID_ASSERT(num > 0, "Number of samples must be greater than zero"); - - auto startCast = static_cast(start); - auto stopCast = static_cast(stop); - auto den = static_cast(num - includeEnd); - Shape shape = {num}; - Array result(shape); - for (size_t i = 0; i < shape.size(); i++) { - result.storage()[i] = startCast + (stopCast - startCast) * static_cast(i) / den; - } - return result; - } - - template - Array logspace(Start start, Stop stop, int64_t num, bool includeEnd = true) { - LIBRAPID_ASSERT(num > 0, "Number of samples must be greater than zero"); - - auto logLower = ::librapid::log(static_cast(start)); - auto logUpper = ::librapid::log(static_cast(stop)); - - Shape shape = {num}; - Array result(shape); - - for (size_t i = 0; i < shape.size(); i++) { - result.storage()[i] = - ::librapid::exp(logLower + (logUpper - logLower) * static_cast(i) / - static_cast(num - includeEnd)); - } - - return result; - } - - template - Array random(const ShapeType &shape, Lower lower = 0, Upper upper = 1) { - Array result(shape); - fillRandom(result, lower, upper); - return result; - } + /// \brief Force the input to be evaluated to an Array + /// + /// When given a scalar or Array type, this function will return the input unchanged. When given + /// a Function, it will evaluate the function and return the result. This is useful for + /// functions which require an Array instance as input and cannot function with function types. + /// + /// Note that the input is not copied or moved, so the returned Array will be a reference to the + /// input. + /// + /// \tparam T Input type + /// \param other Input + /// \return Evaluated input + template + auto evaluated(const T &other) { + return other; + } + + template + auto evaluated(const array::ArrayContainer &other) { + return other; + } + + template + auto evaluated(const detail::Function &other) { + return other.eval(); + } + + /// \brief Create a new array with the same type and shape as the input array, but without + /// initializing any of the data + /// \tparam T Input array type + /// \param other Input array + /// \return New array + template + auto emptyLike(const T &other) { + using Scalar = typename typetraits::TypeInfo::Scalar; + using Backend = typename typetraits::TypeInfo::Backend; + return Array(other.shape()); + } + + /// \brief Create an Array filled with zeros + /// + /// Create an array with a specified shape, scalar type and Backend, and fill it with zeros. + /// + /// \tparam Scalar Scalar type of the Array + /// \tparam Backend Backend type of the Array + /// \tparam T Scalar type of the Shape + /// \tparam N Maximum number of dimensions of the Shape + /// \param shape Shape of the Array + /// \return Array filled with zeros + template + Array zeros(const Shape &shape) { + return Array(shape, Scalar(0)); + } + + /// \brief Create an Array filled with zeros, with the same type and shape as the input array + /// + /// \tparam T Input array type + /// \param other Input array + /// \return New array + template + auto zerosLike(const T &other) { + using Scalar = typename typetraits::TypeInfo::Scalar; + using Backend = typename typetraits::TypeInfo::Backend; + return zeros(other.shape()); + } + + /// \brief Create an Array filled with ones + /// + /// Create an array with a specified shape, scalar type and Backend, and fill it with ones. + /// + /// \tparam Scalar Scalar type of the Array + /// \tparam Backend Backend type of the Array + /// \tparam T Scalar type of the Shape + /// \tparam N Maximum number of dimensions of the Shape + /// \param shape Shape of the Array + /// \return Array filled with ones + template + Array ones(const Shape &shape) { + return Array(shape, Scalar(1)); + } + + /// \brief Create an Array filled with ones, with the same type and shape as the input array + /// + /// \tparam T Input array type + /// \param other Input array + /// \return New array + template + auto onesLike(const T &other) { + using Scalar = typename typetraits::TypeInfo::Scalar; + using Backend = typename typetraits::TypeInfo::Backend; + return ones(other.shape()); + } + + /// \brief Create an Array filled, in order, with the numbers 0 to N-1 + /// + /// Create a new Array object with a given shape, where each value is filled with a number from + /// 0 to N-1, where N is the total number of elements in the array. The values are filled in + /// the same order as the array is stored in memory. + /// + /// \tparam Scalar Scalar type of the Array + /// \tparam Backend Backend type of the Array + /// \tparam T Scalar type of the Shape + /// \tparam N Maximum number of dimensions of the Shape + /// \param shape Shape of the Array + /// \return Array filled with numbers from 0 to N-1 + template + Array ordered(const Shape &shape) { + Array result(shape); + for (size_t i = 0; i < shape.size(); i++) { result.storage()[i] = Scalar(i); } + return result; + } + + /// \brief Create a 1-dimensional Array from a range of numbers and a step size + /// + /// Provided with a start value and a stop value, create a 1-dimensional Array with + /// \f$\lfloor \frac{stop - start}{step} \rfloor \f$ elements, where each element is + /// \f$start + i \times step\f$, for \f$i \in [0, \lfloor \frac{stop - start}{step} \rfloor)\f$. + /// + /// \tparam Scalar Scalar type of the Array + /// \tparam Backend Backend for the Array + /// \tparam Start Scalar type of the start value + /// \tparam Stop Scalar type of the stop value + /// \tparam Step Scalar type of the step size + /// \param start First value in the range + /// \param stop Second value in the range + /// \param step Step size between values in the range + /// \return Array + template + Array arange(Start start, Stop stop, Step step) { + LIBRAPID_ASSERT(step != 0, "Step size cannot be zero"); + LIBRAPID_ASSERT((stop - start) / step > 0, "Step size is invalid for the specified range"); + + Shape shape = {(int64_t)::librapid::abs((stop - start) / step)}; + Array result(shape); + for (size_t i = 0; i < shape.size(); i++) { + result.storage()[i] = Scalar(start + i * step); + } + return result; + } + + template + Array arange(T start, T stop) { + LIBRAPID_ASSERT((stop - start) > 0, "Step size is invalid for the specified range"); + + Shape shape = {(int64_t)::librapid::abs(stop - start)}; + Array result(shape); + for (size_t i = 0; i < shape.size(); i++) { result.storage()[i] = Scalar(start + i); } + return result; + } + + template + Array arange(T stop) { + Shape shape = {(int64_t)::librapid::abs(stop)}; + Array result(shape); + for (size_t i = 0; i < shape.size(); i++) { result.storage()[i] = Scalar(i); } + return result; + } + + /// \brief Create a 1-dimensional Array with a specified number of elements, evenly spaced + /// between two values + /// + /// Create a 1-dimensional Array with a specified number of elements, evenly spaced between + /// two values. If \p includeEnd is true, the last element of the Array will be equal to + /// \p stop, otherwise it will be equal to \p stop - \f$\frac{stop - start}{num}\f$. + /// + /// \tparam Scalar Scalar type of the Array + /// \tparam Backend Backend for the Array + /// \tparam Start Scalar type of the start value + /// \tparam Stop Scalar type of the stop value + /// \param start First value in the range + /// \param stop Second value in the range + /// \param num Number of elements in the Array + /// \param includeEnd Whether or not to include the end value in the Array + /// \return Linearly spaced Array + template + Array linspace(Start start, Stop stop, int64_t num, bool includeEnd = true) { + LIBRAPID_ASSERT(num > 0, "Number of samples must be greater than zero"); + + auto startCast = static_cast(start); + auto stopCast = static_cast(stop); + auto den = static_cast(num - includeEnd); + Shape shape = {num}; + Array result(shape); + for (size_t i = 0; i < shape.size(); i++) { + result.storage()[i] = startCast + (stopCast - startCast) * static_cast(i) / den; + } + return result; + } + + template + Array logspace(Start start, Stop stop, int64_t num, bool includeEnd = true) { + LIBRAPID_ASSERT(num > 0, "Number of samples must be greater than zero"); + + auto logLower = ::librapid::log(static_cast(start)); + auto logUpper = ::librapid::log(static_cast(stop)); + + Shape shape = {num}; + Array result(shape); + + for (size_t i = 0; i < shape.size(); i++) { + result.storage()[i] = + ::librapid::exp(logLower + (logUpper - logLower) * static_cast(i) / + static_cast(num - includeEnd)); + } + + return result; + } + + template + Array random(const ShapeType &shape, Lower lower = 0, Upper upper = 1) { + Array result(shape); + fillRandom(result, lower, upper); + return result; + } } // namespace librapid #endif // LIBRAPID_ARRAY_PSEUDO_CONSTRUCTORS_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/sizetype.hpp b/librapid/include/librapid/array/sizetype.hpp index c164dea0..5f7f1ec7 100644 --- a/librapid/include/librapid/array/sizetype.hpp +++ b/librapid/include/librapid/array/sizetype.hpp @@ -7,345 +7,345 @@ */ namespace librapid { - namespace typetraits { - LIBRAPID_DEFINE_AS_TYPE(typename T COMMA size_t N, Shape); - } - - template - class Shape { - public: - using SizeType = T; - static constexpr size_t MaxDimensions = N; - - /// Default constructor - Shape() = default; - - /// Create a shape object from the dimensions of a FixedStorage object. This is used - // mainly internally, but may serve some purpose I haven't yet thought of. - /// \tparam Scalar Scalar type of the FixedStorage object - /// \tparam Dimensions Dimensions of the FixedStorage object - /// \param fixed The FixedStorage object - template - explicit Shape(const FixedStorage &fixed); - - /// Create a Shape object from a list of values - /// \tparam V Scalar type of the values - /// \param vals The dimensions for the object - template::value> = 0> - Shape(const std::initializer_list &vals); - - /// Create a Shape object from a vector of values - /// \tparam V Scalar type of the values - /// \param vals The dimensions for the object - template::value> = 0> - explicit Shape(const std::vector &vals); - - /// Create a copy of a Shape object - /// \param other Shape object to copy - Shape(const Shape &other) = default; - - /// Create a Shape from an RValue - /// \param other Temporary Shape object to copy - Shape(Shape &&other) noexcept = default; - - /// Create a Shape object from one with a different type and number of dimensions. - /// \tparam V Scalar type of the values - /// \tparam Dim Number of dimensions - /// \param other Shape object to copy - template - Shape(const Shape &other); - - /// Create a Shape object from one with a different type and number of dimensions, moving it - /// instead of copying it. - /// \tparam V Scalar type of the values - /// \tparam Dim Number of dimensions - /// \param other Temporary Shape object to move - template - Shape(Shape &&other) noexcept; - - /// Assign a Shape object to this object - /// \tparam V Scalar type of the Shape - /// \param vals Dimensions of the Shape - /// \return *this - template::value> = 0> - Shape &operator=(const std::initializer_list &vals); - - /// Assign a Shape object to this object - /// \tparam V Scalar type of the Shape - /// \param vals Dimensions of the Shape - /// \return *this - template::value> = 0> - Shape &operator=(const std::vector &vals); - - /// Assign an RValue Shape to this object - /// \param other RValue to move - /// \return - Shape &operator=(Shape &&other) noexcept = default; - - /// Assign a Shape to this object - /// \param other Shape to copy - /// \return - Shape &operator=(const Shape &other) = default; - - /// Return a Shape object with \p dims dimensions, all initialized to zero. - /// \param dims Number of dimensions - /// \return New Shape object - static Shape zeros(size_t dims); - - /// Return a Shape object with \p dims dimensions, all initialized to one. - /// \param dims Number of dimensions - /// \return New Shape object - static Shape ones(size_t dims); - - /// Access an element of the Shape object - /// \tparam Index Typename of the index - /// \param index Index to access - /// \return The value at the index - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const T &operator[](Index index) const; - - /// Access an element of the Shape object - /// \tparam Index Typename of the index - /// \param index Index to access - /// \return A reference to the value at the index - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T &operator[](Index index); - - /// Compare two Shape objects, returning true if and only if they are identical - /// \param other Shape object to compare - /// \return true if the objects are identical - LIBRAPID_ALWAYS_INLINE bool operator==(const Shape &other) const; - - /// Compare two Shape objects, returning true if and only if they are not identical - /// \param other Shape object to compare - /// \return true if the objects are not identical - LIBRAPID_ALWAYS_INLINE bool operator!=(const Shape &other) const; - - /// Return the number of dimensions in the Shape object - /// \return Number of dimensions - LIBRAPID_NODISCARD T ndim() const; - - /// Return a subshape of the Shape object - /// \param start Starting index - /// \param end Ending index - /// \return Subshape - LIBRAPID_NODISCARD Shape subshape(size_t start, size_t end) const; - - /// Return the number of elements the Shape object represents - /// \return Number of elements - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T size() const; - - /// Convert a Shape object into a string representation - /// \return A string representation of the Shape object - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; - - protected: - T m_dims; - std::array m_data; - }; - - namespace detail { - template - Shape shapeFromFixedStorage(const FixedStorage &) { - return Shape({Dims...}); - } - } // namespace detail - - template - template - Shape::Shape(const FixedStorage &) : m_data({Dimensions...}) {} - - template - template::value>> - Shape::Shape(const std::initializer_list &vals) : m_dims(vals.size()) { - LIBRAPID_ASSERT(vals.size() <= N, - "Shape object is limited to {} dimensions. Cannot initialize with {}", - N, - vals.size()); - for (size_t i = 0; i < vals.size(); ++i) { m_data[i] = *(vals.begin() + i); } - } - - template - template::value>> - Shape::Shape(const std::vector &vals) : m_dims(vals.size()) { - LIBRAPID_ASSERT(vals.size() <= N, - "Shape object is limited to {} dimensions. Cannot initialize with {}", - N, - vals.size()); - for (size_t i = 0; i < vals.size(); ++i) { m_data[i] = vals[i]; } - } - - template - template - Shape::Shape(const Shape &other) : m_dims(other.ndim()) { - LIBRAPID_ASSERT(other.ndim() <= N, - "Shape object is limited to {} dimensions. Cannot initialize with {}", - N, - other.ndim()); - for (size_t i = 0; i < m_dims; ++i) { m_data[i] = other[i]; } - } - - template - template - Shape::Shape(Shape &&other) noexcept : m_dims(other.ndim()) { - LIBRAPID_ASSERT(other.ndim() <= N, - "Shape object is limited to {} dimensions. Cannot initialize with {}", - N, - other.ndim()); - for (size_t i = 0; i < m_dims; ++i) { m_data[i] = other[i]; } - } - - template - template::value>> - Shape &Shape::operator=(const std::initializer_list &vals) { - LIBRAPID_ASSERT(vals.size() <= N, - "Shape object is limited to {} dimensions. Cannot initialize with {}", - N, - vals.size()); - m_dims = vals.size(); - for (int64_t i = 0; i < vals.size(); ++i) { m_data[i] = *(vals.begin() + i); } - return *this; - } - - template - template::value>> - Shape &Shape::operator=(const std::vector &vals) { - LIBRAPID_ASSERT(vals.size() <= N, - "Shape object is limited to {} dimensions. Cannot initialize with {}", - N, - vals.size()); - m_dims = vals.size(); - for (int64_t i = 0; i < vals.size(); ++i) { m_data[i] = vals[i]; } - return *this; - } - - template - Shape Shape::zeros(size_t dims) { - Shape res; - res.m_dims = dims; - for (size_t i = 0; i < dims; ++i) res.m_data[i] = 0; - return res; - } - - template - Shape Shape::ones(size_t dims) { - Shape res; - res.m_dims = dims; - for (size_t i = 0; i < dims; ++i) res.m_data[i] = 1; - return res; - } - - template - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const T &Shape::operator[](Index index) const { - LIBRAPID_ASSERT(static_cast(index) < m_dims, "Index out of bounds"); - LIBRAPID_ASSERT(index >= 0, "Index out of bounds"); - return m_data[index]; - } - - template - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T &Shape::operator[](Index index) { - LIBRAPID_ASSERT(static_cast(index) < m_dims, "Index out of bounds"); - LIBRAPID_ASSERT(index >= 0, "Index out of bounds"); - return m_data[index]; - } - - template - LIBRAPID_ALWAYS_INLINE bool Shape::operator==(const Shape &other) const { - if (m_dims != other.m_dims) return false; - for (size_t i = 0; i < m_dims; ++i) { - if (m_data[i] != other.m_data[i]) return false; - } - return true; - } - - template - LIBRAPID_ALWAYS_INLINE bool Shape::operator!=(const Shape &other) const { - return !(*this == other); - } - - template - LIBRAPID_NODISCARD T Shape::ndim() const { - return m_dims; - } - - template - LIBRAPID_NODISCARD auto Shape::subshape(size_t start, size_t end) const -> Shape { - LIBRAPID_ASSERT(start <= end, "Start index must be less than end index"); - LIBRAPID_ASSERT(end <= m_dims, - "End index must be less than or equal to the number of dimensions"); - LIBRAPID_ASSERT(start >= 0, "Start index must be greater than or equal to 0"); - - Shape res; - res.m_dims = end - start; - for (size_t i = 0; i < res.m_dims; ++i) res.m_data[i] = m_data[i + start]; - return res; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T Shape::size() const { - T res = 1; - for (size_t i = 0; i < m_dims; ++i) res *= m_data[i]; - return res; - } - - template - std::string Shape::str(const std::string &format) const { - std::string result("("); - for (size_t i = 0; i < m_dims; ++i) { - result += fmt::format(format, m_data[i]); - if (i < m_dims - 1) result += std::string(", "); - } - return std::operator+(result, std::string(")")); - } - - /// Returns true if all inputs have the same shape - /// \tparam T1 Type of the first input - /// \tparam N1 Number of dimensions of the first input - /// \tparam T2 Type of the second input - /// \tparam N2 Number of dimensions of the second input - /// \tparam Tn Type of the remaining (optional) inputs - /// \tparam Nn Number of dimensions of the remaining (optional) inputs - /// \param first First input - /// \param second Second input - /// \param shapes Remaining (optional) inputs - /// \return True if all inputs have the same shape, false otherwise - template - LIBRAPID_NODISCARD LIBRAPID_INLINE bool shapesMatch(const Shape &first, - const Shape &second, - const Shape &...shapes) { - if constexpr (sizeof...(Tn) == 0) { - return first == second; - } else { - return first == second && shapesMatch(first, shapes...); - } - } - - /// \sa shapesMatch - template - LIBRAPID_NODISCARD LIBRAPID_INLINE bool - shapesMatch(const std::tuple, Shape, Shape...> &shapes) { - if constexpr (sizeof...(Tn) == 0) { - return std::get<0>(shapes) == std::get<1>(shapes); - } else { - return std::get<0>(shapes) == std::get<1>(shapes) && - shapesMatch(std::apply( - [](auto, auto, auto... rest) { return std::make_tuple(rest...); }, shapes)); - } - } - - namespace typetraits { - template - struct IsSizeType { - using value = std::false_type; - }; - - template - struct IsSizeType> { - using value = std::true_type; - }; - } // namespace typetraits + namespace typetraits { + LIBRAPID_DEFINE_AS_TYPE(typename T COMMA size_t N, Shape); + } + + template + class Shape { + public: + using SizeType = T; + static constexpr size_t MaxDimensions = N; + + /// Default constructor + Shape() = default; + + /// Create a shape object from the dimensions of a FixedStorage object. This is used + // mainly internally, but may serve some purpose I haven't yet thought of. + /// \tparam Scalar Scalar type of the FixedStorage object + /// \tparam Dimensions Dimensions of the FixedStorage object + /// \param fixed The FixedStorage object + template + explicit Shape(const FixedStorage &fixed); + + /// Create a Shape object from a list of values + /// \tparam V Scalar type of the values + /// \param vals The dimensions for the object + template::value> = 0> + Shape(const std::initializer_list &vals); + + /// Create a Shape object from a vector of values + /// \tparam V Scalar type of the values + /// \param vals The dimensions for the object + template::value> = 0> + explicit Shape(const std::vector &vals); + + /// Create a copy of a Shape object + /// \param other Shape object to copy + Shape(const Shape &other) = default; + + /// Create a Shape from an RValue + /// \param other Temporary Shape object to copy + Shape(Shape &&other) noexcept = default; + + /// Create a Shape object from one with a different type and number of dimensions. + /// \tparam V Scalar type of the values + /// \tparam Dim Number of dimensions + /// \param other Shape object to copy + template + Shape(const Shape &other); + + /// Create a Shape object from one with a different type and number of dimensions, moving it + /// instead of copying it. + /// \tparam V Scalar type of the values + /// \tparam Dim Number of dimensions + /// \param other Temporary Shape object to move + template + Shape(Shape &&other) noexcept; + + /// Assign a Shape object to this object + /// \tparam V Scalar type of the Shape + /// \param vals Dimensions of the Shape + /// \return *this + template::value> = 0> + Shape &operator=(const std::initializer_list &vals); + + /// Assign a Shape object to this object + /// \tparam V Scalar type of the Shape + /// \param vals Dimensions of the Shape + /// \return *this + template::value> = 0> + Shape &operator=(const std::vector &vals); + + /// Assign an RValue Shape to this object + /// \param other RValue to move + /// \return + Shape &operator=(Shape &&other) noexcept = default; + + /// Assign a Shape to this object + /// \param other Shape to copy + /// \return + Shape &operator=(const Shape &other) = default; + + /// Return a Shape object with \p dims dimensions, all initialized to zero. + /// \param dims Number of dimensions + /// \return New Shape object + static Shape zeros(size_t dims); + + /// Return a Shape object with \p dims dimensions, all initialized to one. + /// \param dims Number of dimensions + /// \return New Shape object + static Shape ones(size_t dims); + + /// Access an element of the Shape object + /// \tparam Index Typename of the index + /// \param index Index to access + /// \return The value at the index + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const T &operator[](Index index) const; + + /// Access an element of the Shape object + /// \tparam Index Typename of the index + /// \param index Index to access + /// \return A reference to the value at the index + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T &operator[](Index index); + + /// Compare two Shape objects, returning true if and only if they are identical + /// \param other Shape object to compare + /// \return true if the objects are identical + LIBRAPID_ALWAYS_INLINE bool operator==(const Shape &other) const; + + /// Compare two Shape objects, returning true if and only if they are not identical + /// \param other Shape object to compare + /// \return true if the objects are not identical + LIBRAPID_ALWAYS_INLINE bool operator!=(const Shape &other) const; + + /// Return the number of dimensions in the Shape object + /// \return Number of dimensions + LIBRAPID_NODISCARD T ndim() const; + + /// Return a subshape of the Shape object + /// \param start Starting index + /// \param end Ending index + /// \return Subshape + LIBRAPID_NODISCARD Shape subshape(size_t start, size_t end) const; + + /// Return the number of elements the Shape object represents + /// \return Number of elements + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T size() const; + + /// Convert a Shape object into a string representation + /// \return A string representation of the Shape object + LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; + + protected: + T m_dims; + std::array m_data; + }; + + namespace detail { + template + Shape shapeFromFixedStorage(const FixedStorage &) { + return Shape({Dims...}); + } + } // namespace detail + + template + template + Shape::Shape(const FixedStorage &) : m_data({Dimensions...}) {} + + template + template::value>> + Shape::Shape(const std::initializer_list &vals) : m_dims(vals.size()) { + LIBRAPID_ASSERT(vals.size() <= N, + "Shape object is limited to {} dimensions. Cannot initialize with {}", + N, + vals.size()); + for (size_t i = 0; i < vals.size(); ++i) { m_data[i] = *(vals.begin() + i); } + } + + template + template::value>> + Shape::Shape(const std::vector &vals) : m_dims(vals.size()) { + LIBRAPID_ASSERT(vals.size() <= N, + "Shape object is limited to {} dimensions. Cannot initialize with {}", + N, + vals.size()); + for (size_t i = 0; i < vals.size(); ++i) { m_data[i] = vals[i]; } + } + + template + template + Shape::Shape(const Shape &other) : m_dims(other.ndim()) { + LIBRAPID_ASSERT(other.ndim() <= N, + "Shape object is limited to {} dimensions. Cannot initialize with {}", + N, + other.ndim()); + for (size_t i = 0; i < m_dims; ++i) { m_data[i] = other[i]; } + } + + template + template + Shape::Shape(Shape &&other) noexcept : m_dims(other.ndim()) { + LIBRAPID_ASSERT(other.ndim() <= N, + "Shape object is limited to {} dimensions. Cannot initialize with {}", + N, + other.ndim()); + for (size_t i = 0; i < m_dims; ++i) { m_data[i] = other[i]; } + } + + template + template::value>> + Shape &Shape::operator=(const std::initializer_list &vals) { + LIBRAPID_ASSERT(vals.size() <= N, + "Shape object is limited to {} dimensions. Cannot initialize with {}", + N, + vals.size()); + m_dims = vals.size(); + for (int64_t i = 0; i < vals.size(); ++i) { m_data[i] = *(vals.begin() + i); } + return *this; + } + + template + template::value>> + Shape &Shape::operator=(const std::vector &vals) { + LIBRAPID_ASSERT(vals.size() <= N, + "Shape object is limited to {} dimensions. Cannot initialize with {}", + N, + vals.size()); + m_dims = vals.size(); + for (int64_t i = 0; i < vals.size(); ++i) { m_data[i] = vals[i]; } + return *this; + } + + template + Shape Shape::zeros(size_t dims) { + Shape res; + res.m_dims = dims; + for (size_t i = 0; i < dims; ++i) res.m_data[i] = 0; + return res; + } + + template + Shape Shape::ones(size_t dims) { + Shape res; + res.m_dims = dims; + for (size_t i = 0; i < dims; ++i) res.m_data[i] = 1; + return res; + } + + template + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const T &Shape::operator[](Index index) const { + LIBRAPID_ASSERT(static_cast(index) < m_dims, "Index out of bounds"); + LIBRAPID_ASSERT(index >= 0, "Index out of bounds"); + return m_data[index]; + } + + template + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T &Shape::operator[](Index index) { + LIBRAPID_ASSERT(static_cast(index) < m_dims, "Index out of bounds"); + LIBRAPID_ASSERT(index >= 0, "Index out of bounds"); + return m_data[index]; + } + + template + LIBRAPID_ALWAYS_INLINE bool Shape::operator==(const Shape &other) const { + if (m_dims != other.m_dims) return false; + for (size_t i = 0; i < m_dims; ++i) { + if (m_data[i] != other.m_data[i]) return false; + } + return true; + } + + template + LIBRAPID_ALWAYS_INLINE bool Shape::operator!=(const Shape &other) const { + return !(*this == other); + } + + template + LIBRAPID_NODISCARD T Shape::ndim() const { + return m_dims; + } + + template + LIBRAPID_NODISCARD auto Shape::subshape(size_t start, size_t end) const -> Shape { + LIBRAPID_ASSERT(start <= end, "Start index must be less than end index"); + LIBRAPID_ASSERT(end <= m_dims, + "End index must be less than or equal to the number of dimensions"); + LIBRAPID_ASSERT(start >= 0, "Start index must be greater than or equal to 0"); + + Shape res; + res.m_dims = end - start; + for (size_t i = 0; i < res.m_dims; ++i) res.m_data[i] = m_data[i + start]; + return res; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T Shape::size() const { + T res = 1; + for (size_t i = 0; i < m_dims; ++i) res *= m_data[i]; + return res; + } + + template + std::string Shape::str(const std::string &format) const { + std::string result("("); + for (size_t i = 0; i < m_dims; ++i) { + result += fmt::format(format, m_data[i]); + if (i < m_dims - 1) result += std::string(", "); + } + return std::operator+(result, std::string(")")); + } + + /// Returns true if all inputs have the same shape + /// \tparam T1 Type of the first input + /// \tparam N1 Number of dimensions of the first input + /// \tparam T2 Type of the second input + /// \tparam N2 Number of dimensions of the second input + /// \tparam Tn Type of the remaining (optional) inputs + /// \tparam Nn Number of dimensions of the remaining (optional) inputs + /// \param first First input + /// \param second Second input + /// \param shapes Remaining (optional) inputs + /// \return True if all inputs have the same shape, false otherwise + template + LIBRAPID_NODISCARD LIBRAPID_INLINE bool shapesMatch(const Shape &first, + const Shape &second, + const Shape &...shapes) { + if constexpr (sizeof...(Tn) == 0) { + return first == second; + } else { + return first == second && shapesMatch(first, shapes...); + } + } + + /// \sa shapesMatch + template + LIBRAPID_NODISCARD LIBRAPID_INLINE bool + shapesMatch(const std::tuple, Shape, Shape...> &shapes) { + if constexpr (sizeof...(Tn) == 0) { + return std::get<0>(shapes) == std::get<1>(shapes); + } else { + return std::get<0>(shapes) == std::get<1>(shapes) && + shapesMatch(std::apply( + [](auto, auto, auto... rest) { return std::make_tuple(rest...); }, shapes)); + } + } + + namespace typetraits { + template + struct IsSizeType { + using value = std::false_type; + }; + + template + struct IsSizeType> { + using value = std::true_type; + }; + } // namespace typetraits } // namespace librapid // Support FMT printing diff --git a/librapid/include/librapid/array/storage.hpp b/librapid/include/librapid/array/storage.hpp index 47253b23..64a360f8 100644 --- a/librapid/include/librapid/array/storage.hpp +++ b/librapid/include/librapid/array/storage.hpp @@ -7,875 +7,875 @@ */ namespace librapid { - namespace typetraits { - template - struct TypeInfo> { - static constexpr bool isLibRapidType = true; - using Scalar = Scalar_; - using Backend = backend::CPU; - }; - - template - struct TypeInfo> { - static constexpr bool isLibRapidType = true; - using Scalar = Scalar_; - using Backend = backend::CPU; - }; - - LIBRAPID_DEFINE_AS_TYPE(typename Scalar, Storage); - } // namespace typetraits - - template - class Storage { - public: - using Scalar = Scalar_; - using Packet = typename typetraits::TypeInfo::Packet; - static constexpr uint64_t packetWidth = typetraits::TypeInfo::packetWidth; - using RawPointer = Scalar *; - using ConstRawPointer = const Scalar *; - using Pointer = std::shared_ptr; - using ConstPointer = std::shared_ptr; - using Reference = Scalar &; - using ConstReference = const Scalar &; - using SizeType = size_t; - using DifferenceType = ptrdiff_t; - using Iterator = RawPointer; - using ConstIterator = ConstRawPointer; - using ReverseIterator = std::reverse_iterator; - using ConstReverseIterator = std::reverse_iterator; - - /// Default constructor - Storage() = default; - - /// Create a Storage object with \p size elements - /// \param size Number of elements to allocate - LIBRAPID_ALWAYS_INLINE explicit Storage(SizeType size); - - LIBRAPID_ALWAYS_INLINE explicit Storage(Scalar *begin, Scalar *end, bool ownsData); - - /// Create a Storage object with \p size elements, each initialized - /// to \p value. - /// \param size Number of elements to allocate - /// \param value Value to initialize each element to - LIBRAPID_ALWAYS_INLINE Storage(SizeType size, ConstReference value); - - /// Create a Storage object from another Storage object. Additionally - /// a custom allocator can be used. The data is **NOT** copied -- it is referenced - /// by the new Storage object. For a deep copy, use the ``copy()`` method. - /// \param other Storage object to copy - LIBRAPID_ALWAYS_INLINE Storage(const Storage &other); - - /// Move a Storage object into this object. - /// \param other Storage object to move - LIBRAPID_ALWAYS_INLINE Storage(Storage &&other) noexcept; - - /// Create a Storage object from an std::initializer_list - /// \tparam V Type of the elements in the initializer list - /// \param list Initializer list to copy - /// \param alloc Allocator to use - template - LIBRAPID_ALWAYS_INLINE Storage(const std::initializer_list &list); - - /// Create a Storage object from a std::vector - /// \tparam V Type of the elements in the vector - /// \param vec Vector to copy - template - LIBRAPID_ALWAYS_INLINE explicit Storage(const std::vector &vec); - - template - static Storage fromData(const std::initializer_list &vec); - - template - static Storage fromData(const std::vector &vec); - - /// Assignment operator for a Storage object - /// \param other Storage object to copy - /// \return *this - LIBRAPID_ALWAYS_INLINE Storage &operator=(const Storage &other); - - /// Move assignment operator for a Storage object - /// \param other Storage object to move - /// \return *this - LIBRAPID_ALWAYS_INLINE Storage &operator=(Storage &&other) LIBRAPID_RELEASE_NOEXCEPT; - - /// Free a Storage object - ~Storage(); - - /// \brief Set this storage object to reference the same data as \p other - /// \param other Storage object to reference - void set(const Storage &other); - - /// \brief Return a Storage object on the host with the same data as this Storage object - /// (mainly for use with CUDA or OpenCL) - /// \return - Storage toHostStorage() const; - - /// \brief Same as `toHostStorage()` but does not necessarily copy the data - /// \return Storage object on the host - Storage toHostStorageUnsafe() const; - - /// \brief Create a deep copy of this Storage object - /// \return Deep copy of this Storage object - Storage copy() const; - - template - static ShapeType defaultShape(); - - /// Resize a Storage object to \p size elements. Existing elements - /// are preserved. - /// \param size New size of the Storage object - LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize); - - /// Resize a Storage object to \p size elements. Existing elements - /// are not preserved - /// \param size New size of the Storage object - LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize, int); - - /// Return the number of elements in the Storage object - /// \return - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE SizeType size() const noexcept; - - /// Const access to the element at index \p index - /// \param index Index of the element to access - /// \return Const reference to the element at index \p index - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReference operator[](SizeType index) const; - - /// Access to the element at index \p index - /// \param index Index of the element to access - /// \return Reference to the element at index \p index - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Reference operator[](SizeType index); - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer data() const noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE RawPointer begin() noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE RawPointer end() noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator begin() const noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator end() const noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator cbegin() const noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator cend() const noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ReverseIterator rbegin() noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ReverseIterator rend() noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator rbegin() const noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator rend() const noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator crbegin() const noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator crend() const noexcept; - - private: - /// Copy data from \p begin to \p end into this Storage object - /// \tparam P Pointer type - /// \param begin Beginning of data to copy - /// \param end End of data to copy - template - LIBRAPID_ALWAYS_INLINE void initData(P begin, P end); - - template - LIBRAPID_ALWAYS_INLINE void initData(P begin, SizeType size); - - /// Resize the Storage Object to \p newSize elements, retaining existing - /// data. - /// \param newSize New size of the Storage object - LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize, int); - - /// Resize the Storage object to \p newSize elements. Note this does not - /// initialize the new elements or maintain existing data. - /// \param newSize New size of the Storage object - LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize); - - // #if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE) - // alignas(LIBRAPID_DEFAULT_MEM_ALIGN) Pointer m_begin = nullptr; - // #else - // Pointer m_begin = nullptr; // Pointer to the beginning of the data - // #endif - - Pointer m_begin = nullptr; - - SizeType m_size = 0; // Number of elements in the Storage object - bool m_ownsData = true; // Whether this Storage object owns the data it points to - }; - - template - class FixedStorage { - public: - using Scalar = Scalar_; - using Pointer = Scalar *; - using ConstPointer = const Scalar *; - using Reference = Scalar &; - using ConstReference = const Scalar &; - using SizeType = size_t; - using DifferenceType = ptrdiff_t; - static constexpr SizeType Size = product(); - using Iterator = typename std::array()>::iterator; - using ConstIterator = typename std::array()>::const_iterator; - using ReverseIterator = std::reverse_iterator; - using ConstReverseIterator = std::reverse_iterator; - - /// Default constructor - FixedStorage(); - - /// Create a FixedStorage object filled with \p value - /// \param value Value to fill the FixedStorage object with - LIBRAPID_ALWAYS_INLINE explicit FixedStorage(const Scalar &value); - - /// Create a FixedStorage object from another FixedStorage object - /// \param other FixedStorage object to copy - LIBRAPID_ALWAYS_INLINE FixedStorage(const FixedStorage &other); - - /// Move constructor for a FixedStorage object - /// \param other FixedStorage object to move - LIBRAPID_ALWAYS_INLINE FixedStorage(FixedStorage &&other) noexcept; - - /// Create a FixedStorage object from a std::initializer_list - /// \tparam V Type of the elements in the initializer list - /// \param list Initializer list to copy - LIBRAPID_ALWAYS_INLINE explicit FixedStorage(const std::initializer_list &list); - - /// Create a FixedStorage object from a std::vector - /// \tparam V Type of the elements in the vector - /// \param vec Vector to copy - LIBRAPID_ALWAYS_INLINE explicit FixedStorage(const std::vector &vec); - - /// Assignment operator for a FixedStorage object - /// \param other FixedStorage object to copy - /// \return *this - LIBRAPID_ALWAYS_INLINE FixedStorage &operator=(const FixedStorage &other); - - /// Move assignment operator for a FixedStorage object - /// \param other FixedStorage object to move - /// \return *this - LIBRAPID_ALWAYS_INLINE FixedStorage &operator=(FixedStorage &&other) noexcept; - - /// Free a FixedStorage object - ~FixedStorage() = default; - - template - static ShapeType defaultShape(); - - /// Resize a Storage object to \p size elements. Existing elements - /// are preserved. - /// \param size New size of the Storage object - LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize); - - /// Resize a Storage object to \p size elements. Existing elements - /// are not preserved - /// \param size New size of the Storage object - LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize, int); - - /// Return the number of elements in the FixedStorage object - /// \return Number of elements in the FixedStorage object - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE SizeType size() const noexcept; - - /// \brief Create a copy of the FixedStorage object - /// \return Copy of the FixedStorage object - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE FixedStorage copy() const; - - /// Const access to the element at index \p index - /// \param index Index of the element to access - /// \return Const reference to the element at index \p index - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReference operator[](SizeType index) const; - - /// Access to the element at index \p index - /// \param index Index of the element to access - /// \return Reference to the element at index \p index - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Reference operator[](SizeType index); - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer data() const noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator begin() const noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator end() const noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator cbegin() const noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator cend() const noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ReverseIterator rbegin() noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ReverseIterator rend() noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator rbegin() const noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator rend() const noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator crbegin() const noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator crend() const noexcept; - - private: + namespace typetraits { + template + struct TypeInfo> { + static constexpr bool isLibRapidType = true; + using Scalar = Scalar_; + using Backend = backend::CPU; + }; + + template + struct TypeInfo> { + static constexpr bool isLibRapidType = true; + using Scalar = Scalar_; + using Backend = backend::CPU; + }; + + LIBRAPID_DEFINE_AS_TYPE(typename Scalar, Storage); + } // namespace typetraits + + template + class Storage { + public: + using Scalar = Scalar_; + using Packet = typename typetraits::TypeInfo::Packet; + static constexpr uint64_t packetWidth = typetraits::TypeInfo::packetWidth; + using RawPointer = Scalar *; + using ConstRawPointer = const Scalar *; + using Pointer = std::shared_ptr; + using ConstPointer = std::shared_ptr; + using Reference = Scalar &; + using ConstReference = const Scalar &; + using SizeType = size_t; + using DifferenceType = ptrdiff_t; + using Iterator = RawPointer; + using ConstIterator = ConstRawPointer; + using ReverseIterator = std::reverse_iterator; + using ConstReverseIterator = std::reverse_iterator; + + /// Default constructor + Storage() = default; + + /// Create a Storage object with \p size elements + /// \param size Number of elements to allocate + LIBRAPID_ALWAYS_INLINE explicit Storage(SizeType size); + + LIBRAPID_ALWAYS_INLINE explicit Storage(Scalar *begin, Scalar *end, bool ownsData); + + /// Create a Storage object with \p size elements, each initialized + /// to \p value. + /// \param size Number of elements to allocate + /// \param value Value to initialize each element to + LIBRAPID_ALWAYS_INLINE Storage(SizeType size, ConstReference value); + + /// Create a Storage object from another Storage object. Additionally + /// a custom allocator can be used. The data is **NOT** copied -- it is referenced + /// by the new Storage object. For a deep copy, use the ``copy()`` method. + /// \param other Storage object to copy + LIBRAPID_ALWAYS_INLINE Storage(const Storage &other); + + /// Move a Storage object into this object. + /// \param other Storage object to move + LIBRAPID_ALWAYS_INLINE Storage(Storage &&other) noexcept; + + /// Create a Storage object from an std::initializer_list + /// \tparam V Type of the elements in the initializer list + /// \param list Initializer list to copy + /// \param alloc Allocator to use + template + LIBRAPID_ALWAYS_INLINE Storage(const std::initializer_list &list); + + /// Create a Storage object from a std::vector + /// \tparam V Type of the elements in the vector + /// \param vec Vector to copy + template + LIBRAPID_ALWAYS_INLINE explicit Storage(const std::vector &vec); + + template + static Storage fromData(const std::initializer_list &vec); + + template + static Storage fromData(const std::vector &vec); + + /// Assignment operator for a Storage object + /// \param other Storage object to copy + /// \return *this + LIBRAPID_ALWAYS_INLINE Storage &operator=(const Storage &other); + + /// Move assignment operator for a Storage object + /// \param other Storage object to move + /// \return *this + LIBRAPID_ALWAYS_INLINE Storage &operator=(Storage &&other) LIBRAPID_RELEASE_NOEXCEPT; + + /// Free a Storage object + ~Storage(); + + /// \brief Set this storage object to reference the same data as \p other + /// \param other Storage object to reference + void set(const Storage &other); + + /// \brief Return a Storage object on the host with the same data as this Storage object + /// (mainly for use with CUDA or OpenCL) + /// \return + Storage toHostStorage() const; + + /// \brief Same as `toHostStorage()` but does not necessarily copy the data + /// \return Storage object on the host + Storage toHostStorageUnsafe() const; + + /// \brief Create a deep copy of this Storage object + /// \return Deep copy of this Storage object + Storage copy() const; + + template + static ShapeType defaultShape(); + + /// Resize a Storage object to \p size elements. Existing elements + /// are preserved. + /// \param size New size of the Storage object + LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize); + + /// Resize a Storage object to \p size elements. Existing elements + /// are not preserved + /// \param size New size of the Storage object + LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize, int); + + /// Return the number of elements in the Storage object + /// \return + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE SizeType size() const noexcept; + + /// Const access to the element at index \p index + /// \param index Index of the element to access + /// \return Const reference to the element at index \p index + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReference operator[](SizeType index) const; + + /// Access to the element at index \p index + /// \param index Index of the element to access + /// \return Reference to the element at index \p index + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Reference operator[](SizeType index); + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer data() const noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE RawPointer begin() noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE RawPointer end() noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator begin() const noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator end() const noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator cbegin() const noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator cend() const noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ReverseIterator rbegin() noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ReverseIterator rend() noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator rbegin() const noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator rend() const noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator crbegin() const noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator crend() const noexcept; + + private: + /// Copy data from \p begin to \p end into this Storage object + /// \tparam P Pointer type + /// \param begin Beginning of data to copy + /// \param end End of data to copy + template + LIBRAPID_ALWAYS_INLINE void initData(P begin, P end); + + template + LIBRAPID_ALWAYS_INLINE void initData(P begin, SizeType size); + + /// Resize the Storage Object to \p newSize elements, retaining existing + /// data. + /// \param newSize New size of the Storage object + LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize, int); + + /// Resize the Storage object to \p newSize elements. Note this does not + /// initialize the new elements or maintain existing data. + /// \param newSize New size of the Storage object + LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize); + + // #if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE) + // alignas(LIBRAPID_DEFAULT_MEM_ALIGN) Pointer m_begin = nullptr; + // #else + // Pointer m_begin = nullptr; // Pointer to the beginning of the data + // #endif + + Pointer m_begin = nullptr; + + SizeType m_size = 0; // Number of elements in the Storage object + bool m_ownsData = true; // Whether this Storage object owns the data it points to + }; + + template + class FixedStorage { + public: + using Scalar = Scalar_; + using Pointer = Scalar *; + using ConstPointer = const Scalar *; + using Reference = Scalar &; + using ConstReference = const Scalar &; + using SizeType = size_t; + using DifferenceType = ptrdiff_t; + static constexpr SizeType Size = product(); + using Iterator = typename std::array()>::iterator; + using ConstIterator = typename std::array()>::const_iterator; + using ReverseIterator = std::reverse_iterator; + using ConstReverseIterator = std::reverse_iterator; + + /// Default constructor + FixedStorage(); + + /// Create a FixedStorage object filled with \p value + /// \param value Value to fill the FixedStorage object with + LIBRAPID_ALWAYS_INLINE explicit FixedStorage(const Scalar &value); + + /// Create a FixedStorage object from another FixedStorage object + /// \param other FixedStorage object to copy + LIBRAPID_ALWAYS_INLINE FixedStorage(const FixedStorage &other); + + /// Move constructor for a FixedStorage object + /// \param other FixedStorage object to move + LIBRAPID_ALWAYS_INLINE FixedStorage(FixedStorage &&other) noexcept; + + /// Create a FixedStorage object from a std::initializer_list + /// \tparam V Type of the elements in the initializer list + /// \param list Initializer list to copy + LIBRAPID_ALWAYS_INLINE explicit FixedStorage(const std::initializer_list &list); + + /// Create a FixedStorage object from a std::vector + /// \tparam V Type of the elements in the vector + /// \param vec Vector to copy + LIBRAPID_ALWAYS_INLINE explicit FixedStorage(const std::vector &vec); + + /// Assignment operator for a FixedStorage object + /// \param other FixedStorage object to copy + /// \return *this + LIBRAPID_ALWAYS_INLINE FixedStorage &operator=(const FixedStorage &other); + + /// Move assignment operator for a FixedStorage object + /// \param other FixedStorage object to move + /// \return *this + LIBRAPID_ALWAYS_INLINE FixedStorage &operator=(FixedStorage &&other) noexcept; + + /// Free a FixedStorage object + ~FixedStorage() = default; + + template + static ShapeType defaultShape(); + + /// Resize a Storage object to \p size elements. Existing elements + /// are preserved. + /// \param size New size of the Storage object + LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize); + + /// Resize a Storage object to \p size elements. Existing elements + /// are not preserved + /// \param size New size of the Storage object + LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize, int); + + /// Return the number of elements in the FixedStorage object + /// \return Number of elements in the FixedStorage object + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE SizeType size() const noexcept; + + /// \brief Create a copy of the FixedStorage object + /// \return Copy of the FixedStorage object + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE FixedStorage copy() const; + + /// Const access to the element at index \p index + /// \param index Index of the element to access + /// \return Const reference to the element at index \p index + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReference operator[](SizeType index) const; + + /// Access to the element at index \p index + /// \param index Index of the element to access + /// \return Reference to the element at index \p index + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Reference operator[](SizeType index); + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer data() const noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator begin() const noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator end() const noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator cbegin() const noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator cend() const noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ReverseIterator rbegin() noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ReverseIterator rend() noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator rbegin() const noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator rend() const noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator crbegin() const noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstReverseIterator crend() const noexcept; + + private: #if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE) - alignas(LIBRAPID_DEFAULT_MEM_ALIGN) std::array m_data; + alignas(LIBRAPID_DEFAULT_MEM_ALIGN) std::array m_data; #else - // No memory alignment on Apple platforms or if it is disabled - std::array m_data; + // No memory alignment on Apple platforms or if it is disabled + std::array m_data; #endif - }; - - // Trait implementations - namespace typetraits { - template - struct IsStorage : std::false_type {}; - - template - struct IsStorage> : std::true_type {}; - - template - struct IsFixedStorage : std::false_type {}; - - template - struct IsFixedStorage> : std::true_type {}; - } // namespace typetraits - - namespace detail { - /// Safely deallocate memory for \p size elements, using an std::allocator \p alloc. If the - /// object cannot be trivially destroyed, the destructor will be called on each element of - /// the data, ensuring that it is safe to free the allocated memory. - /// \tparam A The allocator type - /// \param alloc The allocator object - /// \param ptr The pointer to free - /// \param size The number of elements of type \p in the memory block - template - void safeDeallocate(T *ptr, size_t size) { - if constexpr (!std::is_trivially_destructible_v) { - for (size_t i = 0; i < size; ++i) { ptr[i].~T(); } - } + }; + + // Trait implementations + namespace typetraits { + template + struct IsStorage : std::false_type {}; + + template + struct IsStorage> : std::true_type {}; + + template + struct IsFixedStorage : std::false_type {}; + + template + struct IsFixedStorage> : std::true_type {}; + } // namespace typetraits + + namespace detail { + /// Safely deallocate memory for \p size elements, using an std::allocator \p alloc. If the + /// object cannot be trivially destroyed, the destructor will be called on each element of + /// the data, ensuring that it is safe to free the allocated memory. + /// \tparam A The allocator type + /// \param alloc The allocator object + /// \param ptr The pointer to free + /// \param size The number of elements of type \p in the memory block + template + void safeDeallocate(T *ptr, size_t size) { + if constexpr (!std::is_trivially_destructible_v) { + for (size_t i = 0; i < size; ++i) { ptr[i].~T(); } + } #if defined(LIBRAPID_BLAS_MKLBLAS) - mkl_free(ptr); + mkl_free(ptr); #elif defined(LIBRAPID_APPLE) - free(ptr); + free(ptr); #elif defined(LIBRAPID_NATIVE_ARCH) && defined(LIBRAPID_MSVC) - _aligned_free(ptr); + _aligned_free(ptr); #else - free(ptr); + free(ptr); #endif - } - - /// Safely allocate memory for \p size elements using the allocator \p alloc. If the data - /// can be trivially default constructed, then the constructor is not called and no data - /// is initialized. Otherwise, the correct default constructor will be called for each - /// element in the data, making sure the returned pointer is safe to use. - /// \tparam A The allocator type to use - /// \param alloc The allocator object to use - /// \param size Number of elements to allocate - /// \return Pointer to the first element - /// \see safeDeallocate - template - std::shared_ptr safeAllocate(size_t size) { - using RawPointer = T *; - using Pointer = std::shared_ptr; + } + + /// Safely allocate memory for \p size elements using the allocator \p alloc. If the data + /// can be trivially default constructed, then the constructor is not called and no data + /// is initialized. Otherwise, the correct default constructor will be called for each + /// element in the data, making sure the returned pointer is safe to use. + /// \tparam A The allocator type to use + /// \param alloc The allocator object to use + /// \param size Number of elements to allocate + /// \return Pointer to the first element + /// \see safeDeallocate + template + std::shared_ptr safeAllocate(size_t size) { + using RawPointer = T *; + using Pointer = std::shared_ptr; #if defined(LIBRAPID_BLAS_MKLBLAS) - // MKL has its own memory allocation function - auto ptr = static_cast(mkl_malloc(size * sizeof(T), 64)); + // MKL has its own memory allocation function + auto ptr = static_cast(mkl_malloc(size * sizeof(T), 64)); #elif defined(LIBRAPID_APPLE) - // Use posix_memalign - void *_ptr; - auto err = posix_memalign(&_ptr, global::memoryAlignment, size * sizeof(T)); - LIBRAPID_ASSERT(err == 0, "posix_memalign failed with error code {}", err); - auto ptr = static_cast(_ptr); + // Use posix_memalign + void *_ptr; + auto err = posix_memalign(&_ptr, global::memoryAlignment, size * sizeof(T)); + LIBRAPID_ASSERT(err == 0, "posix_memalign failed with error code {}", err); + auto ptr = static_cast(_ptr); #elif defined(LIBRAPID_MSVC) || defined(LIBRAPID_MINGW) - auto ptr = - static_cast(_aligned_malloc(size * sizeof(T), global::memoryAlignment)); + auto ptr = + static_cast(_aligned_malloc(size * sizeof(T), global::memoryAlignment)); #else - auto ptr = static_cast( - std::aligned_alloc(global::memoryAlignment, size * sizeof(T))); + auto ptr = static_cast( + std::aligned_alloc(global::memoryAlignment, size * sizeof(T))); #endif - LIBRAPID_ASSERT( - ptr != nullptr, "Failed to allocate {} bytes of memory", size * sizeof(T)); - - // If the type cannot be trivially constructed, we need to - // initialize each value - if constexpr (!typetraits::TriviallyDefaultConstructible::value && - !std::is_array::value) { - for (RawPointer p = ptr; p != ptr + size; ++p) { new (p) T(); } - } - - return Pointer(ptr, [size](RawPointer ptr) { safeDeallocate(ptr, size); }); - } - - /// Safely copy a pointer to a shared pointer. If \p ownsData is true, then the shared - /// pointer will be initialized with a custom deleter that will call safeDeallocate on the - /// pointer. Otherwise, the shared pointer will be initialized with a no-op deleter. - /// \tparam T Type of the pointer - /// \param ptr Raw pointer to copy - /// \param ownsData Whether the shared pointer should own the data - /// \return Shared pointer to the data - template - std::shared_ptr safePointerCopy(T *ptr, size_t size, bool ownsData) { - using RawPointer = T *; - using Pointer = std::shared_ptr; - - if (ownsData) { - return Pointer(ptr, [size](RawPointer ptr) { safeDeallocate(ptr, size); }); - } else { - return Pointer(ptr, [](RawPointer) {}); - } - } - - template - std::shared_ptr safePointerCopy(const std::shared_ptr &ptr, size_t size, - bool ownsData = true) { - using RawPointer = T *; - using Pointer = std::shared_ptr; - - if (ownsData) { - return Pointer(ptr.get(), [size](RawPointer ptr) { safeDeallocate(ptr, size); }); - } else { - return Pointer(ptr.get(), [](RawPointer) {}); - } - } - } // namespace detail - - template - Storage::Storage(SizeType size) : - m_begin(detail::safeAllocate(size)), m_size(size), m_ownsData(true) {} - - template - Storage::Storage(Scalar *begin, Scalar *end, bool ownsData) : - m_begin(detail::safePointerCopy(begin, std::distance(begin, end), ownsData)), - m_size(std::distance(begin, end)), m_ownsData(ownsData) {} - - template - Storage::Storage(SizeType size, ConstReference value) : - m_begin(detail::safeAllocate(size)), m_size(size), m_ownsData(true) { - for (SizeType i = 0; i < size; ++i) { m_begin.get()[i] = value; } - } - - template - Storage::Storage(const Storage &other) : - m_begin(other.m_begin), m_size(other.m_size), m_ownsData(other.m_ownsData) {} - - template - Storage::Storage(Storage &&other) noexcept : - m_begin(std::move(other.m_begin)), m_size(std::move(other.m_size)), - m_ownsData(std::move(other.m_ownsData)) { - other.m_begin = nullptr; - other.m_size = 0; - other.m_ownsData = false; - } - - template - template - Storage::Storage(const std::initializer_list &list) : - m_begin(nullptr), m_size(0), m_ownsData(true) { - initData(list.begin(), list.end()); - } - - template - template - Storage::Storage(const std::vector &vector) : - m_begin(nullptr), m_size(0), m_ownsData(true) { - initData(vector.begin(), vector.end()); - } - - template - template - auto Storage::fromData(const std::initializer_list &list) -> Storage { - return Storage(list); - } - - template - template - auto Storage::fromData(const std::vector &vec) -> Storage { - return Storage(vec); - } - - template - Storage &Storage::operator=(const Storage &other) { - if (this != &other) { - if (m_ownsData) { - // If we own the data already, we can just copy the pointer since we know it won't - // affect anything else. The shared pointer deals with the reference counting, so - // we don't need to worry about other arrays that might be using the same data. - m_begin = other.m_begin; - m_size = other.m_size; - } else { - LIBRAPID_ASSERT(m_size == other.m_size, - "Cannot copy storage with {} elements to dependent storage with " - "{} elements", - other.m_size, - m_size); - - // If we don't own the data, the size must be the same since it is being used - // elsewhere, and we can't change it - - if (typetraits::TriviallyDefaultConstructible::value) { - // Use a slightly faster memcpy if the type is trivially default constructible - std::uninitialized_copy(other.begin(), other.end(), m_begin.get()); - } else { - // Otherwise, use the standard copy algorithm - std::copy(other.begin(), other.end(), m_begin.get()); - } - } - } - return *this; - } - - template - Storage &Storage::operator=(Storage &&other) LIBRAPID_RELEASE_NOEXCEPT { - if (this != &other) { - if (m_ownsData) { - std::swap(m_begin, other.m_begin); - std::swap(m_size, other.m_size); - m_ownsData = other.m_ownsData; - } else { - LIBRAPID_ASSERT( - size() == other.size(), - "Mismatched storage sizes. Cannot assign storage with {} elements to " - "dependent storage with {} elements", - other.size(), - size()); - - if (typetraits::TriviallyDefaultConstructible::value) { - // Use a slightly faster memcpy if the type is trivially default constructible - std::uninitialized_copy(other.begin(), other.end(), m_begin.get()); - } else { - // Otherwise, use the standard copy algorithm - std::copy(other.begin(), other.end(), m_begin.get()); - } - } - } - return *this; - } - - template - Storage::~Storage() { - // All deallocation is handled by the shared pointer, which has a custom deleter which - // depends on whether the data is owned by the storage object or not. If it is owned, the - // data is deallocated, otherwise it is left alone. - } - - template - template - void Storage::initData(P begin, P end) { - m_size = static_cast(std::distance(begin, end)); - m_begin = detail::safeAllocate(m_size); - - if constexpr (typetraits::TypeInfo::canMemcpy) { - if constexpr (typetraits::TriviallyDefaultConstructible::value) { - // Use a slightly faster memcpy if the type is trivially default constructible - std::uninitialized_copy(begin, end, m_begin.get()); - } else { - // Otherwise, use the standard copy algorithm - std::copy(begin, end, m_begin.get()); - } - } else { - // Since we can't memcpy, we have to copy each element individually - for (SizeType i = 0; i < m_size; ++i) { m_begin.get()[i] = begin[i]; } - } - } - - template - template - void Storage::initData(P begin, SizeType size) { - initData(begin, begin + size); - } - - template - void Storage::set(const Storage &other) { - // We can simply copy the shared pointers across - m_begin = other.m_begin; - m_size = other.m_size; - m_ownsData = other.m_ownsData; - } - - template - auto Storage::toHostStorage() const -> Storage { - return copy(); - } - - template - auto Storage::toHostStorageUnsafe() const -> Storage { - return copy(); - } - - template - auto Storage::copy() const -> Storage { - Storage ret; - ret.initData(m_begin.get(), m_size); - return ret; - } - - template - template - auto Storage::defaultShape() -> ShapeType { - return ShapeType({0}); - } - - template - auto Storage::size() const noexcept -> SizeType { - return m_size; - } - - template - void Storage::resize(SizeType newSize) { - resizeImpl(newSize); - } - - template - void Storage::resize(SizeType newSize, int) { - resizeImpl(newSize, 0); - } - - template - LIBRAPID_ALWAYS_INLINE void Storage::resizeImpl(SizeType newSize) { - // Resize and retain data - - if (newSize == size()) return; - LIBRAPID_ASSERT(m_ownsData, "Dependent storage cannot be resized"); - - // Copy the existing data to a new location - Pointer oldBegin = m_begin; - SizeType oldSize = m_size; - - // Allocate a new block of memory - m_begin = detail::safeAllocate(newSize); - m_size = newSize; - - // Copy the data across - if constexpr (typetraits::TriviallyDefaultConstructible::value) { - // Use a slightly faster memcpy if the type is trivially default constructible - std::uninitialized_copy( - oldBegin.get(), oldBegin.get() + ::librapid::min(oldSize, newSize), m_begin.get()); - } else { - // Otherwise, use the standard copy algorithm - std::copy( - oldBegin.get(), oldBegin.get() + ::librapid::min(oldSize, newSize), m_begin.get()); - } - } - - template - LIBRAPID_ALWAYS_INLINE void Storage::resizeImpl(SizeType newSize, int) { - // Resize and discard data - - if (size() == newSize) return; - LIBRAPID_ASSERT(m_ownsData, "Dependent storage cannot be resized"); - - // Allocate a new block of memory - m_begin = detail::safeAllocate(newSize); - m_size = newSize; - } - - template - auto Storage::operator[](Storage::SizeType index) const -> ConstReference { - LIBRAPID_ASSERT(index < size(), "Index {} out of bounds for size {}", index, size()); - return m_begin.get()[index]; - } - - template - auto Storage::operator[](Storage::SizeType index) -> Reference { - LIBRAPID_ASSERT(index < size(), "Index {} out of bounds for size {}", index, size()); - return m_begin.get()[index]; - } - - template - auto Storage::data() const noexcept -> Pointer { - return m_begin; - } - - template - auto Storage::begin() noexcept -> RawPointer { - return m_begin.get(); - } - - template - auto Storage::end() noexcept -> RawPointer { - return m_begin.get() + m_size; - } - - template - auto Storage::begin() const noexcept -> ConstIterator { - return m_begin.get(); - } - - template - auto Storage::end() const noexcept -> ConstIterator { - return m_begin.get() + m_size; - } - - template - auto Storage::cbegin() const noexcept -> ConstIterator { - return begin(); - } - - template - auto Storage::cend() const noexcept -> ConstIterator { - return end(); - } - - template - auto Storage::rbegin() noexcept -> ReverseIterator { - return ReverseIterator(m_begin.get() + m_size); - } - - template - auto Storage::rend() noexcept -> ReverseIterator { - return ReverseIterator(m_begin.get()); - } - - template - auto Storage::rbegin() const noexcept -> ConstReverseIterator { - return ConstReverseIterator(m_begin.get() + m_size); - } - - template - auto Storage::rend() const noexcept -> ConstReverseIterator { - return ConstReverseIterator(m_begin.get()); - } - - template - auto Storage::crbegin() const noexcept -> ConstReverseIterator { - return rbegin(); - } - - template - auto Storage::crend() const noexcept -> ConstReverseIterator { - return rend(); - } - - template - FixedStorage::FixedStorage() = default; - - template - FixedStorage::FixedStorage(const Scalar &value) { - for (size_t i = 0; i < Size; ++i) { m_data[i] = value; } - } - - template - FixedStorage::FixedStorage(const FixedStorage &other) = default; - - template - FixedStorage::FixedStorage(FixedStorage &&other) noexcept = default; - - template - FixedStorage::FixedStorage(const std::initializer_list &list) { - LIBRAPID_ASSERT(list.size() == size(), "Initializer list size does not match storage size"); - for (size_t i = 0; i < Size; ++i) { m_data[i] = list.begin()[i]; } - } - - template - FixedStorage::FixedStorage(const std::vector &vec) { - LIBRAPID_ASSERT(vec.size() == size(), "Initializer list size does not match storage size"); - for (size_t i = 0; i < Size; ++i) { m_data[i] = vec[i]; } - } - - template - auto FixedStorage::operator=(const FixedStorage &other) -> FixedStorage & { - if (this != &other) { - for (size_t i = 0; i < Size; ++i) { m_data[i] = other.m_data[i]; } - } - return *this; - } - - template - auto FixedStorage::operator=(FixedStorage &&other) noexcept - -> FixedStorage & = default; - - template - template - auto FixedStorage::defaultShape() -> ShapeType { - return ShapeType({D...}); - } - - template - void FixedStorage::resize(SizeType newSize) { - LIBRAPID_ASSERT(newSize == size(), "FixedStorage cannot be resized"); - } - - template - void FixedStorage::resize(SizeType newSize, int) { - LIBRAPID_ASSERT(newSize == size(), "FixedStorage cannot be resized"); - } - - template - auto FixedStorage::size() const noexcept -> SizeType { - return Size; - } - - template - auto FixedStorage::copy() const -> FixedStorage { - return FixedStorage(); - } - - template - auto FixedStorage::operator[](SizeType index) const -> ConstReference { - LIBRAPID_ASSERT(index < size(), "Index out of bounds"); - return m_data[index]; - } - - template - auto FixedStorage::operator[](SizeType index) -> Reference { - LIBRAPID_ASSERT(index < size(), "Index out of bounds"); - return m_data[index]; - } - - template - auto FixedStorage::data() const noexcept -> Pointer { - return const_cast(m_data.data()); - } - - template - auto FixedStorage::begin() noexcept -> Iterator { - return m_data.begin(); - } - - template - auto FixedStorage::end() noexcept -> Iterator { - return m_data.end(); - } - - template - auto FixedStorage::begin() const noexcept -> ConstIterator { - return m_data.begin(); - } - - template - auto FixedStorage::end() const noexcept -> ConstIterator { - return m_data.end(); - } - - template - auto FixedStorage::cbegin() const noexcept -> ConstIterator { - return begin(); - } - - template - auto FixedStorage::cend() const noexcept -> ConstIterator { - return end(); - } - - template - auto FixedStorage::rbegin() noexcept -> ReverseIterator { - return ReverseIterator(end()); - } - - template - auto FixedStorage::rend() noexcept -> ReverseIterator { - return ReverseIterator(begin()); - } - - template - auto FixedStorage::rbegin() const noexcept -> ConstReverseIterator { - return ConstReverseIterator(end()); - } - - template - auto FixedStorage::rend() const noexcept -> ConstReverseIterator { - return ConstReverseIterator(begin()); - } - - template - auto FixedStorage::crbegin() const noexcept -> ConstReverseIterator { - return rbegin(); - } - - template - auto FixedStorage::crend() const noexcept -> ConstReverseIterator { - return rend(); - } + LIBRAPID_ASSERT( + ptr != nullptr, "Failed to allocate {} bytes of memory", size * sizeof(T)); + + // If the type cannot be trivially constructed, we need to + // initialize each value + if constexpr (!typetraits::TriviallyDefaultConstructible::value && + !std::is_array::value) { + for (RawPointer p = ptr; p != ptr + size; ++p) { new (p) T(); } + } + + return Pointer(ptr, [size](RawPointer ptr) { safeDeallocate(ptr, size); }); + } + + /// Safely copy a pointer to a shared pointer. If \p ownsData is true, then the shared + /// pointer will be initialized with a custom deleter that will call safeDeallocate on the + /// pointer. Otherwise, the shared pointer will be initialized with a no-op deleter. + /// \tparam T Type of the pointer + /// \param ptr Raw pointer to copy + /// \param ownsData Whether the shared pointer should own the data + /// \return Shared pointer to the data + template + std::shared_ptr safePointerCopy(T *ptr, size_t size, bool ownsData) { + using RawPointer = T *; + using Pointer = std::shared_ptr; + + if (ownsData) { + return Pointer(ptr, [size](RawPointer ptr) { safeDeallocate(ptr, size); }); + } else { + return Pointer(ptr, [](RawPointer) {}); + } + } + + template + std::shared_ptr safePointerCopy(const std::shared_ptr &ptr, size_t size, + bool ownsData = true) { + using RawPointer = T *; + using Pointer = std::shared_ptr; + + if (ownsData) { + return Pointer(ptr.get(), [size](RawPointer ptr) { safeDeallocate(ptr, size); }); + } else { + return Pointer(ptr.get(), [](RawPointer) {}); + } + } + } // namespace detail + + template + Storage::Storage(SizeType size) : + m_begin(detail::safeAllocate(size)), m_size(size), m_ownsData(true) {} + + template + Storage::Storage(Scalar *begin, Scalar *end, bool ownsData) : + m_begin(detail::safePointerCopy(begin, std::distance(begin, end), ownsData)), + m_size(std::distance(begin, end)), m_ownsData(ownsData) {} + + template + Storage::Storage(SizeType size, ConstReference value) : + m_begin(detail::safeAllocate(size)), m_size(size), m_ownsData(true) { + for (SizeType i = 0; i < size; ++i) { m_begin.get()[i] = value; } + } + + template + Storage::Storage(const Storage &other) : + m_begin(other.m_begin), m_size(other.m_size), m_ownsData(other.m_ownsData) {} + + template + Storage::Storage(Storage &&other) noexcept : + m_begin(std::move(other.m_begin)), m_size(std::move(other.m_size)), + m_ownsData(std::move(other.m_ownsData)) { + other.m_begin = nullptr; + other.m_size = 0; + other.m_ownsData = false; + } + + template + template + Storage::Storage(const std::initializer_list &list) : + m_begin(nullptr), m_size(0), m_ownsData(true) { + initData(list.begin(), list.end()); + } + + template + template + Storage::Storage(const std::vector &vector) : + m_begin(nullptr), m_size(0), m_ownsData(true) { + initData(vector.begin(), vector.end()); + } + + template + template + auto Storage::fromData(const std::initializer_list &list) -> Storage { + return Storage(list); + } + + template + template + auto Storage::fromData(const std::vector &vec) -> Storage { + return Storage(vec); + } + + template + Storage &Storage::operator=(const Storage &other) { + if (this != &other) { + if (m_ownsData) { + // If we own the data already, we can just copy the pointer since we know it won't + // affect anything else. The shared pointer deals with the reference counting, so + // we don't need to worry about other arrays that might be using the same data. + m_begin = other.m_begin; + m_size = other.m_size; + } else { + LIBRAPID_ASSERT(m_size == other.m_size, + "Cannot copy storage with {} elements to dependent storage with " + "{} elements", + other.m_size, + m_size); + + // If we don't own the data, the size must be the same since it is being used + // elsewhere, and we can't change it + + if (typetraits::TriviallyDefaultConstructible::value) { + // Use a slightly faster memcpy if the type is trivially default constructible + std::uninitialized_copy(other.begin(), other.end(), m_begin.get()); + } else { + // Otherwise, use the standard copy algorithm + std::copy(other.begin(), other.end(), m_begin.get()); + } + } + } + return *this; + } + + template + Storage &Storage::operator=(Storage &&other) LIBRAPID_RELEASE_NOEXCEPT { + if (this != &other) { + if (m_ownsData) { + std::swap(m_begin, other.m_begin); + std::swap(m_size, other.m_size); + m_ownsData = other.m_ownsData; + } else { + LIBRAPID_ASSERT( + size() == other.size(), + "Mismatched storage sizes. Cannot assign storage with {} elements to " + "dependent storage with {} elements", + other.size(), + size()); + + if (typetraits::TriviallyDefaultConstructible::value) { + // Use a slightly faster memcpy if the type is trivially default constructible + std::uninitialized_copy(other.begin(), other.end(), m_begin.get()); + } else { + // Otherwise, use the standard copy algorithm + std::copy(other.begin(), other.end(), m_begin.get()); + } + } + } + return *this; + } + + template + Storage::~Storage() { + // All deallocation is handled by the shared pointer, which has a custom deleter which + // depends on whether the data is owned by the storage object or not. If it is owned, the + // data is deallocated, otherwise it is left alone. + } + + template + template + void Storage::initData(P begin, P end) { + m_size = static_cast(std::distance(begin, end)); + m_begin = detail::safeAllocate(m_size); + + if constexpr (typetraits::TypeInfo::canMemcpy) { + if constexpr (typetraits::TriviallyDefaultConstructible::value) { + // Use a slightly faster memcpy if the type is trivially default constructible + std::uninitialized_copy(begin, end, m_begin.get()); + } else { + // Otherwise, use the standard copy algorithm + std::copy(begin, end, m_begin.get()); + } + } else { + // Since we can't memcpy, we have to copy each element individually + for (SizeType i = 0; i < m_size; ++i) { m_begin.get()[i] = begin[i]; } + } + } + + template + template + void Storage::initData(P begin, SizeType size) { + initData(begin, begin + size); + } + + template + void Storage::set(const Storage &other) { + // We can simply copy the shared pointers across + m_begin = other.m_begin; + m_size = other.m_size; + m_ownsData = other.m_ownsData; + } + + template + auto Storage::toHostStorage() const -> Storage { + return copy(); + } + + template + auto Storage::toHostStorageUnsafe() const -> Storage { + return copy(); + } + + template + auto Storage::copy() const -> Storage { + Storage ret; + ret.initData(m_begin.get(), m_size); + return ret; + } + + template + template + auto Storage::defaultShape() -> ShapeType { + return ShapeType({0}); + } + + template + auto Storage::size() const noexcept -> SizeType { + return m_size; + } + + template + void Storage::resize(SizeType newSize) { + resizeImpl(newSize); + } + + template + void Storage::resize(SizeType newSize, int) { + resizeImpl(newSize, 0); + } + + template + LIBRAPID_ALWAYS_INLINE void Storage::resizeImpl(SizeType newSize) { + // Resize and retain data + + if (newSize == size()) return; + LIBRAPID_ASSERT(m_ownsData, "Dependent storage cannot be resized"); + + // Copy the existing data to a new location + Pointer oldBegin = m_begin; + SizeType oldSize = m_size; + + // Allocate a new block of memory + m_begin = detail::safeAllocate(newSize); + m_size = newSize; + + // Copy the data across + if constexpr (typetraits::TriviallyDefaultConstructible::value) { + // Use a slightly faster memcpy if the type is trivially default constructible + std::uninitialized_copy( + oldBegin.get(), oldBegin.get() + ::librapid::min(oldSize, newSize), m_begin.get()); + } else { + // Otherwise, use the standard copy algorithm + std::copy( + oldBegin.get(), oldBegin.get() + ::librapid::min(oldSize, newSize), m_begin.get()); + } + } + + template + LIBRAPID_ALWAYS_INLINE void Storage::resizeImpl(SizeType newSize, int) { + // Resize and discard data + + if (size() == newSize) return; + LIBRAPID_ASSERT(m_ownsData, "Dependent storage cannot be resized"); + + // Allocate a new block of memory + m_begin = detail::safeAllocate(newSize); + m_size = newSize; + } + + template + auto Storage::operator[](Storage::SizeType index) const -> ConstReference { + LIBRAPID_ASSERT(index < size(), "Index {} out of bounds for size {}", index, size()); + return m_begin.get()[index]; + } + + template + auto Storage::operator[](Storage::SizeType index) -> Reference { + LIBRAPID_ASSERT(index < size(), "Index {} out of bounds for size {}", index, size()); + return m_begin.get()[index]; + } + + template + auto Storage::data() const noexcept -> Pointer { + return m_begin; + } + + template + auto Storage::begin() noexcept -> RawPointer { + return m_begin.get(); + } + + template + auto Storage::end() noexcept -> RawPointer { + return m_begin.get() + m_size; + } + + template + auto Storage::begin() const noexcept -> ConstIterator { + return m_begin.get(); + } + + template + auto Storage::end() const noexcept -> ConstIterator { + return m_begin.get() + m_size; + } + + template + auto Storage::cbegin() const noexcept -> ConstIterator { + return begin(); + } + + template + auto Storage::cend() const noexcept -> ConstIterator { + return end(); + } + + template + auto Storage::rbegin() noexcept -> ReverseIterator { + return ReverseIterator(m_begin.get() + m_size); + } + + template + auto Storage::rend() noexcept -> ReverseIterator { + return ReverseIterator(m_begin.get()); + } + + template + auto Storage::rbegin() const noexcept -> ConstReverseIterator { + return ConstReverseIterator(m_begin.get() + m_size); + } + + template + auto Storage::rend() const noexcept -> ConstReverseIterator { + return ConstReverseIterator(m_begin.get()); + } + + template + auto Storage::crbegin() const noexcept -> ConstReverseIterator { + return rbegin(); + } + + template + auto Storage::crend() const noexcept -> ConstReverseIterator { + return rend(); + } + + template + FixedStorage::FixedStorage() = default; + + template + FixedStorage::FixedStorage(const Scalar &value) { + for (size_t i = 0; i < Size; ++i) { m_data[i] = value; } + } + + template + FixedStorage::FixedStorage(const FixedStorage &other) = default; + + template + FixedStorage::FixedStorage(FixedStorage &&other) noexcept = default; + + template + FixedStorage::FixedStorage(const std::initializer_list &list) { + LIBRAPID_ASSERT(list.size() == size(), "Initializer list size does not match storage size"); + for (size_t i = 0; i < Size; ++i) { m_data[i] = list.begin()[i]; } + } + + template + FixedStorage::FixedStorage(const std::vector &vec) { + LIBRAPID_ASSERT(vec.size() == size(), "Initializer list size does not match storage size"); + for (size_t i = 0; i < Size; ++i) { m_data[i] = vec[i]; } + } + + template + auto FixedStorage::operator=(const FixedStorage &other) -> FixedStorage & { + if (this != &other) { + for (size_t i = 0; i < Size; ++i) { m_data[i] = other.m_data[i]; } + } + return *this; + } + + template + auto FixedStorage::operator=(FixedStorage &&other) noexcept + -> FixedStorage & = default; + + template + template + auto FixedStorage::defaultShape() -> ShapeType { + return ShapeType({D...}); + } + + template + void FixedStorage::resize(SizeType newSize) { + LIBRAPID_ASSERT(newSize == size(), "FixedStorage cannot be resized"); + } + + template + void FixedStorage::resize(SizeType newSize, int) { + LIBRAPID_ASSERT(newSize == size(), "FixedStorage cannot be resized"); + } + + template + auto FixedStorage::size() const noexcept -> SizeType { + return Size; + } + + template + auto FixedStorage::copy() const -> FixedStorage { + return FixedStorage(); + } + + template + auto FixedStorage::operator[](SizeType index) const -> ConstReference { + LIBRAPID_ASSERT(index < size(), "Index out of bounds"); + return m_data[index]; + } + + template + auto FixedStorage::operator[](SizeType index) -> Reference { + LIBRAPID_ASSERT(index < size(), "Index out of bounds"); + return m_data[index]; + } + + template + auto FixedStorage::data() const noexcept -> Pointer { + return const_cast(m_data.data()); + } + + template + auto FixedStorage::begin() noexcept -> Iterator { + return m_data.begin(); + } + + template + auto FixedStorage::end() noexcept -> Iterator { + return m_data.end(); + } + + template + auto FixedStorage::begin() const noexcept -> ConstIterator { + return m_data.begin(); + } + + template + auto FixedStorage::end() const noexcept -> ConstIterator { + return m_data.end(); + } + + template + auto FixedStorage::cbegin() const noexcept -> ConstIterator { + return begin(); + } + + template + auto FixedStorage::cend() const noexcept -> ConstIterator { + return end(); + } + + template + auto FixedStorage::rbegin() noexcept -> ReverseIterator { + return ReverseIterator(end()); + } + + template + auto FixedStorage::rend() noexcept -> ReverseIterator { + return ReverseIterator(begin()); + } + + template + auto FixedStorage::rbegin() const noexcept -> ConstReverseIterator { + return ConstReverseIterator(end()); + } + + template + auto FixedStorage::rend() const noexcept -> ConstReverseIterator { + return ConstReverseIterator(begin()); + } + + template + auto FixedStorage::crbegin() const noexcept -> ConstReverseIterator { + return rbegin(); + } + + template + auto FixedStorage::crend() const noexcept -> ConstReverseIterator { + return rend(); + } } // namespace librapid #endif // LIBRAPID_ARRAY_STORAGE_HPP \ No newline at end of file diff --git a/librapid/include/librapid/array/strideTools.hpp b/librapid/include/librapid/array/strideTools.hpp index f061e5a5..283e67ed 100644 --- a/librapid/include/librapid/array/strideTools.hpp +++ b/librapid/include/librapid/array/strideTools.hpp @@ -2,59 +2,59 @@ #define LIBRAPID_ARRAY_STRIDE_TOOLS_HPP namespace librapid { - namespace typetraits { - LIBRAPID_DEFINE_AS_TYPE(typename T COMMA size_t N, Stride); - } + namespace typetraits { + LIBRAPID_DEFINE_AS_TYPE(typename T COMMA size_t N, Stride); + } - /// A Stride is a vector of integers that describes the distance between elements in each - /// dimension of an ArrayContainer object. This can be used to access elements in a non-trivial - /// order, or to access a sub-array of an ArrayContainer object. The Stride class inherits from - /// the Shape class. - /// \tparam T The type of the Stride. Must be an integer type. - /// \tparam N The number of dimensions in the Stride. - /// \see Shape - template - class Stride : public Shape { - public: - /// Default Constructor - Stride() = default; + /// A Stride is a vector of integers that describes the distance between elements in each + /// dimension of an ArrayContainer object. This can be used to access elements in a non-trivial + /// order, or to access a sub-array of an ArrayContainer object. The Stride class inherits from + /// the Shape class. + /// \tparam T The type of the Stride. Must be an integer type. + /// \tparam N The number of dimensions in the Stride. + /// \see Shape + template + class Stride : public Shape { + public: + /// Default Constructor + Stride() = default; - /// Construct a Stride from a Shape object. This will assume that the data represented by - /// the Shape object is a contiguous block of memory, and will calculate the corresponding - /// strides based on this. - /// \param shape - Stride(const Shape &shape); + /// Construct a Stride from a Shape object. This will assume that the data represented by + /// the Shape object is a contiguous block of memory, and will calculate the corresponding + /// strides based on this. + /// \param shape + Stride(const Shape &shape); - /// Copy a Stride object - /// \param other The Stride object to copy. - Stride(const Stride &other) = default; + /// Copy a Stride object + /// \param other The Stride object to copy. + Stride(const Stride &other) = default; - /// Move a Stride object - /// \param other The Stride object to move. - Stride(Stride &&other) noexcept = default; + /// Move a Stride object + /// \param other The Stride object to move. + Stride(Stride &&other) noexcept = default; - /// Assign a Stride object to this Stride object. - /// \param other The Stride object to assign. - Stride &operator=(const Stride &other) = default; + /// Assign a Stride object to this Stride object. + /// \param other The Stride object to assign. + Stride &operator=(const Stride &other) = default; - /// Move a Stride object to this Stride object. - /// \param other The Stride object to move. - Stride &operator=(Stride &&other) noexcept = default; - }; + /// Move a Stride object to this Stride object. + /// \param other The Stride object to move. + Stride &operator=(Stride &&other) noexcept = default; + }; - template - Stride::Stride(const Shape &shape) : Shape(shape) { - if (this->m_dims == 0) { - // Edge case for a zero-dimensional array - this->m_data[0] = 1; - return; - } + template + Stride::Stride(const Shape &shape) : Shape(shape) { + if (this->m_dims == 0) { + // Edge case for a zero-dimensional array + this->m_data[0] = 1; + return; + } - T tmp[N] {0}; - tmp[this->m_dims - 1] = 1; - for (size_t i = this->m_dims - 1; i > 0; --i) tmp[i - 1] = tmp[i] * this->m_data[i]; - for (size_t i = 0; i < this->m_dims; ++i) this->m_data[i] = tmp[i]; - } + T tmp[N] {0}; + tmp[this->m_dims - 1] = 1; + for (size_t i = this->m_dims - 1; i > 0; --i) tmp[i - 1] = tmp[i] * this->m_data[i]; + for (size_t i = 0; i < this->m_dims; ++i) this->m_data[i] = tmp[i]; + } } // namespace librapid // Support FMT printing diff --git a/librapid/include/librapid/autodiff/dual.hpp b/librapid/include/librapid/autodiff/dual.hpp index 1e368d83..1ef4a8cf 100644 --- a/librapid/include/librapid/autodiff/dual.hpp +++ b/librapid/include/librapid/autodiff/dual.hpp @@ -2,457 +2,457 @@ #define LIBRAPID_AUTODIFF_DUAL #if defined(LIBRAPID_IN_JITIFY) -# define REQUIRE_SCALAR(TYPE) typename AD_ARG_ = void +# define REQUIRE_SCALAR(TYPE) typename AD_ARG_ = void #else -# define REQUIRE_SCALAR(TYPE) typename std::enable_if_t, int> = 0 +# define REQUIRE_SCALAR(TYPE) typename std::enable_if_t, int> = 0 #endif // LIBRAPID_IN_JITIFY #if !defined(LIBRAPID_IN_JITIFY) namespace librapid { #endif - template - class Dual { - public: - T value; - T derivative; + template + class Dual { + public: + T value; + T derivative; #if defined(LIBRAPID_IN_JITIFY) - using Scalar = T; - using Packet = T; - static constexpr uint64_t packetWidth = 1; + using Scalar = T; + using Packet = T; + static constexpr uint64_t packetWidth = 1; #else - using Scalar = typename typetraits::TypeInfo::Scalar; - using Packet = typename typetraits::TypeInfo::Packet; - static constexpr uint64_t packetWidth = typetraits::TypeInfo::packetWidth; + using Scalar = typename typetraits::TypeInfo::Scalar; + using Packet = typename typetraits::TypeInfo::Packet; + static constexpr uint64_t packetWidth = typetraits::TypeInfo::packetWidth; #endif - Dual() = default; - explicit Dual(T value) : value(value), derivative(T()) {} - Dual(T value, T derivative) : value(value), derivative(derivative) {} - - template - explicit Dual(const Dual &other) : value(other.value), derivative(other.derivative) {} - - template - explicit Dual(Dual &&other) : - value(std::move(other.value)), derivative(std::move(other.derivative)) {} - - template - Dual &operator=(const Dual &other) { - value = other.value; - derivative = other.derivative; - return *this; - } - - template - Dual &operator=(Dual &&other) { - value = std::move(other.value); - derivative = std::move(other.derivative); - return *this; - } - - static constexpr size_t size() { return typetraits::TypeInfo::packetWidth; } - - // template - // LIBRAPID_ALWAYS_INLINE void store(P *ptr) const { - // // Load the data into batches. - // auto casted = reinterpret_cast(ptr); - // - // // Compute interleaved values. - // std::array interleaved; - // for (std::size_t i = 0; i < packetWidth; ++i) { - // interleaved[2 * i] = value.get(i); - // interleaved[2 * i + 1] = derivative.get(i); - // } - // - // // Store the interleaved values back to memory. - // std::copy(interleaved.begin(), interleaved.end(), casted); - // } - - // template - // LIBRAPID_ALWAYS_INLINE void load(const P *ptr) { - // // auto casted = reinterpret_cast(ptr); - // // Vc::deinterleave(&value, &derivative, casted, Vc::Aligned); - // - // // Load the data into batches. - // auto casted = reinterpret_cast(ptr); - // - // // Compute interleaved values. - // std::array interleaved; - // std::copy(casted, casted + 2 * packetWidth, interleaved.begin()); - // - // // Store the interleaved values back to memory. - // for (std::size_t i = 0; i < packetWidth; ++i) { - // value.set(i, interleaved[2 * i]); - // derivative.set(i, interleaved[2 * i + 1]); - // } - // } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+=(const Dual &other) { - value += other.value; - derivative += other.derivative; - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator-=(const Dual &other) { - value -= other.value; - derivative -= other.derivative; - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator*=(const Dual &other) { - value *= other.value; - derivative = derivative * other.value + value * other.derivative; - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator/=(const Dual &other) { - value /= other.value; - derivative = - (derivative * other.value - value * other.derivative) / (other.value * other.value); - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+=(const T &other) { - value += other; - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator-=(const T &other) { - value -= other; - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator*=(const T &other) { - value *= other; - derivative *= other; - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator/=(const T &other) { - value /= other; - derivative /= other; - return *this; - } + Dual() = default; + explicit Dual(T value) : value(value), derivative(T()) {} + Dual(T value, T derivative) : value(value), derivative(derivative) {} + + template + explicit Dual(const Dual &other) : value(other.value), derivative(other.derivative) {} + + template + explicit Dual(Dual &&other) : + value(std::move(other.value)), derivative(std::move(other.derivative)) {} + + template + Dual &operator=(const Dual &other) { + value = other.value; + derivative = other.derivative; + return *this; + } + + template + Dual &operator=(Dual &&other) { + value = std::move(other.value); + derivative = std::move(other.derivative); + return *this; + } + + static constexpr size_t size() { return typetraits::TypeInfo::packetWidth; } + + // template + // LIBRAPID_ALWAYS_INLINE void store(P *ptr) const { + // // Load the data into batches. + // auto casted = reinterpret_cast(ptr); + // + // // Compute interleaved values. + // std::array interleaved; + // for (std::size_t i = 0; i < packetWidth; ++i) { + // interleaved[2 * i] = value.get(i); + // interleaved[2 * i + 1] = derivative.get(i); + // } + // + // // Store the interleaved values back to memory. + // std::copy(interleaved.begin(), interleaved.end(), casted); + // } + + // template + // LIBRAPID_ALWAYS_INLINE void load(const P *ptr) { + // // auto casted = reinterpret_cast(ptr); + // // Vc::deinterleave(&value, &derivative, casted, Vc::Aligned); + // + // // Load the data into batches. + // auto casted = reinterpret_cast(ptr); + // + // // Compute interleaved values. + // std::array interleaved; + // std::copy(casted, casted + 2 * packetWidth, interleaved.begin()); + // + // // Store the interleaved values back to memory. + // for (std::size_t i = 0; i < packetWidth; ++i) { + // value.set(i, interleaved[2 * i]); + // derivative.set(i, interleaved[2 * i + 1]); + // } + // } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+=(const Dual &other) { + value += other.value; + derivative += other.derivative; + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator-=(const Dual &other) { + value -= other.value; + derivative -= other.derivative; + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator*=(const Dual &other) { + value *= other.value; + derivative = derivative * other.value + value * other.derivative; + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator/=(const Dual &other) { + value /= other.value; + derivative = + (derivative * other.value - value * other.derivative) / (other.value * other.value); + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+=(const T &other) { + value += other; + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator-=(const T &other) { + value -= other; + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator*=(const T &other) { + value *= other; + derivative *= other; + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator/=(const T &other) { + value /= other; + derivative /= other; + return *this; + } #if !defined(LIBRAPID_IN_JITIFY) - std::string str(const std::string &format = "{}") const { - return fmt::format( - "Dual({}, {})", fmt::format(format, value), fmt::format(format, derivative)); - } + std::string str(const std::string &format = "{}") const { + return fmt::format( + "Dual({}, {})", fmt::format(format, value), fmt::format(format, derivative)); + } #endif // !defined(LIBRAPID_IN_JITIFY) - }; - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator+(const Dual &lhs, const Dual &rhs) { - return {lhs.value + rhs.value, lhs.derivative + rhs.derivative}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator+(const Dual &lhs, const V &rhs) { - return {lhs.value + rhs, lhs.derivative}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator+(const V &lhs, const Dual &rhs) { - return {lhs + rhs.value, rhs.derivative}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator-(const Dual &lhs, const Dual &rhs) { - return {lhs.value - rhs.value, lhs.derivative - rhs.derivative}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator-(const Dual &lhs, const V &rhs) { - return {lhs.value - rhs, lhs.derivative}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator-(const V &lhs, const Dual &rhs) { - return {lhs - rhs.value, -rhs.derivative}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator*(const Dual &lhs, const Dual &rhs) { - return {lhs.value * rhs.value, lhs.derivative * rhs.value + lhs.value * rhs.derivative}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator*(const Dual &lhs, const V &rhs) { - return {lhs.value * rhs, lhs.derivative * rhs}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator*(const V &lhs, const Dual &rhs) { - return {lhs * rhs.value, lhs * rhs.derivative}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator/(const Dual &lhs, const Dual &rhs) { - return {lhs.value / rhs.value, - (lhs.derivative * rhs.value - lhs.value * rhs.derivative) / - (rhs.value * rhs.value)}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator/(const Dual &lhs, const V &rhs) { - return {lhs.value / rhs, lhs.derivative / rhs}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual - operator/(const V &lhs, const Dual &rhs) { - return {lhs / rhs.value, -lhs * rhs.derivative / (rhs.value * rhs.value)}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator-(const Dual &lhs) { - return {-lhs.value, -lhs.derivative}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+(const Dual &lhs) { - return {lhs.value, lhs.derivative}; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual sin(const Dual &x) { - using Ret = decltype(::librapid::sin(x.value)); - return Dual(::librapid::sin(x.value), ::librapid::cos(x.value) * x.derivative); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual cos(const Dual &x) { - using Ret = decltype(::librapid::cos(x.value)); - return Dual(::librapid::cos(x.value), -::librapid::sin(x.value) * x.derivative); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual tan(const Dual &x) { - using Ret = decltype(::librapid::tan(x.value)); - auto cosX = ::librapid::cos(x.value); - return Dual(::librapid::tan(x.value), x.derivative / (cosX * cosX)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual asin(const Dual &x) { - using Ret = decltype(::librapid::asin(x.value)); - return Dual(::librapid::asin(x.value), - x.derivative / ::librapid::sqrt(1 - x.value * x.value)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual acos(const Dual &x) { - using Ret = decltype(::librapid::acos(x.value)); - return Dual(::librapid::acos(x.value), - -x.derivative / ::librapid::sqrt(1 - x.value * x.value)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual atan(const Dual &x) { - using Ret = decltype(::librapid::atan(x.value)); - return Dual(::librapid::atan(x.value), x.derivative / (1 + x.value * x.value)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual sinh(const Dual &x) { - using Ret = decltype(::librapid::sinh(x.value)); - return Dual(::librapid::sinh(x.value), ::librapid::cosh(x.value) * x.derivative); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual cosh(const Dual &x) { - using Ret = decltype(::librapid::cosh(x.value)); - return Dual(::librapid::cosh(x.value), ::librapid::sinh(x.value) * x.derivative); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual tanh(const Dual &x) { - using Ret = decltype(::librapid::tanh(x.value)); - auto coshX = ::librapid::cosh(x.value); - return Dual(::librapid::tanh(x.value), x.derivative / (coshX * coshX)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual asinh(const Dual &x) { - using Ret = decltype(::librapid::asinh(x.value)); - return Dual(::librapid::asinh(x.value), - x.derivative / ::librapid::sqrt(1 + x.value * x.value)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual acosh(const Dual &x) { - using Ret = decltype(::librapid::acosh(x.value)); - return Dual(::librapid::acosh(x.value), - x.derivative / ::librapid::sqrt(x.value * x.value - 1)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual atanh(const Dual &x) { - using Ret = decltype(::librapid::atanh(x.value)); - return Dual(::librapid::atanh(x.value), x.derivative / (1 - x.value * x.value)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual exp(const Dual &x) { - using Ret = decltype(::librapid::exp(x.value)); - auto expX = ::librapid::exp(x.value); - return Dual(expX, expX * x.derivative); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual exp2(const Dual &x) { - using Ret = decltype(::librapid::exp2(x.value)); - auto exp2X = ::librapid::exp2(x.value); - return Dual(exp2X, exp2X * ::librapid::log(2) * x.derivative); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual exp10(const Dual &x) { - using Ret = decltype(::librapid::exp2(x.value)); - auto exp2X = ::librapid::exp10(x.value); - return Dual(exp2X, exp2X * ::librapid::log(10) * x.derivative); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual log(const Dual &x) { - using Ret = decltype(::librapid::log(x.value)); - return Dual(::librapid::log(x.value), x.derivative / x.value); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual log10(const Dual &x) { - using Ret = decltype(::librapid::log10(x.value)); - return Dual(::librapid::log10(x.value), - x.derivative / (x.value * ::librapid::log(10))); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual log2(const Dual &x) { - using Ret = decltype(::librapid::log2(x.value)); - return Dual(::librapid::log2(x.value), x.derivative / (x.value * ::librapid::log(2))); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual sqrt(const Dual &x) { - using Ret = decltype(::librapid::sqrt(x.value)); - return Dual(::librapid::sqrt(x.value), x.derivative / (2 * ::librapid::sqrt(x.value))); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual cbrt(const Dual &x) { - using Ret = decltype(::librapid::cbrt(x.value)); - return Dual(::librapid::cbrt(x.value), - x.derivative / (3 * ::librapid::cbrt(x.value * x.value))); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual pow(const Dual &x, const V &y) { - using Ret = decltype(::librapid::pow(x.value, y)); - return Dual(::librapid::pow(x.value, y), - y * ::librapid::pow(x.value, y - 1) * x.derivative); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual pow(const V &x, const Dual &y) { - using Ret = decltype(::librapid::pow(x, y.value)); - return Dual(::librapid::pow(x, y.value), - ::librapid::log(x) * ::librapid::pow(x, y.value) * y.derivative); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual pow(const Dual &x, const Dual &y) { - using Ret = decltype(::librapid::pow(x.value, y.value)); - return Dual( - ::librapid::pow(x.value, y.value), - ::librapid::pow(x.value, y.value) * - (y.derivative * ::librapid::log(x.value) + y.value * x.derivative / x.value)); - } + }; + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator+(const Dual &lhs, const Dual &rhs) { + return {lhs.value + rhs.value, lhs.derivative + rhs.derivative}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator+(const Dual &lhs, const V &rhs) { + return {lhs.value + rhs, lhs.derivative}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator+(const V &lhs, const Dual &rhs) { + return {lhs + rhs.value, rhs.derivative}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator-(const Dual &lhs, const Dual &rhs) { + return {lhs.value - rhs.value, lhs.derivative - rhs.derivative}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator-(const Dual &lhs, const V &rhs) { + return {lhs.value - rhs, lhs.derivative}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator-(const V &lhs, const Dual &rhs) { + return {lhs - rhs.value, -rhs.derivative}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator*(const Dual &lhs, const Dual &rhs) { + return {lhs.value * rhs.value, lhs.derivative * rhs.value + lhs.value * rhs.derivative}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator*(const Dual &lhs, const V &rhs) { + return {lhs.value * rhs, lhs.derivative * rhs}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator*(const V &lhs, const Dual &rhs) { + return {lhs * rhs.value, lhs * rhs.derivative}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator/(const Dual &lhs, const Dual &rhs) { + return {lhs.value / rhs.value, + (lhs.derivative * rhs.value - lhs.value * rhs.derivative) / + (rhs.value * rhs.value)}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator/(const Dual &lhs, const V &rhs) { + return {lhs.value / rhs, lhs.derivative / rhs}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual + operator/(const V &lhs, const Dual &rhs) { + return {lhs / rhs.value, -lhs * rhs.derivative / (rhs.value * rhs.value)}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator-(const Dual &lhs) { + return {-lhs.value, -lhs.derivative}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+(const Dual &lhs) { + return {lhs.value, lhs.derivative}; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual sin(const Dual &x) { + using Ret = decltype(::librapid::sin(x.value)); + return Dual(::librapid::sin(x.value), ::librapid::cos(x.value) * x.derivative); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual cos(const Dual &x) { + using Ret = decltype(::librapid::cos(x.value)); + return Dual(::librapid::cos(x.value), -::librapid::sin(x.value) * x.derivative); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual tan(const Dual &x) { + using Ret = decltype(::librapid::tan(x.value)); + auto cosX = ::librapid::cos(x.value); + return Dual(::librapid::tan(x.value), x.derivative / (cosX * cosX)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual asin(const Dual &x) { + using Ret = decltype(::librapid::asin(x.value)); + return Dual(::librapid::asin(x.value), + x.derivative / ::librapid::sqrt(1 - x.value * x.value)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual acos(const Dual &x) { + using Ret = decltype(::librapid::acos(x.value)); + return Dual(::librapid::acos(x.value), + -x.derivative / ::librapid::sqrt(1 - x.value * x.value)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual atan(const Dual &x) { + using Ret = decltype(::librapid::atan(x.value)); + return Dual(::librapid::atan(x.value), x.derivative / (1 + x.value * x.value)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual sinh(const Dual &x) { + using Ret = decltype(::librapid::sinh(x.value)); + return Dual(::librapid::sinh(x.value), ::librapid::cosh(x.value) * x.derivative); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual cosh(const Dual &x) { + using Ret = decltype(::librapid::cosh(x.value)); + return Dual(::librapid::cosh(x.value), ::librapid::sinh(x.value) * x.derivative); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual tanh(const Dual &x) { + using Ret = decltype(::librapid::tanh(x.value)); + auto coshX = ::librapid::cosh(x.value); + return Dual(::librapid::tanh(x.value), x.derivative / (coshX * coshX)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual asinh(const Dual &x) { + using Ret = decltype(::librapid::asinh(x.value)); + return Dual(::librapid::asinh(x.value), + x.derivative / ::librapid::sqrt(1 + x.value * x.value)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual acosh(const Dual &x) { + using Ret = decltype(::librapid::acosh(x.value)); + return Dual(::librapid::acosh(x.value), + x.derivative / ::librapid::sqrt(x.value * x.value - 1)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual atanh(const Dual &x) { + using Ret = decltype(::librapid::atanh(x.value)); + return Dual(::librapid::atanh(x.value), x.derivative / (1 - x.value * x.value)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual exp(const Dual &x) { + using Ret = decltype(::librapid::exp(x.value)); + auto expX = ::librapid::exp(x.value); + return Dual(expX, expX * x.derivative); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual exp2(const Dual &x) { + using Ret = decltype(::librapid::exp2(x.value)); + auto exp2X = ::librapid::exp2(x.value); + return Dual(exp2X, exp2X * ::librapid::log(2) * x.derivative); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual exp10(const Dual &x) { + using Ret = decltype(::librapid::exp2(x.value)); + auto exp2X = ::librapid::exp10(x.value); + return Dual(exp2X, exp2X * ::librapid::log(10) * x.derivative); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual log(const Dual &x) { + using Ret = decltype(::librapid::log(x.value)); + return Dual(::librapid::log(x.value), x.derivative / x.value); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual log10(const Dual &x) { + using Ret = decltype(::librapid::log10(x.value)); + return Dual(::librapid::log10(x.value), + x.derivative / (x.value * ::librapid::log(10))); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual log2(const Dual &x) { + using Ret = decltype(::librapid::log2(x.value)); + return Dual(::librapid::log2(x.value), x.derivative / (x.value * ::librapid::log(2))); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual sqrt(const Dual &x) { + using Ret = decltype(::librapid::sqrt(x.value)); + return Dual(::librapid::sqrt(x.value), x.derivative / (2 * ::librapid::sqrt(x.value))); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual cbrt(const Dual &x) { + using Ret = decltype(::librapid::cbrt(x.value)); + return Dual(::librapid::cbrt(x.value), + x.derivative / (3 * ::librapid::cbrt(x.value * x.value))); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual pow(const Dual &x, const V &y) { + using Ret = decltype(::librapid::pow(x.value, y)); + return Dual(::librapid::pow(x.value, y), + y * ::librapid::pow(x.value, y - 1) * x.derivative); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual pow(const V &x, const Dual &y) { + using Ret = decltype(::librapid::pow(x, y.value)); + return Dual(::librapid::pow(x, y.value), + ::librapid::log(x) * ::librapid::pow(x, y.value) * y.derivative); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual pow(const Dual &x, const Dual &y) { + using Ret = decltype(::librapid::pow(x.value, y.value)); + return Dual( + ::librapid::pow(x.value, y.value), + ::librapid::pow(x.value, y.value) * + (y.derivative * ::librapid::log(x.value) + y.value * x.derivative / x.value)); + } #if !defined(LIBRAPID_IN_JITIFY) - namespace typetraits { - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Dual; - using Scalar = T; - using Packet = std::false_type; // Dual::Packet>; - static constexpr int64_t packetWidth = - 0; // TypeInfo::Scalar>::packetWidth; - using Backend = backend::CPU; - - static constexpr char name[] = "Dual_T"; - - static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; - static constexpr bool supportsLogical = TypeInfo::supportsLogical; - static constexpr bool supportsBinary = TypeInfo::supportsBinary; - static constexpr bool allowVectorisation = false; // TypeInfo::allowVectorisation; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = TypeInfo::CudaType; - static constexpr int64_t cudaPacketWidth = 1; -# endif // LIBRAPID_HAS_CUDA - - static constexpr bool canAlign = TypeInfo::canAlign; - static constexpr int64_t canMemcpy = TypeInfo::canMemcpy; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Dual; - using Scalar = float; - using Packet = std::false_type; // Dual::Packet>; - static constexpr int64_t packetWidth = - 0; // TypeInfo::Scalar>::packetWidth; - using Backend = backend::CPU; - - static constexpr char name[] = "Dual_float"; - - static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; - static constexpr bool supportsLogical = TypeInfo::supportsLogical; - static constexpr bool supportsBinary = TypeInfo::supportsBinary; - static constexpr bool allowVectorisation = - false; // TypeInfo::allowVectorisation; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = TypeInfo::CudaType; - static constexpr int64_t cudaPacketWidth = 1; -# endif // LIBRAPID_HAS_CUDA - - static constexpr bool canAlign = TypeInfo::canAlign; - static constexpr int64_t canMemcpy = TypeInfo::canMemcpy; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - } // namespace typetraits + namespace typetraits { + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Dual; + using Scalar = T; + using Packet = std::false_type; // Dual::Packet>; + static constexpr int64_t packetWidth = + 0; // TypeInfo::Scalar>::packetWidth; + using Backend = backend::CPU; + + static constexpr char name[] = "Dual_T"; + + static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; + static constexpr bool supportsLogical = TypeInfo::supportsLogical; + static constexpr bool supportsBinary = TypeInfo::supportsBinary; + static constexpr bool allowVectorisation = false; // TypeInfo::allowVectorisation; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = TypeInfo::CudaType; + static constexpr int64_t cudaPacketWidth = 1; +# endif // LIBRAPID_HAS_CUDA + + static constexpr bool canAlign = TypeInfo::canAlign; + static constexpr int64_t canMemcpy = TypeInfo::canMemcpy; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Dual; + using Scalar = float; + using Packet = std::false_type; // Dual::Packet>; + static constexpr int64_t packetWidth = + 0; // TypeInfo::Scalar>::packetWidth; + using Backend = backend::CPU; + + static constexpr char name[] = "Dual_float"; + + static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; + static constexpr bool supportsLogical = TypeInfo::supportsLogical; + static constexpr bool supportsBinary = TypeInfo::supportsBinary; + static constexpr bool allowVectorisation = + false; // TypeInfo::allowVectorisation; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = TypeInfo::CudaType; + static constexpr int64_t cudaPacketWidth = 1; +# endif // LIBRAPID_HAS_CUDA + + static constexpr bool canAlign = TypeInfo::canAlign; + static constexpr int64_t canMemcpy = TypeInfo::canMemcpy; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + } // namespace typetraits #endif // !LIBRAPID_IN_JITIFY #if !defined(LIBRAPID_IN_JITIFY) @@ -461,12 +461,12 @@ namespace librapid { #if defined(LIBRAPID_HAS_CUDA) namespace jitify::reflection::detail { - template - struct type_reflection<::librapid::Dual> { - inline static std::string name() { - return fmt::format("Dual<{}>", type_reflection::name()); - } - }; + template + struct type_reflection<::librapid::Dual> { + inline static std::string name() { + return fmt::format("Dual<{}>", type_reflection::name()); + } + }; } // namespace jitify::reflection::detail #endif // LIBRAPID_HAS_CUDA diff --git a/librapid/include/librapid/core/config.hpp b/librapid/include/librapid/core/config.hpp index 67d3580b..fcd35365 100644 --- a/librapid/include/librapid/core/config.hpp +++ b/librapid/include/librapid/core/config.hpp @@ -11,248 +11,248 @@ // Detect Release vs Debug builds #if !defined(NDEBUG) -# define LIBRAPID_DEBUG -# define LIBRAPID_RELEASE_NOEXCEPT +# define LIBRAPID_DEBUG +# define LIBRAPID_RELEASE_NOEXCEPT #else -# define LIBRAPID_RELEASE -# define LIBRAPID_RELEASE_NOEXCEPT noexcept +# define LIBRAPID_RELEASE +# define LIBRAPID_RELEASE_NOEXCEPT noexcept #endif // Detect the operating system #if defined(_WIN32) -# define LIBRAPID_WINDOWS // Windows -# define LIBRAPID_OS_NAME "windows" +# define LIBRAPID_WINDOWS // Windows +# define LIBRAPID_OS_NAME "windows" #elif defined(_WIN64) -# define LIBRAPID_WINDOWS // Windows -# define LIBRAPID_OS_NAME "windows" +# define LIBRAPID_WINDOWS // Windows +# define LIBRAPID_OS_NAME "windows" #elif defined(__CYGWIN__) && !defined(_WIN32) -# define LIBRAPID_WINDOWS // Windows (Cygwin POSIX under Microsoft Window) -# define LIBRAPID_OS_NAME "windows" +# define LIBRAPID_WINDOWS // Windows (Cygwin POSIX under Microsoft Window) +# define LIBRAPID_OS_NAME "windows" #elif defined(__ANDROID__) -# define LIBRAPID_ANDROID // Android (implies Linux, so it must come first) -# define LIBRAPID_OS_NAME "android" +# define LIBRAPID_ANDROID // Android (implies Linux, so it must come first) +# define LIBRAPID_OS_NAME "android" #elif defined(__linux__) -# define LIBRAPID_LINUX // Debian, Ubuntu, Gentoo, Fedora, openSUSE, RedHat, Centos and other -# define LIBRAPID_UNIX -# define LIBRAPID_OS_NAME "linux" +# define LIBRAPID_LINUX // Debian, Ubuntu, Gentoo, Fedora, openSUSE, RedHat, Centos and other +# define LIBRAPID_UNIX +# define LIBRAPID_OS_NAME "linux" #elif defined(__unix__) || !defined(__APPLE__) && defined(__MACH__) -# include -# if defined(BSD) -# define LIBRAPID_BSD // FreeBSD, NetBSD, OpenBSD, DragonFly BSD -# define LIBRAPID_UNIX -# define LIBRAPID_OS_NAME "bsd" -# endif +# include +# if defined(BSD) +# define LIBRAPID_BSD // FreeBSD, NetBSD, OpenBSD, DragonFly BSD +# define LIBRAPID_UNIX +# define LIBRAPID_OS_NAME "bsd" +# endif #elif defined(__hpux) -# define LIBRAPID_HP_UX // HP-UX -# define LIBRAPID_OS_NAME "hp-ux" +# define LIBRAPID_HP_UX // HP-UX +# define LIBRAPID_OS_NAME "hp-ux" #elif defined(_AIX) -# define LIBRAPID_AIX // IBM AIX -# define LIBRAPID_OS_NAME "aix" +# define LIBRAPID_AIX // IBM AIX +# define LIBRAPID_OS_NAME "aix" #elif defined(__APPLE__) && defined(__MACH__) // Apple OSX and iOS (Darwin) -# define LIBRAPID_APPLE -# define LIBRAPID_UNIX -# include -# if TARGET_IPHONE_SIMULATOR == 1 -# define LIBRAPID_IOS // Apple iOS -# define LIBRAPID_OS_NAME "ios" -# elif TARGET_OS_IPHONE == 1 -# define LIBRAPID_IOS // Apple iOS -# define LIBRAPID_OS_NAME "ios" -# elif TARGET_OS_MAC == 1 -# define LIBRAPID_OSX // Apple OSX -# define LIBRAPID_OS_NAME "osx" -# endif +# define LIBRAPID_APPLE +# define LIBRAPID_UNIX +# include +# if TARGET_IPHONE_SIMULATOR == 1 +# define LIBRAPID_IOS // Apple iOS +# define LIBRAPID_OS_NAME "ios" +# elif TARGET_OS_IPHONE == 1 +# define LIBRAPID_IOS // Apple iOS +# define LIBRAPID_OS_NAME "ios" +# elif TARGET_OS_MAC == 1 +# define LIBRAPID_OSX // Apple OSX +# define LIBRAPID_OS_NAME "osx" +# endif #elif defined(__sun) && defined(__SVR4) -# define LIBRAPID_SOLARIS // Oracle Solaris, Open Indiana -# define LIBRAPID_OS_NAME "solaris" +# define LIBRAPID_SOLARIS // Oracle Solaris, Open Indiana +# define LIBRAPID_OS_NAME "solaris" #else -# define LIBRAPID_UNKNOWN -# define LIBRAPID_OS_NAME "unknown" +# define LIBRAPID_UNKNOWN +# define LIBRAPID_OS_NAME "unknown" #endif // Compiler information #if defined(__GNUC__) -# define LIBRAPID_GNU -# define LIBRAPID_COMPILER_NAME "GNU C/C++ Compiler" +# define LIBRAPID_GNU +# define LIBRAPID_COMPILER_NAME "GNU C/C++ Compiler" #endif #if defined(__MINGW32__) -# define LIBRAPID_MINGW -# define LIBRAPID_COMPILER_NAME "Mingw or GNU C/C++ Compiler ported for Windows NT" +# define LIBRAPID_MINGW +# define LIBRAPID_COMPILER_NAME "Mingw or GNU C/C++ Compiler ported for Windows NT" #endif #if defined(__MINGW64__) -# define LIBRAPID_MINGW -# define LIBRAPID_COMPILER_NAME \ - "Mingw or GNU C/C++ Compiler ported for Windows NT - 64 bits only" +# define LIBRAPID_MINGW +# define LIBRAPID_COMPILER_NAME \ + "Mingw or GNU C/C++ Compiler ported for Windows NT - 64 bits only" #endif #if defined(__GFORTRAN__) -# define LIBRAPID_FORTRAN -# define LIBRAPID_COMPILER_NAME "Fortran / GNU Fortran Compiler" +# define LIBRAPID_FORTRAN +# define LIBRAPID_COMPILER_NAME "Fortran / GNU Fortran Compiler" #endif #if defined(__clang__) && !defined(_MSC_VER) -# define LIBRAPID_CLANG -# define LIBRAPID_COMPILER_NAME "Clang / LLVM Compiler" +# define LIBRAPID_CLANG +# define LIBRAPID_COMPILER_NAME "Clang / LLVM Compiler" #endif #if defined(_MSC_VER) -# define LIBRAPID_MSVC -# define LIBRAPID_COMPILER_NAME "Microsoft Visual Studio Compiler MSVC" +# define LIBRAPID_MSVC +# define LIBRAPID_COMPILER_NAME "Microsoft Visual Studio Compiler MSVC" #endif #if defined(_MANAGED) || defined(__cplusplus_cli) -# define LIBRAPID_DOTNET -# define LIBRAPID_COMPILER_NAME "Compilation to C++/CLI .NET (CLR) bytecode" +# define LIBRAPID_DOTNET +# define LIBRAPID_COMPILER_NAME "Compilation to C++/CLI .NET (CLR) bytecode" #endif #if defined(__INTEL_COMPILER) -# define LIBRAPID_INTEL -# define LIBRAPID_COMPILER_NAME "Intel C/C++ Compiler" +# define LIBRAPID_INTEL +# define LIBRAPID_COMPILER_NAME "Intel C/C++ Compiler" #endif #if defined(__PGI) || defined(__PGIC__) -# define LIBRAPID_PORTLAND -# define LIBRAPID_COMPILER_NAME "Portland Group C/C++ Compiler" +# define LIBRAPID_PORTLAND +# define LIBRAPID_COMPILER_NAME "Portland Group C/C++ Compiler" #endif #if defined(__BORLANDC__) -# define LIBRAPID_BORLAND -# define LIBRAPID_COMPILER_NAME "Borland C++ Compiler" +# define LIBRAPID_BORLAND +# define LIBRAPID_COMPILER_NAME "Borland C++ Compiler" #endif #if defined(__EMSCRIPTEN__) -# define LIBRAPID_EMSCRIPTEN -# define LIBRAPID_COMPILER_NAME "emscripten (asm.js - web assembly)" +# define LIBRAPID_EMSCRIPTEN +# define LIBRAPID_COMPILER_NAME "emscripten (asm.js - web assembly)" #endif #if defined(__asmjs__) -# define LIBRAPID_ASMJS -# define LIBRAPID_COMPILER_NAME "asm.js" +# define LIBRAPID_ASMJS +# define LIBRAPID_COMPILER_NAME "asm.js" #endif #if defined(__wasm__) -# define LIBRAPID_WASM -# define LIBRAPID_COMPILER_NAME "WebAssembly" +# define LIBRAPID_WASM +# define LIBRAPID_COMPILER_NAME "WebAssembly" #endif #if defined(__NVCC__) -# define LIBRAPID_NVCC -# define LIBRAPID_COMPILER_NAME "NVIDIA NVCC CUDA Compiler" +# define LIBRAPID_NVCC +# define LIBRAPID_COMPILER_NAME "NVIDIA NVCC CUDA Compiler" #endif #if defined(__CLING__) -# define LIBRAPID_CLING -# define LIBRAPID_COMPILER_NAME "CERN's ROOT Cling C++ Interactive Shell" +# define LIBRAPID_CLING +# define LIBRAPID_COMPILER_NAME "CERN's ROOT Cling C++ Interactive Shell" #endif // Instruction sets #define AVX512_2 10 -#define AVX512 9 -#define AVX2 8 -#define AVX 7 -#define SSE4_2 6 -#define SSE4_1 5 -#define SSSE3 4 -#define SSE3 3 -#define SSE2 2 -#define SSE 1 -#define NOARCH 0 +#define AVX512 9 +#define AVX2 8 +#define AVX 7 +#define SSE4_2 6 +#define SSE4_1 5 +#define SSSE3 4 +#define SSE3 3 +#define SSE2 2 +#define SSE 1 +#define NOARCH 0 // Instruction set detection #if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) -# define LIBRAPID_AVX512 -# define LIBRAPID_ARCH AVX512_2 -# define LIBRAPID_ARCH_NAME "AVX512" -# define LIBRAPID_DEFAULT_MEM_ALIGN 256 +# define LIBRAPID_AVX512 +# define LIBRAPID_ARCH AVX512_2 +# define LIBRAPID_ARCH_NAME "AVX512" +# define LIBRAPID_DEFAULT_MEM_ALIGN 256 #elif defined(__AVX512F__) || defined(__AVX512__) -# define LIBRAPID_AVX512 -# define LIBRAPID_ARCH AVX512 -# define LIBRAPID_ARCH_NAME "AVX512" -# define LIBRAPID_DEFAULT_MEM_ALIGN 256 +# define LIBRAPID_AVX512 +# define LIBRAPID_ARCH AVX512 +# define LIBRAPID_ARCH_NAME "AVX512" +# define LIBRAPID_DEFAULT_MEM_ALIGN 256 #elif defined(__AVX2__) -# define LIBRAPID_AVX2 -# define LIBRAPID_ARCH AVX2 -# define LIBRAPID_ARCH_NAME "AVX2" -# define LIBRAPID_DEFAULT_MEM_ALIGN 128 +# define LIBRAPID_AVX2 +# define LIBRAPID_ARCH AVX2 +# define LIBRAPID_ARCH_NAME "AVX2" +# define LIBRAPID_DEFAULT_MEM_ALIGN 128 #elif defined(__AVX__) -# define LIBRAPID_AVX -# define LIBRAPID_ARCH AVX -# define LIBRAPID_ARCH_NAME "AVX" -# define LIBRAPID_DEFAULT_MEM_ALIGN 128 +# define LIBRAPID_AVX +# define LIBRAPID_ARCH AVX +# define LIBRAPID_ARCH_NAME "AVX" +# define LIBRAPID_DEFAULT_MEM_ALIGN 128 #elif defined(__SSE4_2__) -# define LIBRAPID_SSE42 -# define LIBRAPID_ARCH SSE4_2 -# define LIBRAPID_ARCH_NAME "SSE4.2" -# define LIBRAPID_DEFAULT_MEM_ALIGN 64 +# define LIBRAPID_SSE42 +# define LIBRAPID_ARCH SSE4_2 +# define LIBRAPID_ARCH_NAME "SSE4.2" +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(__SSE4_1__) -# define LIBRAPID_SSE41 -# define LIBRAPID_ARCH SSE4_1 -# define LIBRAPID_ARCH_NAME "SSE4.1" -# define LIBRAPID_DEFAULT_MEM_ALIGN 64 +# define LIBRAPID_SSE41 +# define LIBRAPID_ARCH SSE4_1 +# define LIBRAPID_ARCH_NAME "SSE4.1" +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(__SSSE3__) -# define LIBRAPID_SSSE3 -# define LIBRAPID_ARCH SSSE3 -# define LIBRAPID_ARCH_NAME "SSSE3" -# define LIBRAPID_DEFAULT_MEM_ALIGN 64 +# define LIBRAPID_SSSE3 +# define LIBRAPID_ARCH SSSE3 +# define LIBRAPID_ARCH_NAME "SSSE3" +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(__SSE3__) -# define LIBRAPID_SSE3 -# define LIBRAPID_ARCH SSE3 -# define LIBRAPID_ARCH_NAME "SSE3" -# define LIBRAPID_DEFAULT_MEM_ALIGN 64 +# define LIBRAPID_SSE3 +# define LIBRAPID_ARCH SSE3 +# define LIBRAPID_ARCH_NAME "SSE3" +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(__SSE2__) || defined(__x86_64__) -# define LIBRAPID_SSE2 -# define LIBRAPID_ARCH SSE2 -# define LIBRAPID_ARCH_NAME "SSE2" -# define LIBRAPID_DEFAULT_MEM_ALIGN 64 +# define LIBRAPID_SSE2 +# define LIBRAPID_ARCH SSE2 +# define LIBRAPID_ARCH_NAME "SSE2" +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(__SSE__) -# define LIBRAPID_SSE -# define LIBRAPID_ARCH SSE -# define LIBRAPID_ARCH_NAME "SSE" -# define LIBRAPID_DEFAULT_MEM_ALIGN 64 +# define LIBRAPID_SSE +# define LIBRAPID_ARCH SSE +# define LIBRAPID_ARCH_NAME "SSE" +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 #elif defined(_M_IX86_FP) // Defined in MS compiler. 1: SSE, 2: SSE2 -# if _M_IX86_FP == 1 -# define LIBRAPID_SSE -# define LIBRAPID_ARCH SSE -# define LIBRAPID_ARCH_NAME "SSE" -# define LIBRAPID_DEFAULT_MEM_ALIGN 64 -# elif _M_IX86_FP == 2 -# define LIBRAPID_SSE2 -# define LIBRAPID_ARCH SSE2 -# define LIBRAPID_ARCH_NAME "SSE2" -# define LIBRAPID_DEFAULT_MEM_ALIGN 64 -# endif // _M_IX86_FP +# if _M_IX86_FP == 1 +# define LIBRAPID_SSE +# define LIBRAPID_ARCH SSE +# define LIBRAPID_ARCH_NAME "SSE" +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 +# elif _M_IX86_FP == 2 +# define LIBRAPID_SSE2 +# define LIBRAPID_ARCH SSE2 +# define LIBRAPID_ARCH_NAME "SSE2" +# define LIBRAPID_DEFAULT_MEM_ALIGN 64 +# endif // _M_IX86_FP #else -# define LIBRAPID_ARCH 0 -# define LIBRAPID_ARCH_NAME "None" -# define LIBRAPID_DEFAULT_MEM_ALIGN 32 +# define LIBRAPID_ARCH 0 +# define LIBRAPID_ARCH_NAME "None" +# define LIBRAPID_DEFAULT_MEM_ALIGN 32 #endif // Instruction set detection // Check for 32bit vs 64bit #if _WIN32 || _WIN64 // Check windows -# if _WIN64 -# define LIBRAPID_64BIT -# else -# define LIBRAPID_32BIT -# endif +# if _WIN64 +# define LIBRAPID_64BIT +# else +# define LIBRAPID_32BIT +# endif #elif __GNUC__ -# if __x86_64__ || __ppc64__ -# define LIBRAPID_64BIT -# else -# define LIBRAPID_32BIT -# endif +# if __x86_64__ || __ppc64__ +# define LIBRAPID_64BIT +# else +# define LIBRAPID_32BIT +# endif #else -# LIBRAPID_64BIT // Default to 64bit +# LIBRAPID_64BIT // Default to 64bit #endif // Branch prediction hints #ifdef LIBRAPID_20 -# define LIBRAPID_LIKELY [[likely]] -# define LIBRAPID_UNLIKELY [[unlikely]] +# define LIBRAPID_LIKELY [[likely]] +# define LIBRAPID_UNLIKELY [[unlikely]] #else -# define LIBRAPID_LIKELY -# define LIBRAPID_UNLIKELY +# define LIBRAPID_LIKELY +# define LIBRAPID_UNLIKELY #endif // [[nodiscard]] macro @@ -260,28 +260,28 @@ // Nicer FILENAME macro #if defined(FILENAME) -# warning \ - "The macro 'FILENAME' is already defined. LibRapid's logging system might not function correctly as a result" +# warning \ + "The macro 'FILENAME' is already defined. LibRapid's logging system might not function correctly as a result" #else -# ifdef LIBRAPID_OS_WINDOWS -# define FILENAME (strrchr(__FILE__, '\\') ? strrchr(__FILE__, '\\') + 1 : __FILE__) -# else -# define FILENAME (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__) -# endif +# ifdef LIBRAPID_OS_WINDOWS +# define FILENAME (strrchr(__FILE__, '\\') ? strrchr(__FILE__, '\\') + 1 : __FILE__) +# else +# define FILENAME (strrchr(__FILE__, '/') ? strrchr(__FILE__, '/') + 1 : __FILE__) +# endif #endif // Nicer FUNCTION macro #if defined(FUNCTION) -# warning \ - "The macro 'FUNCTION' is already defined. LibRapid's logging system might not function correctly as a result" +# warning \ + "The macro 'FUNCTION' is already defined. LibRapid's logging system might not function correctly as a result" #else -# if defined(LIBRAPID_MSVC) -# define FUNCTION __FUNCSIG__ -# elif defined(LIBRAPID_GNU) || defined(LIBRAPID_CLANG) || defined(LIBRAPID_CLING) -# define FUNCTION __PRETTY_FUNCTION__ -# else -# define FUNCTION "Function Signature Unknown" -# endif +# if defined(LIBRAPID_MSVC) +# define FUNCTION __FUNCSIG__ +# elif defined(LIBRAPID_GNU) || defined(LIBRAPID_CLANG) || defined(LIBRAPID_CLING) +# define FUNCTION __PRETTY_FUNCTION__ +# else +# define FUNCTION "Function Signature Unknown" +# endif #endif // STRINGIFY @@ -290,90 +290,90 @@ // Assertions, warnings and errors #if defined(LIBRAPID_DEBUG) && !defined(LIBRAPID_ENABLE_ASSERT) -# define LIBRAPID_ENABLE_ASSERT +# define LIBRAPID_ENABLE_ASSERT #endif // LIBRAPID_DEBUG && !LIBRAPID_ASSERT // Warn the user the first time this is called, but never again #if defined(LIBRAPID_ASSERT) -# define LIBRAPID_WARN_ONCE(msg, ...) \ - do { \ - static bool _alerted = false; \ - if (!_alerted) { \ - LIBRAPID_WARN(msg, __VA_ARGS__); \ - _alerted = true; \ - } \ - } while (false) +# define LIBRAPID_WARN_ONCE(msg, ...) \ + do { \ + static bool _alerted = false; \ + if (!_alerted) { \ + LIBRAPID_WARN(msg, __VA_ARGS__); \ + _alerted = true; \ + } \ + } while (false) #endif // LIBRAPID_ASSERT #if defined(LIBRAPID_ENABLE_ASSERT) || defined(LIBRAPID_DEBUG) -# define LIBRAPID_NOT_IMPLEMENTED LIBRAPID_ASSERT(false, "Not implemented"); +# define LIBRAPID_NOT_IMPLEMENTED LIBRAPID_ASSERT(false, "Not implemented"); #else -# define LIBRAPID_NOT_IMPLEMENTED throw std::runtime_error("Not implemented"); +# define LIBRAPID_NOT_IMPLEMENTED throw std::runtime_error("Not implemented"); #endif // Compiler-specific attributes #if defined(LIBRAPID_MSVC) -# include "msvcConfig.hpp" +# include "msvcConfig.hpp" #elif defined(LIBRAPID_GNU) || defined(LIBRAPID_CLANG) || defined(LIBRAPID_CLING) -# include "gnuConfig.hpp" +# include "gnuConfig.hpp" #else -# include "genericConfig.hpp" +# include "genericConfig.hpp" #endif #if defined(LIBRAPID_HAS_CUDA) -# include "cudaConfig.hpp" +# include "cudaConfig.hpp" #else namespace librapid::typetraits { - template - struct IsCudaStorage : std::false_type {}; + template + struct IsCudaStorage : std::false_type {}; } // namespace librapid::typetraits #endif #if defined(LIBRAPID_HAS_OPENCL) -# include "openclConfig.hpp" +# include "openclConfig.hpp" #else namespace librapid::typetraits { - template - struct IsOpenCLStorage : std::false_type {}; + template + struct IsOpenCLStorage : std::false_type {}; } // namespace librapid::typetraits #endif namespace librapid::backend { - // Use the CPU for computation (default) - struct CPU {}; + // Use the CPU for computation (default) + struct CPU {}; - // Use the GPU via CUDA - struct CUDA {}; + // Use the GPU via CUDA + struct CUDA {}; - // Use OpenCL - struct OpenCL {}; + // Use OpenCL + struct OpenCL {}; - // Use the fastest device for computation + // Use the fastest device for computation #if defined(LIBRAPID_HAS_CUDA) - using Fastest = CUDA; + using Fastest = CUDA; #elif defined(LIBRAPID_HAS_OPENCL) - using Fastest = OpenCL; + using Fastest = OpenCL; #else - using Fastest = CPU; + using Fastest = CPU; #endif - // GPU if available, CPU otherwise + // GPU if available, CPU otherwise #if defined(LIBRAPID_HAS_CUDA) - using CUDAIfAvailable = CUDA; + using CUDAIfAvailable = CUDA; #else - using CUDAIfAvailable = CPU; + using CUDAIfAvailable = CPU; #endif - // OpenCL if available, CPU otherwise + // OpenCL if available, CPU otherwise #if defined(LIBRAPID_HAS_OPENCL) - using OpenCLIfAvailable = OpenCL; + using OpenCLIfAvailable = OpenCL; #else - using OpenCLIfAvailable = CPU; + using OpenCLIfAvailable = CPU; #endif } // namespace librapid::backend #ifndef LIBRAPID_MAX_ARRAY_DIMS -# define LIBRAPID_MAX_ARRAY_DIMS 32 +# define LIBRAPID_MAX_ARRAY_DIMS 32 #endif // LIBRAPID_MAX_ARRAY_DIMS // Code to be run *before* main() diff --git a/librapid/include/librapid/core/core.hpp b/librapid/include/librapid/core/core.hpp index 7f6286a2..23edcc92 100644 --- a/librapid/include/librapid/core/core.hpp +++ b/librapid/include/librapid/core/core.hpp @@ -19,30 +19,30 @@ // Fourier Transform #if defined(LIBRAPID_HAS_FFTW) && !defined(LIBRAPID_HAS_CUDA) // If CUDA is enabled, we use cuFFT -# include +# include #endif // LIBRAPID_HAS_CUDA #if defined(LIBRAPID_MSVC) -#pragma warning(push) -#pragma warning(disable : 4324) -#pragma warning(disable : 4458) -#pragma warning(disable : 4456) +# pragma warning(push) +# pragma warning(disable : 4324) +# pragma warning(disable : 4458) +# pragma warning(disable : 4456) #endif // LIBRAPID_MSVC #include #if defined(LIBRAPID_MSVC) -#pragma warning(pop) +# pragma warning(pop) #endif // LIBRAPID_MSVC #if defined(LIBRAPID_HAS_OPENCL) -# include "../opencl/openclErrorIdentifier.hpp" -# include "../opencl/openclConfigure.hpp" -# include "../opencl/openclKernelProcessor.hpp" +# include "../opencl/openclErrorIdentifier.hpp" +# include "../opencl/openclConfigure.hpp" +# include "../opencl/openclKernelProcessor.hpp" #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) -# include "../cuda/cudaKernelProcesor.hpp" +# include "../cuda/cudaKernelProcesor.hpp" #endif // LIBRAPID_HAS_CUDA #endif // LIBRAPID_CORE \ No newline at end of file diff --git a/librapid/include/librapid/core/cudaConfig.hpp b/librapid/include/librapid/core/cudaConfig.hpp index afa8a3e5..28ac08ca 100644 --- a/librapid/include/librapid/core/cudaConfig.hpp +++ b/librapid/include/librapid/core/cudaConfig.hpp @@ -5,29 +5,29 @@ #ifdef LIBRAPID_HAS_CUDA // Under MSVC, supress a few warnings -# ifdef _MSC_VER -# pragma warning(push) -# pragma warning(disable : 4505) // unreferenced local function has been removed -# endif - -# define CUDA_NO_HALF // Ensure the cuda_helpers "half" data type is not defined - -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include - -# ifdef _MSC_VER -# pragma warning(pop) -# endif - -# include "../vendor/jitify/jitify.hpp" +# ifdef _MSC_VER +# pragma warning(push) +# pragma warning(disable : 4505) // unreferenced local function has been removed +# endif + +# define CUDA_NO_HALF // Ensure the cuda_helpers "half" data type is not defined + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# ifdef _MSC_VER +# pragma warning(pop) +# endif + +# include "../vendor/jitify/jitify.hpp" // cuBLAS API errors const char *getCublasErrorEnum_(cublasStatus_t error); @@ -36,54 +36,54 @@ const char *getCublasErrorEnum_(cublasStatus_t error); // cuBLAS ERROR CHECK // //********************// -# if !defined(cublasSafeCall) -# define cublasSafeCall(err) \ - LIBRAPID_ASSERT_ALWAYS( \ - (err) == CUBLAS_STATUS_SUCCESS, "cuBLAS error: {}", getCublasErrorEnum_(err)) -# endif +# if !defined(cublasSafeCall) +# define cublasSafeCall(err) \ + LIBRAPID_ASSERT_ALWAYS( \ + (err) == CUBLAS_STATUS_SUCCESS, "cuBLAS error: {}", getCublasErrorEnum_(err)) +# endif //********************// // CUDA ERROR CHECK // //********************// -# if defined(LIBRAPID_ENABLE_ASSERT) -# define cudaSafeCall(call) \ - LIBRAPID_ASSERT(!(call), "CUDA Assertion Failed: {}", cudaGetErrorString(call)) - -# define jitifyCall(call) \ - do { \ - if ((call) != CUDA_SUCCESS) { \ - const char *str; \ - cuGetErrorName(call, &str); \ - throw std::runtime_error(std::string("CUDA JIT failed: ") + str); \ - } \ - } while (0) -# else -# define cudaSafeCall(call) (call) -# define jitifyCall(call) (call) -# endif - -# ifdef _MSC_VER -# pragma warning(default : 4996) -# endif - -# include "../cuda/helper_cuda.h" -# include "../cuda/helper_functions.h" - -# define CONCAT_IMPL(x, y) x##y -# define CONCAT(x, y) CONCAT_IMPL(x, y) - -# if LIBRAPID_CUDA_FLOAT_VECTOR_WIDTH > 1 -# define CUDA_FLOAT_VECTOR_TYPE CONCAT(jitify::float, LIBRAPID_CUDA_FLOAT_VECTOR_WIDTH) -# else -# define CUDA_FLOAT_VECTOR_TYPE float -# endif - -# if LIBRAPID_CUDA_DOUBLE_VECTOR_WIDTH > 1 -# define CUDA_DOUBLE_VECTOR_TYPE CONCAT(jitify::double, LIBRAPID_CUDA_DOUBLE_VECTOR_WIDTH) -# else -# define CUDA_DOUBLE_VECTOR_TYPE double -# endif +# if defined(LIBRAPID_ENABLE_ASSERT) +# define cudaSafeCall(call) \ + LIBRAPID_ASSERT(!(call), "CUDA Assertion Failed: {}", cudaGetErrorString(call)) + +# define jitifyCall(call) \ + do { \ + if ((call) != CUDA_SUCCESS) { \ + const char *str; \ + cuGetErrorName(call, &str); \ + throw std::runtime_error(std::string("CUDA JIT failed: ") + str); \ + } \ + } while (0) +# else +# define cudaSafeCall(call) (call) +# define jitifyCall(call) (call) +# endif + +# ifdef _MSC_VER +# pragma warning(default : 4996) +# endif + +# include "../cuda/helper_cuda.h" +# include "../cuda/helper_functions.h" + +# define CONCAT_IMPL(x, y) x##y +# define CONCAT(x, y) CONCAT_IMPL(x, y) + +# if LIBRAPID_CUDA_FLOAT_VECTOR_WIDTH > 1 +# define CUDA_FLOAT_VECTOR_TYPE CONCAT(jitify::float, LIBRAPID_CUDA_FLOAT_VECTOR_WIDTH) +# else +# define CUDA_FLOAT_VECTOR_TYPE float +# endif + +# if LIBRAPID_CUDA_DOUBLE_VECTOR_WIDTH > 1 +# define CUDA_DOUBLE_VECTOR_TYPE CONCAT(jitify::double, LIBRAPID_CUDA_DOUBLE_VECTOR_WIDTH) +# else +# define CUDA_DOUBLE_VECTOR_TYPE double +# endif #endif // LIBRAPID_HAS_CUDA diff --git a/librapid/include/librapid/core/debugTrap.hpp b/librapid/include/librapid/core/debugTrap.hpp index 8e9d2523..16b03f46 100644 --- a/librapid/include/librapid/core/debugTrap.hpp +++ b/librapid/include/librapid/core/debugTrap.hpp @@ -9,77 +9,77 @@ */ #if !defined(PSNIP_DEBUG_TRAP_H) -# define PSNIP_DEBUG_TRAP_H +# define PSNIP_DEBUG_TRAP_H -# if !defined(PSNIP_NDEBUG) && defined(NDEBUG) && !defined(PSNIP_DEBUG) && \ - !defined(LIBRAPID_ENABLE_ASSERTIONS) -# define PSNIP_NDEBUG 1 -# endif +# if !defined(PSNIP_NDEBUG) && defined(NDEBUG) && !defined(PSNIP_DEBUG) && \ + !defined(LIBRAPID_ENABLE_ASSERTIONS) +# define PSNIP_NDEBUG 1 +# endif -# if defined(__has_builtin) && !defined(__ibmxl__) -# if __has_builtin(__builtin_debugtrap) -# define psnip_trap() __builtin_debugtrap() -# elif __has_builtin(__debugbreak) -# define psnip_trap() __debugbreak() -# endif -# endif -# if !defined(psnip_trap) -# if defined(_MSC_VER) || defined(__INTEL_COMPILER) -# define psnip_trap() __debugbreak() -# elif defined(__ARMCC_VERSION) -# define psnip_trap() __breakpoint(42) -# elif defined(__ibmxl__) || defined(__xlC__) -# include -# define psnip_trap() __trap(42) -# elif defined(__DMC__) && defined(_M_IX86) +# if defined(__has_builtin) && !defined(__ibmxl__) +# if __has_builtin(__builtin_debugtrap) +# define psnip_trap() __builtin_debugtrap() +# elif __has_builtin(__debugbreak) +# define psnip_trap() __debugbreak() +# endif +# endif +# if !defined(psnip_trap) +# if defined(_MSC_VER) || defined(__INTEL_COMPILER) +# define psnip_trap() __debugbreak() +# elif defined(__ARMCC_VERSION) +# define psnip_trap() __breakpoint(42) +# elif defined(__ibmxl__) || defined(__xlC__) +# include +# define psnip_trap() __trap(42) +# elif defined(__DMC__) && defined(_M_IX86) static inline void psnip_trap(void) { __asm int 3h; } -# elif defined(__i386__) || defined(__x86_64__) +# elif defined(__i386__) || defined(__x86_64__) static inline void psnip_trap(void) { __asm__ __volatile__("int3"); } -# elif defined(__thumb__) +# elif defined(__thumb__) static inline void psnip_trap(void) { __asm__ __volatile__(".inst 0xde01"); } -# elif defined(__aarch64__) +# elif defined(__aarch64__) static inline void psnip_trap(void) { __asm__ __volatile__(".inst 0xd4200000"); } -# elif defined(__arm__) +# elif defined(__arm__) static inline void psnip_trap(void) { __asm__ __volatile__(".inst 0xe7f001f0"); } -# elif defined(__alpha__) && !defined(__osf__) +# elif defined(__alpha__) && !defined(__osf__) static inline void psnip_trap(void) { __asm__ __volatile__("bpt"); } -# elif defined(_54_) +# elif defined(_54_) static inline void psnip_trap(void) { __asm__ __volatile__("ESTOP"); } -# elif defined(_55_) +# elif defined(_55_) static inline void psnip_trap(void) { - __asm__ __volatile__(";\n .if (.MNEMONIC)\n ESTOP_1\n .else\n ESTOP_1()\n .endif\n NOP"); + __asm__ __volatile__(";\n .if (.MNEMONIC)\n ESTOP_1\n .else\n ESTOP_1()\n .endif\n NOP"); } -# elif defined(_64P_) +# elif defined(_64P_) static inline void psnip_trap(void) { __asm__ __volatile__("SWBP 0"); } -# elif defined(_6x_) +# elif defined(_6x_) static inline void psnip_trap(void) { __asm__ __volatile__("NOP\n .word 0x10000000"); } -# elif defined(__STDC_HOSTED__) && (__STDC_HOSTED__ == 0) && defined(__GNUC__) -# define psnip_trap() __builtin_trap() -# else -# include -# if defined(SIGTRAP) -# define psnip_trap() raise(SIGTRAP) -# else -# define psnip_trap() raise(SIGABRT) -# endif -# endif -# endif +# elif defined(__STDC_HOSTED__) && (__STDC_HOSTED__ == 0) && defined(__GNUC__) +# define psnip_trap() __builtin_trap() +# else +# include +# if defined(SIGTRAP) +# define psnip_trap() raise(SIGTRAP) +# else +# define psnip_trap() raise(SIGABRT) +# endif +# endif +# endif -# if defined(HEDLEY_LIKELY) -# define PSNIP_DBG_LIKELY(expr) HEDLEY_LIKELY(expr) -# elif defined(__GNUC__) && (__GNUC__ >= 3) -# define PSNIP_DBG_LIKELY(expr) __builtin_expect(!!(expr), 1) -# else -# define PSNIP_DBG_LIKELY(expr) (!!(expr)) -# endif +# if defined(HEDLEY_LIKELY) +# define PSNIP_DBG_LIKELY(expr) HEDLEY_LIKELY(expr) +# elif defined(__GNUC__) && (__GNUC__ >= 3) +# define PSNIP_DBG_LIKELY(expr) __builtin_expect(!!(expr), 1) +# else +# define PSNIP_DBG_LIKELY(expr) (!!(expr)) +# endif -# if !defined(PSNIP_NDEBUG) || (PSNIP_NDEBUG == 0) -# define psnip_dbg_assert(expr) \ - do { \ - if (!PSNIP_DBG_LIKELY(expr)) { psnip_trap(); } \ - } while (0) -# else -# define psnip_dbg_assert(expr) -# endif +# if !defined(PSNIP_NDEBUG) || (PSNIP_NDEBUG == 0) +# define psnip_dbg_assert(expr) \ + do { \ + if (!PSNIP_DBG_LIKELY(expr)) { psnip_trap(); } \ + } while (0) +# else +# define psnip_dbg_assert(expr) +# endif #endif /* !defined(PSNIP_DEBUG_TRAP_H) */ \ No newline at end of file diff --git a/librapid/include/librapid/core/forward.hpp b/librapid/include/librapid/core/forward.hpp index f7b73a49..5f945ce4 100644 --- a/librapid/include/librapid/core/forward.hpp +++ b/librapid/include/librapid/core/forward.hpp @@ -4,99 +4,97 @@ #ifndef LIBRAPID_DOXYGEN namespace librapid { - template - class Shape; - - template - class Stride; - - template - class Storage; - - template - class FixedStorage; - - template - class OpenCLStorage; - - template - class CudaStorage; - - namespace array { - template - class ArrayContainer; - } - - namespace detail { - /// \brief Identifies which type of function is being used - namespace descriptor { - struct Trivial {}; /// Operation is trivial and can be done with a vectorised loop - struct Transpose {}; /// Operation is a matrix/array transposition - struct Matmul {}; /// Operation is a matrix/array multiplication - struct Combined {}; /// Operation is a combination of the above - } // namespace descriptor - - template - class Function; - - template>::value, - int> = 0> - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer> &lhs, - const detail::Function &function); - - template>::value, - int> = 0> - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer> &lhs, - const detail::Function &function); - - template>::value, - int> = 0> - LIBRAPID_ALWAYS_INLINE void assignParallel( - array::ArrayContainer> &lhs, - const detail::Function &function); - - template>::value, - int> = 0> - LIBRAPID_ALWAYS_INLINE void assignParallel( - array::ArrayContainer> &lhs, - const detail::Function &function); - -# if defined(LIBRAPID_HAS_OPENCL) - template>::value, - int> = 0> - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer> &lhs, - const detail::Function &function); - -# endif // LIBRAPID_HAS_CUDA - -# if defined(LIBRAPID_HAS_CUDA) - template>::value, - int> = 0> - LIBRAPID_ALWAYS_INLINE void - assign(array::ArrayContainer> &lhs, - const detail::Function &function); - -# endif // LIBRAPID_HAS_CUDA - } // namespace detail + template + class Shape; + + template + class Stride; + + template + class Storage; + + template + class FixedStorage; + + template + class OpenCLStorage; + + template + class CudaStorage; + + namespace array { + template + class ArrayContainer; + } + + namespace detail { + /// \brief Identifies which type of function is being used + namespace descriptor { + struct Trivial {}; /// Operation is trivial and can be done with a vectorised loop + struct Transpose {}; /// Operation is a matrix/array transposition + struct Matmul {}; /// Operation is a matrix/array multiplication + struct Combined {}; /// Operation is a combination of the above + } // namespace descriptor + + template + class Function; + + template>::value, + int> = 0> + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer> &lhs, + const detail::Function &function); + + template>::value, + int> = 0> + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer> &lhs, + const detail::Function &function); + + template>::value, + int> = 0> + LIBRAPID_ALWAYS_INLINE void + assignParallel(array::ArrayContainer> &lhs, + const detail::Function &function); + + template>::value, + int> = 0> + LIBRAPID_ALWAYS_INLINE void assignParallel( + array::ArrayContainer> &lhs, + const detail::Function &function); + +# if defined(LIBRAPID_HAS_OPENCL) + template>::value, + int> = 0> + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer> &lhs, + const detail::Function &function); + +# endif // LIBRAPID_HAS_CUDA + +# if defined(LIBRAPID_HAS_CUDA) + template>::value, + int> = 0> + LIBRAPID_ALWAYS_INLINE void + assign(array::ArrayContainer> &lhs, + const detail::Function &function); + +# endif // LIBRAPID_HAS_CUDA + } // namespace detail } // namespace librapid #endif // LIBRAPID_DOXYGEN diff --git a/librapid/include/librapid/core/genericConfig.hpp b/librapid/include/librapid/core/genericConfig.hpp index de9f2689..38003b20 100644 --- a/librapid/include/librapid/core/genericConfig.hpp +++ b/librapid/include/librapid/core/genericConfig.hpp @@ -1,168 +1,168 @@ #ifndef LIBRAPID_CORE_GNU_CONFIG_HPP #define LIBRAPID_CORE_GNU_CONFIG_HPP -#define LIBRAPID_INLINE inline +#define LIBRAPID_INLINE inline #define LIBRAPID_ALWAYS_INLINE inline #define LIBRAPID_ASSERT_ALWAYS(cond, msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - if (!(cond)) { \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)), \ - (int)strlen(FILENAME), \ - (int)funcName.length(), \ - (int)strlen(#cond), \ - (int)strlen("ASSERTION FAILED")); \ - std::string formatted = fmt::format( \ - "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ - "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ - "{4:>{10}}]\n{5}\n", \ - "ASSERTION FAILED", \ - FILENAME, \ - funcName, \ - __LINE__, \ - #cond, \ - fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ - maxLen + 14, \ - maxLen + 9, \ - maxLen + 5, \ - maxLen + 9, \ - maxLen + 4); \ - if (librapid::global::throwOnAssert) { \ - throw std::runtime_error(formatted); \ - } else { \ - fmt::print(fmt::fg(fmt::color::red), formatted); \ - psnip_trap(); \ - } \ - } \ - } while (0) + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + if (!(cond)) { \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)), \ + (int)strlen(FILENAME), \ + (int)funcName.length(), \ + (int)strlen(#cond), \ + (int)strlen("ASSERTION FAILED")); \ + std::string formatted = fmt::format( \ + "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ + "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ + "{4:>{10}}]\n{5}\n", \ + "ASSERTION FAILED", \ + FILENAME, \ + funcName, \ + __LINE__, \ + #cond, \ + fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ + maxLen + 14, \ + maxLen + 9, \ + maxLen + 5, \ + maxLen + 9, \ + maxLen + 4); \ + if (librapid::global::throwOnAssert) { \ + throw std::runtime_error(formatted); \ + } else { \ + fmt::print(fmt::fg(fmt::color::red), formatted); \ + psnip_trap(); \ + } \ + } \ + } while (0) #if defined(LIBRAPID_ENABLE_ASSERT) -# define LIBRAPID_STATUS(msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - fmt::print(fmt::fg(fmt::color::green), \ - "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ - "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ - "STATUS", \ - FILENAME, \ - funcName, \ - __LINE__, \ - fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen); \ - } while (0) +# define LIBRAPID_STATUS(msg, ...) \ + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + fmt::print(fmt::fg(fmt::color::green), \ + "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ + "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ + "STATUS", \ + FILENAME, \ + funcName, \ + __LINE__, \ + fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen); \ + } while (0) -# define LIBRAPID_WARN(msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - fmt::print(fmt::fg(fmt::color::yellow), \ - "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ - "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ - "WARNING", \ - FILENAME, \ - funcName, \ - __LINE__, \ - fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen); \ - } while (0) +# define LIBRAPID_WARN(msg, ...) \ + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + fmt::print(fmt::fg(fmt::color::yellow), \ + "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ + "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ + "WARNING", \ + FILENAME, \ + funcName, \ + __LINE__, \ + fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen); \ + } while (0) -# define LIBRAPID_ERROR(msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapiod::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - fmt::print(fmt::fg(fmt::color::red), \ - "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ - "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ - "ERROR", \ - FILENAME, \ - funcName, \ - __LINE__, \ - fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen); \ - if (librapid::global::throwOnAssert) { \ - throw std::runtime_error(formatted); \ - } else { \ - fmt::print(fmt::fg(fmt::color::red), formatted); \ - psnip_trap(); \ - } \ - } while (0) +# define LIBRAPID_ERROR(msg, ...) \ + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapiod::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + fmt::print(fmt::fg(fmt::color::red), \ + "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ + "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ + "ERROR", \ + FILENAME, \ + funcName, \ + __LINE__, \ + fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen); \ + if (librapid::global::throwOnAssert) { \ + throw std::runtime_error(formatted); \ + } else { \ + fmt::print(fmt::fg(fmt::color::red), formatted); \ + psnip_trap(); \ + } \ + } while (0) -# define LIBRAPID_WASSERT(cond, msg, ...) \ - do { \ - if (!(cond)) { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen(#cond) + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - fmt::print(fmt::fg(fmt::color::yellow), \ - "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ - "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ - "{4:>{10}}]\n{5}\n", \ - "WARN ASSERTION FAILED", \ - FILENAME, \ - funcName, \ - __LINE__, \ - #cond, \ - fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen + 0, \ - maxLen - 5); \ - } \ - } while (0) +# define LIBRAPID_WASSERT(cond, msg, ...) \ + do { \ + if (!(cond)) { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen(#cond) + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + fmt::print(fmt::fg(fmt::color::yellow), \ + "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ + "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ + "{4:>{10}}]\n{5}\n", \ + "WARN ASSERTION FAILED", \ + FILENAME, \ + funcName, \ + __LINE__, \ + #cond, \ + fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen + 0, \ + maxLen - 5); \ + } \ + } while (0) -# define LIBRAPID_ASSERT(cond, msg, ...) \ - LIBRAPID_ASSERT_ALWAYS(cond, msg __VA_OPT__(, ) __VA_ARGS__) +# define LIBRAPID_ASSERT(cond, msg, ...) \ + LIBRAPID_ASSERT_ALWAYS(cond, msg __VA_OPT__(, ) __VA_ARGS__) #else -# define LIBRAPID_WARN_ONCE(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_STATUS(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_WARN(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_ERROR(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_LOG(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_WASSERT(cond, ...) \ - do { \ - } while (0) -# define LIBRAPID_ASSERT(cond, ...) \ - do { \ - } while (0) +# define LIBRAPID_WARN_ONCE(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_STATUS(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_WARN(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_ERROR(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_LOG(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_WASSERT(cond, ...) \ + do { \ + } while (0) +# define LIBRAPID_ASSERT(cond, ...) \ + do { \ + } while (0) #endif // LIBRAPID_ENABLE_ASSERT #define PURE_FUNCTION [[nodiscard]] constexpr diff --git a/librapid/include/librapid/core/global.hpp b/librapid/include/librapid/core/global.hpp index a5718ace..1c8329f9 100644 --- a/librapid/include/librapid/core/global.hpp +++ b/librapid/include/librapid/core/global.hpp @@ -7,82 +7,82 @@ */ namespace librapid { - namespace global { - // Should ASSERT functions error or throw exceptions? - extern bool throwOnAssert; + namespace global { + // Should ASSERT functions error or throw exceptions? + extern bool throwOnAssert; - /// Arrays with more elements than this will run with multithreaded implementations - extern size_t multithreadThreshold; + /// Arrays with more elements than this will run with multithreaded implementations + extern size_t multithreadThreshold; - // Number of columns required for a matrix to be parallelized in GEMM - extern size_t gemmMultithreadThreshold; + // Number of columns required for a matrix to be parallelized in GEMM + extern size_t gemmMultithreadThreshold; - // Number of columns required for a matrix to be parallelized in GEMV - extern size_t gemvMultithreadThreshold; + // Number of columns required for a matrix to be parallelized in GEMV + extern size_t gemvMultithreadThreshold; - // Number of threads used by LibRapid - extern size_t numThreads; + // Number of threads used by LibRapid + extern size_t numThreads; - // Random seed used by LibRapid (when changed, the random number generator is reseeded) - extern size_t randomSeed; + // Random seed used by LibRapid (when changed, the random number generator is reseeded) + extern size_t randomSeed; - // Should the random number generator be reseeded? - extern bool reseed; + // Should the random number generator be reseeded? + extern bool reseed; - // Size of a cache line in bytes - extern size_t cacheLineSize; + // Size of a cache line in bytes + extern size_t cacheLineSize; - // Memory alignment for LibRapid - extern size_t memoryAlignment; + // Memory alignment for LibRapid + extern size_t memoryAlignment; #if defined(LIBRAPID_HAS_OPENCL) - // OpenCL device list - extern std::vector openclDevices; + // OpenCL device list + extern std::vector openclDevices; - // OpenCL context - extern cl::Context openCLContext; + // OpenCL context + extern cl::Context openCLContext; - // OpenCL device - extern cl::Device openCLDevice; + // OpenCL device + extern cl::Device openCLDevice; - // OpenCL command queue - extern cl::CommandQueue openCLQueue; + // OpenCL command queue + extern cl::CommandQueue openCLQueue; - // OpenCL program sources - extern cl::Program::Sources openCLSources; + // OpenCL program sources + extern cl::Program::Sources openCLSources; - // OpenCL program - extern cl::Program openCLProgram; + // OpenCL program + extern cl::Program openCLProgram; - // True if OpenCL has been configured - extern bool openCLConfigured; + // True if OpenCL has been configured + extern bool openCLConfigured; #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - // LibRapid's CUDA stream -- this removes the need for calling cudaDeviceSynchronize() - extern cudaStream_t cudaStream; + // LibRapid's CUDA stream -- this removes the need for calling cudaDeviceSynchronize() + extern cudaStream_t cudaStream; - // LibRapid's CuBLAS handle - extern cublasHandle_t cublasHandle; + // LibRapid's CuBLAS handle + extern cublasHandle_t cublasHandle; - // LibRapid's CuBLASLt handle - extern cublasLtHandle_t cublasLtHandle; + // LibRapid's CuBLASLt handle + extern cublasLtHandle_t cublasLtHandle; - extern uint64_t cublasLtWorkspaceSize; + extern uint64_t cublasLtWorkspaceSize; - // LibRapid's CuBLASLt workspace - extern void *cublasLtWorkspace; + // LibRapid's CuBLASLt workspace + extern void *cublasLtWorkspace; - // Jitify cache for CUDA kernels - extern jitify::JitCache jitCache; + // Jitify cache for CUDA kernels + extern jitify::JitCache jitCache; #endif // LIBRAPID_HAS_CUDA - } // namespace global + } // namespace global - void setNumThreads(size_t numThreads); - size_t getNumThreads(); + void setNumThreads(size_t numThreads); + size_t getNumThreads(); - void setSeed(size_t seed); - size_t getSeed(); + void setSeed(size_t seed); + size_t getSeed(); } // namespace librapid #endif // LIBRAPID_CORE_GLOBAL_HPP \ No newline at end of file diff --git a/librapid/include/librapid/core/gnuConfig.hpp b/librapid/include/librapid/core/gnuConfig.hpp index 5395c90c..be214525 100644 --- a/librapid/include/librapid/core/gnuConfig.hpp +++ b/librapid/include/librapid/core/gnuConfig.hpp @@ -1,168 +1,168 @@ #ifndef LIBRAPID_CORE_GNU_CONFIG_HPP #define LIBRAPID_CORE_GNU_CONFIG_HPP -#define LIBRAPID_INLINE inline +#define LIBRAPID_INLINE inline #define LIBRAPID_ALWAYS_INLINE inline __attribute__((always_inline)) #define LIBRAPID_ASSERT_ALWAYS(cond, msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - if (!(cond)) { \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)), \ - (int)strlen(FILENAME), \ - (int)funcName.length(), \ - (int)strlen(#cond), \ - (int)strlen("ASSERTION FAILED")); \ - std::string formatted = fmt::format( \ - "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ - "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ - "{4:>{10}}]\n{5}\n", \ - "ASSERTION FAILED", \ - FILENAME, \ - funcName, \ - __LINE__, \ - #cond, \ - fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ - maxLen + 14, \ - maxLen + 9, \ - maxLen + 5, \ - maxLen + 9, \ - maxLen + 4); \ - if (librapid::global::throwOnAssert) { \ - throw std::runtime_error(formatted); \ - } else { \ - fmt::print(fmt::fg(fmt::color::red), formatted); \ - psnip_trap(); \ - } \ - } \ - } while (0) + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + if (!(cond)) { \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)), \ + (int)strlen(FILENAME), \ + (int)funcName.length(), \ + (int)strlen(#cond), \ + (int)strlen("ASSERTION FAILED")); \ + std::string formatted = fmt::format( \ + "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ + "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ + "{4:>{10}}]\n{5}\n", \ + "ASSERTION FAILED", \ + FILENAME, \ + funcName, \ + __LINE__, \ + #cond, \ + fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ + maxLen + 14, \ + maxLen + 9, \ + maxLen + 5, \ + maxLen + 9, \ + maxLen + 4); \ + if (librapid::global::throwOnAssert) { \ + throw std::runtime_error(formatted); \ + } else { \ + fmt::print(fmt::fg(fmt::color::red), formatted); \ + psnip_trap(); \ + } \ + } \ + } while (0) #if defined(LIBRAPID_ENABLE_ASSERT) -# define LIBRAPID_STATUS(msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - fmt::print(fmt::fg(fmt::color::green), \ - "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ - "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ - "STATUS", \ - FILENAME, \ - funcName, \ - __LINE__, \ - fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen); \ - } while (0) +# define LIBRAPID_STATUS(msg, ...) \ + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + fmt::print(fmt::fg(fmt::color::green), \ + "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ + "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ + "STATUS", \ + FILENAME, \ + funcName, \ + __LINE__, \ + fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen); \ + } while (0) -# define LIBRAPID_WARN(msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - fmt::print(fmt::fg(fmt::color::yellow), \ - "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ - "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ - "WARNING", \ - FILENAME, \ - funcName, \ - __LINE__, \ - fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen); \ - } while (0) +# define LIBRAPID_WARN(msg, ...) \ + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + fmt::print(fmt::fg(fmt::color::yellow), \ + "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ + "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ + "WARNING", \ + FILENAME, \ + funcName, \ + __LINE__, \ + fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen); \ + } while (0) -# define LIBRAPID_ERROR(msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - std::string formatted = fmt::format(fmt::fg(fmt::color::red), \ - "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ - "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ - "ERROR", \ - FILENAME, \ - funcName, \ - __LINE__, \ - fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen); \ - if (librapid::global::throwOnAssert) { \ - throw std::runtime_error(formatted); \ - } else { \ - fmt::print(fmt::fg(fmt::color::red), formatted); \ - psnip_trap(); \ - } \ - } while (0) +# define LIBRAPID_ERROR(msg, ...) \ + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + std::string formatted = fmt::format(fmt::fg(fmt::color::red), \ + "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ + "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ + "ERROR", \ + FILENAME, \ + funcName, \ + __LINE__, \ + fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen); \ + if (librapid::global::throwOnAssert) { \ + throw std::runtime_error(formatted); \ + } else { \ + fmt::print(fmt::fg(fmt::color::red), formatted); \ + psnip_trap(); \ + } \ + } while (0) -# define LIBRAPID_WASSERT(cond, msg, ...) \ - do { \ - if (!(cond)) { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen(#cond) + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - fmt::print(fmt::fg(fmt::color::yellow), \ - "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ - "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ - "{4:>{10}}]\n{5}\n", \ - "WARN ASSERTION FAILED", \ - FILENAME, \ - funcName, \ - __LINE__, \ - #cond, \ - fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen + 0, \ - maxLen - 5); \ - } \ - } while (0) +# define LIBRAPID_WASSERT(cond, msg, ...) \ + do { \ + if (!(cond)) { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen(#cond) + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + fmt::print(fmt::fg(fmt::color::yellow), \ + "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ + "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ + "{4:>{10}}]\n{5}\n", \ + "WARN ASSERTION FAILED", \ + FILENAME, \ + funcName, \ + __LINE__, \ + #cond, \ + fmt::format(msg __VA_OPT__(, ) __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen + 0, \ + maxLen - 5); \ + } \ + } while (0) -# define LIBRAPID_ASSERT(cond, msg, ...) \ - LIBRAPID_ASSERT_ALWAYS(cond, msg __VA_OPT__(, ) __VA_ARGS__) +# define LIBRAPID_ASSERT(cond, msg, ...) \ + LIBRAPID_ASSERT_ALWAYS(cond, msg __VA_OPT__(, ) __VA_ARGS__) #else -# define LIBRAPID_WARN_ONCE(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_STATUS(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_WARN(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_ERROR(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_LOG(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_WASSERT(cond, ...) \ - do { \ - } while (0) -# define LIBRAPID_ASSERT(cond, ...) \ - do { \ - } while (0) +# define LIBRAPID_WARN_ONCE(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_STATUS(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_WARN(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_ERROR(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_LOG(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_WASSERT(cond, ...) \ + do { \ + } while (0) +# define LIBRAPID_ASSERT(cond, ...) \ + do { \ + } while (0) #endif // LIBRAPID_ENABLE_ASSERT #define PURE_FUNCTION __attribute__((const)) diff --git a/librapid/include/librapid/core/helperMacros.hpp b/librapid/include/librapid/core/helperMacros.hpp index 8065a64c..5515ed39 100644 --- a/librapid/include/librapid/core/helperMacros.hpp +++ b/librapid/include/librapid/core/helperMacros.hpp @@ -8,77 +8,77 @@ #define COMMA , #define LIBRAPID_SIMPLE_IO_IMPL(TEMPLATE_, TYPE_) \ - template \ - struct fmt::formatter { \ - char formatStr[32] = {'{', ':'}; \ + template \ + struct fmt::formatter { \ + char formatStr[32] = {'{', ':'}; \ \ - constexpr auto parse(format_parse_context &ctx) -> format_parse_context::iterator { \ - auto it = ctx.begin(); \ - uint64_t index = 0; \ - for (; it != ctx.end(); ++it) { \ - if (*it == '}') break; \ - formatStr[index++] += *it; \ - } \ - formatStr[index] = '}'; \ - return it; \ - } \ + constexpr auto parse(format_parse_context &ctx) -> format_parse_context::iterator { \ + auto it = ctx.begin(); \ + uint64_t index = 0; \ + for (; it != ctx.end(); ++it) { \ + if (*it == '}') break; \ + formatStr[index++] += *it; \ + } \ + formatStr[index] = '}'; \ + return it; \ + } \ \ - template \ - auto format(const TYPE_ &object, FormatContext &ctx) { \ - try { \ - return fmt::format_to(ctx.out(), object.str(formatStr)); \ - } catch (std::exception & e) { return fmt::format_to(ctx.out(), e.what()); } \ - } \ - }; \ + template \ + auto format(const TYPE_ &object, FormatContext &ctx) { \ + try { \ + return fmt::format_to(ctx.out(), object.str(formatStr)); \ + } catch (std::exception & e) { return fmt::format_to(ctx.out(), e.what()); } \ + } \ + }; \ \ - template \ - std::ostream &operator<<(std::ostream &os, const TYPE_ &object) { \ - os << object.str(); \ - return os; \ - } + template \ + std::ostream &operator<<(std::ostream &os, const TYPE_ &object) { \ + os << object.str(); \ + return os; \ + } #define LIBRAPID_SIMPLE_IO_IMPL_NO_TEMPLATE(TYPE_) \ - template<> \ - struct fmt::formatter { \ - std::string formatStr = "{}"; \ + template<> \ + struct fmt::formatter { \ + std::string formatStr = "{}"; \ \ - template \ - constexpr auto parse(ParseContext &ctx) { \ - formatStr = "{:"; \ - auto it = ctx.begin(); \ - for (; it != ctx.end(); ++it) { \ - if (*it == '}') break; \ - formatStr += *it; \ - } \ - formatStr += "}"; \ - return it; \ - } \ + template \ + constexpr auto parse(ParseContext &ctx) { \ + formatStr = "{:"; \ + auto it = ctx.begin(); \ + for (; it != ctx.end(); ++it) { \ + if (*it == '}') break; \ + formatStr += *it; \ + } \ + formatStr += "}"; \ + return it; \ + } \ \ - template \ - auto format(const TYPE_ &object, FormatContext &ctx) { \ - try { \ - return fmt::format_to(ctx.out(), object.str(formatStr)); \ - } catch (std::exception & e) { return fmt::format_to(ctx.out(), e.what()); } \ - } \ - }; \ + template \ + auto format(const TYPE_ &object, FormatContext &ctx) { \ + try { \ + return fmt::format_to(ctx.out(), object.str(formatStr)); \ + } catch (std::exception & e) { return fmt::format_to(ctx.out(), e.what()); } \ + } \ + }; \ \ - LIBRAPID_INLINE std::ostream &operator<<(std::ostream &os, const TYPE_ &object) { \ - os << object.str(); \ - return os; \ - } + LIBRAPID_INLINE std::ostream &operator<<(std::ostream &os, const TYPE_ &object) { \ + os << object.str(); \ + return os; \ + } #define LIBRAPID_SIMPLE_IO_NORANGE(TEMPLATE, TYPE) \ - template \ - struct fmt::is_range : std::false_type {}; + template \ + struct fmt::is_range : std::false_type {}; namespace librapid::typetraits { - template - struct IsLibRapidType : std::false_type {}; + template + struct IsLibRapidType : std::false_type {}; } // namespace librapid::typetraits // Define a type as being part of librapid -- this should be contained in the typetraits namespace #define LIBRAPID_DEFINE_AS_TYPE(TEMPLATE_, TYPE_) \ - template \ - struct IsLibRapidType : std::true_type {} + template \ + struct IsLibRapidType : std::true_type {} #endif // LIBRAPID_CORE_HELPER_MACROS \ No newline at end of file diff --git a/librapid/include/librapid/core/librapidPch.hpp b/librapid/include/librapid/core/librapidPch.hpp index 79f76913..3390b57f 100644 --- a/librapid/include/librapid/core/librapidPch.hpp +++ b/librapid/include/librapid/core/librapidPch.hpp @@ -32,12 +32,12 @@ #include #if defined(LIBRAPID_HAS_OMP) -# include +# include #endif // LIBRAPID_HAS_OMP #if (defined(_WIN32) || defined(_WIN64)) && !defined(LIBRAPID_NO_WINDOWS_H) -# define WIN32_LEAN_AND_MEAN -# include +# define WIN32_LEAN_AND_MEAN +# include #endif // Remove a few macros @@ -59,17 +59,17 @@ #if !defined(LIBRAPID_MINGW) // MinGW does not implement std::from_chars which is required by scnlib // scnlib -# include -# include +# include +# include #endif // !LIBRAPID_MINGW // Vc -- SIMD instructions #if defined(_MSC_VER) // For Vc, we need to disable the following warnings -# pragma warning(push) -# pragma warning(disable : 4244) // conversion from 'int' to 'float', possible loss of data -# pragma warning(disable : 4324) // structure was padded due to alignment specifier -# pragma warning(disable : 4127) // conditional expression is constant +# pragma warning(push) +# pragma warning(disable : 4244) // conversion from 'int' to 'float', possible loss of data +# pragma warning(disable : 4324) // structure was padded due to alignment specifier +# pragma warning(disable : 4127) // conditional expression is constant #endif // #include @@ -80,13 +80,13 @@ #include #if defined(_MSC_VER) -# pragma warning(pop) +# pragma warning(pop) #endif // MPFR (modified) -- arbitrary precision floating point numbers #if defined(LIBRAPID_USE_MULTIPREC) -# include -# include +# include +# include #endif // LIBRAPID_USE_MULTIPREC #endif // LIBRAPID_CORE_LIBRAPID_PCH_HPP \ No newline at end of file diff --git a/librapid/include/librapid/core/literals.hpp b/librapid/include/librapid/core/literals.hpp index 6560b63f..f3312686 100644 --- a/librapid/include/librapid/core/literals.hpp +++ b/librapid/include/librapid/core/literals.hpp @@ -3,10 +3,10 @@ namespace librapid::literals { #if defined(LIBRAPID_USE_MULTIPREC) - /// \brief Creates a multiprecision floating point number from a string literal - /// \param str The string literal to convert - /// \return The multiprecision floating point number - ::librapid::mpfr operator""_f(const char *str, size_t); + /// \brief Creates a multiprecision floating point number from a string literal + /// \param str The string literal to convert + /// \return The multiprecision floating point number + ::librapid::mpfr operator""_f(const char *str, size_t); #endif // LIBRAPID_USE_MULTIPREC } // namespace librapid::literals diff --git a/librapid/include/librapid/core/msvcConfig.hpp b/librapid/include/librapid/core/msvcConfig.hpp index 2fb1e609..86b05764 100644 --- a/librapid/include/librapid/core/msvcConfig.hpp +++ b/librapid/include/librapid/core/msvcConfig.hpp @@ -1,168 +1,168 @@ #ifndef LIBRAPID_CORE_MSVC_CONFIG_HPP #define LIBRAPID_CORE_MSVC_CONFIG_HPP -#define LIBRAPID_INLINE inline +#define LIBRAPID_INLINE inline #define LIBRAPID_ALWAYS_INLINE inline __forceinline #define LIBRAPID_ASSERT_ALWAYS(cond, msg, ...) \ - do { \ - if (!(cond)) { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)), \ - (int)strlen(FILENAME), \ - (int)funcName.length(), \ - (int)strlen(#cond), \ - (int)strlen("ASSERTION FAILED")); \ - std::string formatted = fmt::format( \ - "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ - "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ - "{4:>{10}}]\n{5}\n", \ - "ASSERTION FAILED", \ - FILENAME, \ - funcName, \ - __LINE__, \ - #cond, \ - fmt::format(msg, __VA_ARGS__), \ - maxLen + 14, \ - maxLen + 9, \ - maxLen + 5, \ - maxLen + 9, \ - maxLen + 4); \ - if (librapid::global::throwOnAssert) { \ - throw std::runtime_error(formatted); \ - } else { \ - fmt::print(fmt::fg(fmt::color::red), formatted); \ - psnip_trap(); \ - } \ - } \ - } while (0) + do { \ + if (!(cond)) { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)), \ + (int)strlen(FILENAME), \ + (int)funcName.length(), \ + (int)strlen(#cond), \ + (int)strlen("ASSERTION FAILED")); \ + std::string formatted = fmt::format( \ + "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ + "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ + "{4:>{10}}]\n{5}\n", \ + "ASSERTION FAILED", \ + FILENAME, \ + funcName, \ + __LINE__, \ + #cond, \ + fmt::format(msg, __VA_ARGS__), \ + maxLen + 14, \ + maxLen + 9, \ + maxLen + 5, \ + maxLen + 9, \ + maxLen + 4); \ + if (librapid::global::throwOnAssert) { \ + throw std::runtime_error(formatted); \ + } else { \ + fmt::print(fmt::fg(fmt::color::red), formatted); \ + psnip_trap(); \ + } \ + } \ + } while (0) #if defined(LIBRAPID_ENABLE_ASSERT) -# define LIBRAPID_STATUS(msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - fmt::print(fmt::fg(fmt::color::green), \ - "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ - "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ - "STATUS", \ - FILENAME, \ - funcName, \ - __LINE__, \ - fmt::format(msg, __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen); \ - } while (0) +# define LIBRAPID_STATUS(msg, ...) \ + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + fmt::print(fmt::fg(fmt::color::green), \ + "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ + "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ + "STATUS", \ + FILENAME, \ + funcName, \ + __LINE__, \ + fmt::format(msg, __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen); \ + } while (0) -# define LIBRAPID_WARN(msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - fmt::print(fmt::fg(fmt::color::yellow), \ - "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ - "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ - "WARNING", \ - FILENAME, \ - funcName, \ - __LINE__, \ - fmt::format(msg, __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen); \ - } while (0) +# define LIBRAPID_WARN(msg, ...) \ + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + fmt::print(fmt::fg(fmt::color::yellow), \ + "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ + "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ + "WARNING", \ + FILENAME, \ + funcName, \ + __LINE__, \ + fmt::format(msg, __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen); \ + } while (0) -# define LIBRAPID_ERROR(msg, ...) \ - do { \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - std::string formatted = fmt::format( \ - "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ - "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ - "ERROR", \ - FILENAME, \ - funcName, \ - __LINE__, \ - fmt::format(msg, __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen); \ - if (librapid::global::throwOnAssert) { \ - throw std::runtime_error(formatted); \ - } else { \ - fmt::print(fmt::fg(fmt::color::red), formatted); \ - psnip_trap(); \ - } \ - } while (0) +# define LIBRAPID_ERROR(msg, ...) \ + do { \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + std::string formatted = fmt::format( \ + "[{0:-^{5}}]\n[File {1:>{6}}]\n[Function " \ + "{2:>{7}}]\n[Line {3:>{8}}]\n{4}\n", \ + "ERROR", \ + FILENAME, \ + funcName, \ + __LINE__, \ + fmt::format(msg, __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen); \ + if (librapid::global::throwOnAssert) { \ + throw std::runtime_error(formatted); \ + } else { \ + fmt::print(fmt::fg(fmt::color::red), formatted); \ + psnip_trap(); \ + } \ + } while (0) -# define LIBRAPID_WASSERT(cond, msg, ...) \ - std::string funcName = FUNCTION; \ - if (funcName.length() > 75) funcName = ""; \ - do { \ - if (!(cond)) { \ - int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ - (int)strlen(FILENAME) + 6, \ - (int)funcName.length() + 6, \ - (int)strlen(#cond) + 6, \ - (int)strlen("WARN ASSERTION FAILED")); \ - fmt::print(fmt::fg(fmt::color::yellow), \ - "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ - "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ - "{4:>{10}}]\n{5}\n", \ - "WARN ASSERTION FAILED", \ - FILENAME, \ - funcName, \ - __LINE__, \ - #cond, \ - fmt::format(msg, __VA_ARGS__), \ - maxLen + 5, \ - maxLen + 0, \ - maxLen - 4, \ - maxLen + 0, \ - maxLen - 5); \ - } \ - } while (0) +# define LIBRAPID_WASSERT(cond, msg, ...) \ + std::string funcName = FUNCTION; \ + if (funcName.length() > 75) funcName = ""; \ + do { \ + if (!(cond)) { \ + int maxLen = librapid::detail::internalMax((int)std::ceil(std::log(__LINE__)) + 6, \ + (int)strlen(FILENAME) + 6, \ + (int)funcName.length() + 6, \ + (int)strlen(#cond) + 6, \ + (int)strlen("WARN ASSERTION FAILED")); \ + fmt::print(fmt::fg(fmt::color::yellow), \ + "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function " \ + "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition " \ + "{4:>{10}}]\n{5}\n", \ + "WARN ASSERTION FAILED", \ + FILENAME, \ + funcName, \ + __LINE__, \ + #cond, \ + fmt::format(msg, __VA_ARGS__), \ + maxLen + 5, \ + maxLen + 0, \ + maxLen - 4, \ + maxLen + 0, \ + maxLen - 5); \ + } \ + } while (0) -# define LIBRAPID_ASSERT(cond, msg, ...) LIBRAPID_ASSERT_ALWAYS(cond, msg, __VA_ARGS__) +# define LIBRAPID_ASSERT(cond, msg, ...) LIBRAPID_ASSERT_ALWAYS(cond, msg, __VA_ARGS__) #else -# define LIBRAPID_WARN_ONCE(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_STATUS(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_WARN(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_ERROR(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_LOG(msg, ...) \ - do { \ - } while (0) -# define LIBRAPID_WASSERT(cond, ...) \ - do { \ - } while (0) -# define LIBRAPID_ASSERT(cond, ...) \ - do { \ - } while (0) +# define LIBRAPID_WARN_ONCE(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_STATUS(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_WARN(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_ERROR(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_LOG(msg, ...) \ + do { \ + } while (0) +# define LIBRAPID_WASSERT(cond, ...) \ + do { \ + } while (0) +# define LIBRAPID_ASSERT(cond, ...) \ + do { \ + } while (0) #endif // LIBRAPID_ENABLE_ASSERT #define PURE_FUNCTION [[nodiscard]] constexpr diff --git a/librapid/include/librapid/core/openclConfig.hpp b/librapid/include/librapid/core/openclConfig.hpp index e1ddf615..d45a7139 100644 --- a/librapid/include/librapid/core/openclConfig.hpp +++ b/librapid/include/librapid/core/openclConfig.hpp @@ -3,20 +3,20 @@ #if defined(LIBRAPID_HAS_OPENCL) -# if defined(LIBRAPID_APPLE) -# include -# else -# include -# endif // LIBRAPID_APPLE +# if defined(LIBRAPID_APPLE) +# include +# else +# include +# endif // LIBRAPID_APPLE -#include +# include #else // LIBRAPID_HAS_OPENCL namespace librapid::typetraits { - template - struct IsOpenCLStorage : std::false_type {}; -} + template + struct IsOpenCLStorage : std::false_type {}; +} // namespace librapid::typetraits -#endif // LIBRAPID_HAS_OPENCL -#endif // LIBRAPID_CORE_OPENCL_CONFIG_HPP \ No newline at end of file +#endif // LIBRAPID_HAS_OPENCL +#endif // LIBRAPID_CORE_OPENCL_CONFIG_HPP \ No newline at end of file diff --git a/librapid/include/librapid/core/preMain.hpp b/librapid/include/librapid/core/preMain.hpp index 3857f923..9233edc9 100644 --- a/librapid/include/librapid/core/preMain.hpp +++ b/librapid/include/librapid/core/preMain.hpp @@ -6,26 +6,27 @@ */ namespace librapid::detail { - class PreMain { - public: - PreMain(); - private: - }; + class PreMain { + public: + PreMain(); - // These must be declared here for use in ASSERT functions - template - T internalMax(T val) { - return val; - } + private: + }; - template - T internalMax(T val, Tn... vals) { - auto maxOther = internalMax(vals...); - return val < maxOther ? maxOther : val; - } + // These must be declared here for use in ASSERT functions + template + T internalMax(T val) { + return val; + } - extern bool preMainRun; - static inline PreMain preMain = PreMain(); + template + T internalMax(T val, Tn... vals) { + auto maxOther = internalMax(vals...); + return val < maxOther ? maxOther : val; + } + + extern bool preMainRun; + static inline PreMain preMain = PreMain(); } // namespace librapid::detail #endif // LIBRAPID_CORE_PREMAIN \ No newline at end of file diff --git a/librapid/include/librapid/core/traits.hpp b/librapid/include/librapid/core/traits.hpp index a0d0b674..45aff56e 100644 --- a/librapid/include/librapid/core/traits.hpp +++ b/librapid/include/librapid/core/traits.hpp @@ -16,777 +16,777 @@ */ #define LIMIT_IMPL_CONSTEXPR(NAME_) LIBRAPID_ALWAYS_INLINE static constexpr auto NAME_() noexcept -#define LIMIT_IMPL(NAME_) LIBRAPID_ALWAYS_INLINE static auto NAME_() noexcept -#define NUM_LIM(NAME_) std::numeric_limits::NAME_() +#define LIMIT_IMPL(NAME_) LIBRAPID_ALWAYS_INLINE static auto NAME_() noexcept +#define NUM_LIM(NAME_) std::numeric_limits::NAME_() namespace librapid { - namespace detail { - /// An enum class representing different types within LibRapid. Intended mainly for - /// internal use - enum class LibRapidType { - Scalar, - Dual, - Vector, - ArrayContainer, - ArrayFunction, - ArrayView, - }; - - constexpr bool sameType(LibRapidType type1, LibRapidType type2) { return type1 == type2; } - - /* - * Pretty string representations of data types at compile time. This is adapted from - * https://bitwizeshift.github.io/posts/2021/03/09/getting-an-unmangled-type-name-at-compile-time/ - * and I have simply adapted it to work with LibRapid. - */ - - template - constexpr auto substringAsArray(std::string_view str, std::index_sequence) { - return std::array {str[Idxs]...}; - } - - template - constexpr auto typeNameArray() { + namespace detail { + /// An enum class representing different types within LibRapid. Intended mainly for + /// internal use + enum class LibRapidType { + Scalar, + Dual, + Vector, + ArrayContainer, + ArrayFunction, + ArrayView, + }; + + constexpr bool sameType(LibRapidType type1, LibRapidType type2) { return type1 == type2; } + + /* + * Pretty string representations of data types at compile time. This is adapted from + * https://bitwizeshift.github.io/posts/2021/03/09/getting-an-unmangled-type-name-at-compile-time/ + * and I have simply adapted it to work with LibRapid. + */ + + template + constexpr auto substringAsArray(std::string_view str, std::index_sequence) { + return std::array {str[Idxs]...}; + } + + template + constexpr auto typeNameArray() { #if defined(__clang__) - constexpr auto prefix = std::string_view {"[T = "}; - constexpr auto suffix = std::string_view {"]"}; - constexpr auto function = std::string_view {__PRETTY_FUNCTION__}; + constexpr auto prefix = std::string_view {"[T = "}; + constexpr auto suffix = std::string_view {"]"}; + constexpr auto function = std::string_view {__PRETTY_FUNCTION__}; #elif defined(__GNUC__) - constexpr auto prefix = std::string_view {"with T = "}; - constexpr auto suffix = std::string_view {"]"}; - constexpr auto function = std::string_view {__PRETTY_FUNCTION__}; + constexpr auto prefix = std::string_view {"with T = "}; + constexpr auto suffix = std::string_view {"]"}; + constexpr auto function = std::string_view {__PRETTY_FUNCTION__}; #elif defined(_MSC_VER) - constexpr auto prefix = std::string_view {"type_name_array<"}; - constexpr auto suffix = std::string_view {">(void)"}; - constexpr auto function = std::string_view {__FUNCSIG__}; + constexpr auto prefix = std::string_view {"type_name_array<"}; + constexpr auto suffix = std::string_view {">(void)"}; + constexpr auto function = std::string_view {__FUNCSIG__}; #else -# define LIBRAPID_NO_TYPE_TO_STRING +# define LIBRAPID_NO_TYPE_TO_STRING #endif #if !defined(LIBRAPID_NO_TYPE_TO_STRING) - constexpr auto start = function.find(prefix) + prefix.size(); - constexpr auto end = function.rfind(suffix); + constexpr auto start = function.find(prefix) + prefix.size(); + constexpr auto end = function.rfind(suffix); - static_assert(start < end); + static_assert(start < end); - constexpr auto name = function.substr(start, (end - start)); - return substringAsArray(name, std::make_index_sequence {}); + constexpr auto name = function.substr(start, (end - start)); + return substringAsArray(name, std::make_index_sequence {}); #else - return std::array {}; + return std::array {}; #endif - } - - template - struct TypeNameHolder { - static inline constexpr auto value = typeNameArray(); - }; - } // namespace detail - - namespace typetraits { - template - constexpr auto typeName() -> std::string_view { - constexpr auto &value = detail::TypeNameHolder::value; - return std::string_view {value.data(), value.size()}; - } - - template - struct HasCustomEval : std::false_type {}; - - /// Provides compile-time information about a data type, allowing for easier function - /// switching and compile-time evaluation - /// \tparam T The type to get information about - template - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = T; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "[ NO DEFINED TYPE ]"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = false; + } + + template + struct TypeNameHolder { + static inline constexpr auto value = typeNameArray(); + }; + } // namespace detail + + namespace typetraits { + template + constexpr auto typeName() -> std::string_view { + constexpr auto &value = detail::TypeNameHolder::value; + return std::string_view {value.data(), value.size()}; + } + + template + struct HasCustomEval : std::false_type {}; + + /// Provides compile-time information about a data type, allowing for easier function + /// switching and compile-time evaluation + /// \tparam T The type to get information about + template + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = T; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "[ NO DEFINED TYPE ]"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = false; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template - struct TypeInfo : TypeInfo {}; - - template - struct TypeInfo : TypeInfo {}; - - template - struct TypeInfo : TypeInfo {}; - - template - struct TypeInfo : TypeInfo {}; - - template - struct TypeInfo : TypeInfo {}; - - template - struct TypeInfo : TypeInfo {}; - - template - struct TypeInfo : TypeInfo {}; - - template - struct TypeInfo : TypeInfo {}; - - template - struct TypeInfo : TypeInfo {}; - - template - struct TypeInfo : TypeInfo {}; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = bool; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "char"; - static constexpr bool supportsArithmetic = false; - static constexpr bool supportsLogical = false; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = false; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template + struct TypeInfo : TypeInfo {}; + + template + struct TypeInfo : TypeInfo {}; + + template + struct TypeInfo : TypeInfo {}; + + template + struct TypeInfo : TypeInfo {}; + + template + struct TypeInfo : TypeInfo {}; + + template + struct TypeInfo : TypeInfo {}; + + template + struct TypeInfo : TypeInfo {}; + + template + struct TypeInfo : TypeInfo {}; + + template + struct TypeInfo : TypeInfo {}; + + template + struct TypeInfo : TypeInfo {}; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = bool; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "char"; + static constexpr bool supportsArithmetic = false; + static constexpr bool supportsLogical = false; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = false; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_8I; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_8I; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = char; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "bool"; - static constexpr bool supportsArithmetic = false; - static constexpr bool supportsLogical = false; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = true; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = char; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "bool"; + static constexpr bool supportsArithmetic = false; + static constexpr bool supportsLogical = false; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = true; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_8I; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_8I; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = int8_t; - using Packet = xsimd::batch; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size; - static constexpr char name[] = "int8_t"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = true; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = int8_t; + using Packet = xsimd::batch; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = Packet::size; + static constexpr char name[] = "int8_t"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = true; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_8I; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_8I; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = uint8_t; - using Packet = xsimd::batch; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size; - static constexpr char name[] = "uint8_t"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = true; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = uint8_t; + using Packet = xsimd::batch; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = Packet::size; + static constexpr char name[] = "uint8_t"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = true; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_8U; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_8U; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = int16_t; - using Packet = xsimd::batch; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size; - static constexpr char name[] = "int16_t"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = true; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = int16_t; + using Packet = xsimd::batch; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = Packet::size; + static constexpr char name[] = "int16_t"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = true; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_16I; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_16I; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = uint16_t; - using Packet = xsimd::batch; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size; - static constexpr char name[] = "uint16_t"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = true; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = uint16_t; + using Packet = xsimd::batch; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = Packet::size; + static constexpr char name[] = "uint16_t"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = true; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_16U; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_16U; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = int32_t; - using Packet = xsimd::batch; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size; - static constexpr char name[] = "int32_t"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = true; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = int32_t; + using Packet = xsimd::batch; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = Packet::size; + static constexpr char name[] = "int32_t"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = true; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32I; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32I; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = uint32_t; - using Packet = xsimd::batch; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size; - static constexpr char name[] = "uint32_t"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = true; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = uint32_t; + using Packet = xsimd::batch; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = Packet::size; + static constexpr char name[] = "uint32_t"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = true; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32U; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32U; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = int64_t; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "int64_t"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = false; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = int64_t; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "int64_t"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = false; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64I; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64I; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = uint64_t; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "uint64_t"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = false; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = uint64_t; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "uint64_t"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = false; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64U; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64U; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = float; - using Packet = xsimd::batch; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size; - static constexpr char name[] = "float"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = true; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = float; + using Packet = xsimd::batch; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = Packet::size; + static constexpr char name[] = "float"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = true; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32F; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32F; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = double; - using Packet = xsimd::batch; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = Packet::size; - static constexpr char name[] = "double"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = true; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = double; + using Packet = xsimd::batch; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = Packet::size; + static constexpr char name[] = "double"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = true; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; - using Scalar = T; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "Vector"; - static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; - static constexpr bool supportsLogical = TypeInfo::supportsLogical; - static constexpr bool supportsBinary = TypeInfo::supportsBinary; - static constexpr bool allowVectorisation = TypeInfo::allowVectorisation; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; + using Scalar = T; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "Vector"; + static constexpr bool supportsArithmetic = TypeInfo::supportsArithmetic; + static constexpr bool supportsLogical = TypeInfo::supportsLogical; + static constexpr bool supportsBinary = TypeInfo::supportsBinary; + static constexpr bool allowVectorisation = TypeInfo::allowVectorisation; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = TypeInfo::CudaType; - static constexpr int64_t cudaPacketWidth = TypeInfo::cudaPacketWidth; + static constexpr cudaDataType_t CudaType = TypeInfo::CudaType; + static constexpr int64_t cudaPacketWidth = TypeInfo::cudaPacketWidth; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; #if defined(LIBRAPID_HAS_CUDA) - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = float; - using Packet = std::false_type; - using Backend = backend::CUDA; - static constexpr int64_t packetWidth = 4; - static constexpr char name[] = "float2"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = true; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32F; - static constexpr int64_t cudaPacketWidth = 1; -# endif - - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = float; - using Packet = std::false_type; - using Backend = backend::CUDA; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "float3"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = true; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32F; - static constexpr int64_t cudaPacketWidth = 3; -# endif - - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = float; - using Packet = std::false_type; - using Backend = backend::CUDA; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "float4"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = true; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32F; - static constexpr int64_t cudaPacketWidth = 4; -# endif - - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = double; - using Packet = std::false_type; - using Backend = backend::CUDA; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "double2"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = true; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; - static constexpr int64_t cudaPacketWidth = 2; -# endif - - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = double; - using Packet = std::false_type; - using Backend = backend::CUDA; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "double3"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = true; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; - static constexpr int64_t cudaPacketWidth = 3; -# endif - - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = double; - using Packet = std::false_type; - using Backend = backend::CUDA; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "double4"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = true; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; - static constexpr int64_t cudaPacketWidth = 4; -# endif - - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = float; + using Packet = std::false_type; + using Backend = backend::CUDA; + static constexpr int64_t packetWidth = 4; + static constexpr char name[] = "float2"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = true; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32F; + static constexpr int64_t cudaPacketWidth = 1; +# endif + + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = float; + using Packet = std::false_type; + using Backend = backend::CUDA; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "float3"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = true; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32F; + static constexpr int64_t cudaPacketWidth = 3; +# endif + + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = float; + using Packet = std::false_type; + using Backend = backend::CUDA; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "float4"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = true; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_32F; + static constexpr int64_t cudaPacketWidth = 4; +# endif + + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = double; + using Packet = std::false_type; + using Backend = backend::CUDA; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "double2"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = true; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; + static constexpr int64_t cudaPacketWidth = 2; +# endif + + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = double; + using Packet = std::false_type; + using Backend = backend::CUDA; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "double3"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = true; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; + static constexpr int64_t cudaPacketWidth = 3; +# endif + + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = double; + using Packet = std::false_type; + using Backend = backend::CUDA; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "double4"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = true; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; + static constexpr int64_t cudaPacketWidth = 4; +# endif + + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; #endif - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = typename xsimd::batch_element_reference::Scalar; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "xsimd::batch_element_reference"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = false; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = false; - - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = false; - - LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } - LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } - LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr char name[] = "CPU"; - using Backend = backend::CPU; - }; + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = typename xsimd::batch_element_reference::Scalar; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "xsimd::batch_element_reference"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = false; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = false; + + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = false; + + LIMIT_IMPL_CONSTEXPR(min) { return NUM_LIM(min); } + LIMIT_IMPL_CONSTEXPR(max) { return NUM_LIM(max); } + LIMIT_IMPL_CONSTEXPR(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL_CONSTEXPR(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL_CONSTEXPR(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL_CONSTEXPR(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL_CONSTEXPR(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL_CONSTEXPR(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr char name[] = "CPU"; + using Backend = backend::CPU; + }; #if defined(LIBRAPID_HAS_OPENCL) - template<> - struct TypeInfo { - static constexpr char name[] = "OpenCL"; - using Backend = backend::OpenCL; - }; + template<> + struct TypeInfo { + static constexpr char name[] = "OpenCL"; + using Backend = backend::OpenCL; + }; #endif #if defined(LIBRAPID_HAS_CUDA) - template<> - struct TypeInfo { - static constexpr char name[] = "CUDA"; - using Backend = backend::CUDA; - }; + template<> + struct TypeInfo { + static constexpr char name[] = "CUDA"; + using Backend = backend::CUDA; + }; #endif - template - using ScalarReturnType = typename TypeInfo::Scalar; - }; // namespace typetraits + template + using ScalarReturnType = typename TypeInfo::Scalar; + }; // namespace typetraits } // namespace librapid #endif // LIBRAPID_CORE_TRAITS_HPP \ No newline at end of file diff --git a/librapid/include/librapid/core/typetraits.hpp b/librapid/include/librapid/core/typetraits.hpp index 2ca61fbb..c4cd8962 100644 --- a/librapid/include/librapid/core/typetraits.hpp +++ b/librapid/include/librapid/core/typetraits.hpp @@ -7,62 +7,62 @@ */ namespace librapid::typetraits { - template - using EnableIf = std::enable_if_t; + template + using EnableIf = std::enable_if_t; - template - constexpr bool IsSame = std::is_same::value; + template + constexpr bool IsSame = std::is_same::value; - namespace impl { - /* - * These functions test for the presence of certain features of a type - * by providing two valid function overloads, but the preferred one - * (the one taking an integer) is only valid if the requested feature - * exists. The return type of both functions differ, and can be evaluated - * as "true" and "false" depending on the presence of the feature. - * - * This is really cool :) - */ + namespace impl { + /* + * These functions test for the presence of certain features of a type + * by providing two valid function overloads, but the preferred one + * (the one taking an integer) is only valid if the requested feature + * exists. The return type of both functions differ, and can be evaluated + * as "true" and "false" depending on the presence of the feature. + * + * This is really cool :) + */ - template()[std::declval()])> - std::true_type testSubscript(int); - template - std::false_type testSubscript(float); + template()[std::declval()])> + std::true_type testSubscript(int); + template + std::false_type testSubscript(float); - template() + std::declval())> - std::true_type testAddition(int); - template - std::false_type testAddition(float); + template() + std::declval())> + std::true_type testAddition(int); + template + std::false_type testAddition(float); - template() * std::declval())> - std::true_type testMultiplication(int); - template - std::false_type testMultiplication(float); + template() * std::declval())> + std::true_type testMultiplication(int); + template + std::false_type testMultiplication(float); - template())> - std::true_type testCast(int); - template - std::false_type testCast(float); - } // namespace impl + template())> + std::true_type testCast(int); + template + std::false_type testCast(float); + } // namespace impl - template - struct HasSubscript : public decltype(impl::testSubscript(1)) {}; + template + struct HasSubscript : public decltype(impl::testSubscript(1)) {}; - template - struct HasAddition : public decltype(impl::testAddition(1)) {}; + template + struct HasAddition : public decltype(impl::testAddition(1)) {}; - template - struct HasMultiplication : public decltype(impl::testMultiplication(1)) {}; + template + struct HasMultiplication : public decltype(impl::testMultiplication(1)) {}; - template - struct CanCast : public decltype(impl::testCast(1)) {}; + template + struct CanCast : public decltype(impl::testCast(1)) {}; - // Detect whether a class can be default constructed - template - using TriviallyDefaultConstructible = std::is_trivially_default_constructible; + // Detect whether a class can be default constructed + template + using TriviallyDefaultConstructible = std::is_trivially_default_constructible; } // namespace librapid::typetraits #endif // LIBRAPID_CORE_TYPETRAITS_HPP \ No newline at end of file diff --git a/librapid/include/librapid/core/warningSuppress.hpp b/librapid/include/librapid/core/warningSuppress.hpp index fed3c1dc..be77c22e 100644 --- a/librapid/include/librapid/core/warningSuppress.hpp +++ b/librapid/include/librapid/core/warningSuppress.hpp @@ -2,17 +2,17 @@ #define LIBRAPID_WARNING_SUPPRESS #ifdef _MSC_VER -# define LIBRAPID_MSVC_SUPPRESS(WARNING_) __pragma(warning(suppress : WARNING_)) +# define LIBRAPID_MSVC_SUPPRESS(WARNING_) __pragma(warning(suppress : WARNING_)) #else -# define LIBRAPID_MSVC_SUPPRESS(WARNING_) +# define LIBRAPID_MSVC_SUPPRESS(WARNING_) #endif // Disable warnings for GCC/Clang #ifdef __GNUC__ -# define LIBRAPID_GCC_SUPPRESS(WARNING_) \ - _Pragma("GCC diagnostic push") _Pragma("GCC diagnostic ignored \"-W" #WARNING_ "\"") +# define LIBRAPID_GCC_SUPPRESS(WARNING_) \ + _Pragma("GCC diagnostic push") _Pragma("GCC diagnostic ignored \"-W" #WARNING_ "\"") #else -# define LIBRAPID_GCC_SUPPRESS(WARNING_) +# define LIBRAPID_GCC_SUPPRESS(WARNING_) #endif LIBRAPID_MSVC_SUPPRESS(4996) // Disable warnings about unsafe classes diff --git a/librapid/include/librapid/cuda/cudaKernelProcesor.hpp b/librapid/include/librapid/cuda/cudaKernelProcesor.hpp index 9c41f64c..ff8cf468 100644 --- a/librapid/include/librapid/cuda/cudaKernelProcesor.hpp +++ b/librapid/include/librapid/cuda/cudaKernelProcesor.hpp @@ -4,94 +4,94 @@ #if defined(LIBRAPID_HAS_CUDA) namespace librapid::cuda { - /// Load a CUDA kernel from a file and return the string representation of it. - /// \param relPath File path relative to LibRapid's "cuda/kernels" directory - /// \return String representation of the kernel - const std::string &loadKernel(const std::string &path, bool relative = true); - - jitify::Program generateCudaProgram(const std::string &kernel); - - /// Run a kernel string on the GPU with the specified arguments - /// \tparam Templates Instantiation types passed to Jitify - /// \tparam Args Argument types passed to Jitify - /// \param kernel Kernel string to run - /// \param kernelName Name of the kernel - /// \param elements Number of elements to process - /// \param arguments Arguments to pass to the kernel - template - void runKernelString(const std::string &kernel, const std::string &kernelName, size_t elements, - Args... arguments) { - jitify::Program program = generateCudaProgram(kernel); - unsigned int threadsPerBlock, blocksPerGrid; - - // Use 1 to 512 threads per block - if (elements < 512) { - threadsPerBlock = static_cast(elements); - blocksPerGrid = 1; - } else { - threadsPerBlock = 512; - blocksPerGrid = static_cast( - ceil(static_cast(elements) / static_cast(threadsPerBlock))); - } - - dim3 grid(blocksPerGrid); - dim3 block(threadsPerBlock); - -# if defined(LIBRAPID_DEBUG) - try { -# endif // LIBRAPID_DEBUG - - jitifyCall(program.kernel(kernelName) - .instantiate(jitify::reflection::Type()...) - .configure(grid, block, 0, global::cudaStream) - .launch(arguments...)); - -# if defined(LIBRAPID_DEBUG) - } catch (const std::exception &e) { - auto format = fmt::emphasis::bold | fmt::fg(fmt::color::red); - fmt::print(format, "Error : {}\n", e.what()); - fmt::print(format, "Kernel name : {}\n", kernelName); - fmt::print(format, "Elements : {}\n", elements); - fmt::print(format, "Threads per block: {}\n", threadsPerBlock); - fmt::print(format, "Blocks per grid : {}\n", blocksPerGrid); - fmt::print(format, "Arguments : {}\n", sizeof...(Args)); - - // Print all arguments - auto printer = [](auto x, auto format) { - fmt::print(fmt::fg(fmt::color::purple), "\nArgument:\n"); - - // True if x can be printed with fmt - constexpr bool isPrintable = fmt::is_formattable::value; - - if constexpr (isPrintable) { - fmt::print(format, "\tValue: {}\n", x); - } else { - fmt::print(format, "\tValue: [ CANNOT PRINT ]\n"); - } - fmt::print(format, "\tType : {}\n", typeid(x).name()); - }; - - (printer(arguments, fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange)), ...); - (printer(typeid(Templates).name(), fmt::emphasis::bold | fmt::fg(fmt::color::plum)), - ...); - - throw; - } -# endif // LIBRAPID_DEBUG - } - - /// Run a kernel from a filename and kernel name with the specified arguments - /// \tparam Templates Instantiation types passed to Jitify - /// \tparam Args Argument types passed to Jitify - /// \param name Filename of the kernel - /// \param kernelName Name of the kernel - /// \param elements Number of elements to process - /// \param arguments Arguments to pass to the kernel - template - void runKernel(const std::string &name, const std::string &kernelName, size_t elements, - Args... arguments) { - runKernelString(loadKernel(name), kernelName, elements, arguments...); - } + /// Load a CUDA kernel from a file and return the string representation of it. + /// \param relPath File path relative to LibRapid's "cuda/kernels" directory + /// \return String representation of the kernel + const std::string &loadKernel(const std::string &path, bool relative = true); + + jitify::Program generateCudaProgram(const std::string &kernel); + + /// Run a kernel string on the GPU with the specified arguments + /// \tparam Templates Instantiation types passed to Jitify + /// \tparam Args Argument types passed to Jitify + /// \param kernel Kernel string to run + /// \param kernelName Name of the kernel + /// \param elements Number of elements to process + /// \param arguments Arguments to pass to the kernel + template + void runKernelString(const std::string &kernel, const std::string &kernelName, size_t elements, + Args... arguments) { + jitify::Program program = generateCudaProgram(kernel); + unsigned int threadsPerBlock, blocksPerGrid; + + // Use 1 to 512 threads per block + if (elements < 512) { + threadsPerBlock = static_cast(elements); + blocksPerGrid = 1; + } else { + threadsPerBlock = 512; + blocksPerGrid = static_cast( + ceil(static_cast(elements) / static_cast(threadsPerBlock))); + } + + dim3 grid(blocksPerGrid); + dim3 block(threadsPerBlock); + +# if defined(LIBRAPID_DEBUG) + try { +# endif // LIBRAPID_DEBUG + + jitifyCall(program.kernel(kernelName) + .instantiate(jitify::reflection::Type()...) + .configure(grid, block, 0, global::cudaStream) + .launch(arguments...)); + +# if defined(LIBRAPID_DEBUG) + } catch (const std::exception &e) { + auto format = fmt::emphasis::bold | fmt::fg(fmt::color::red); + fmt::print(format, "Error : {}\n", e.what()); + fmt::print(format, "Kernel name : {}\n", kernelName); + fmt::print(format, "Elements : {}\n", elements); + fmt::print(format, "Threads per block: {}\n", threadsPerBlock); + fmt::print(format, "Blocks per grid : {}\n", blocksPerGrid); + fmt::print(format, "Arguments : {}\n", sizeof...(Args)); + + // Print all arguments + auto printer = [](auto x, auto format) { + fmt::print(fmt::fg(fmt::color::purple), "\nArgument:\n"); + + // True if x can be printed with fmt + constexpr bool isPrintable = fmt::is_formattable::value; + + if constexpr (isPrintable) { + fmt::print(format, "\tValue: {}\n", x); + } else { + fmt::print(format, "\tValue: [ CANNOT PRINT ]\n"); + } + fmt::print(format, "\tType : {}\n", typeid(x).name()); + }; + + (printer(arguments, fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange)), ...); + (printer(typeid(Templates).name(), fmt::emphasis::bold | fmt::fg(fmt::color::plum)), + ...); + + throw; + } +# endif // LIBRAPID_DEBUG + } + + /// Run a kernel from a filename and kernel name with the specified arguments + /// \tparam Templates Instantiation types passed to Jitify + /// \tparam Args Argument types passed to Jitify + /// \param name Filename of the kernel + /// \param kernelName Name of the kernel + /// \param elements Number of elements to process + /// \param arguments Arguments to pass to the kernel + template + void runKernel(const std::string &name, const std::string &kernelName, size_t elements, + Args... arguments) { + runKernelString(loadKernel(name), kernelName, elements, arguments...); + } } // namespace librapid::cuda #endif // LIBRAPID_HAS_CUDA diff --git a/librapid/include/librapid/cuda/cudaStorage.hpp b/librapid/include/librapid/cuda/cudaStorage.hpp index 67fbfead..70146bb5 100644 --- a/librapid/include/librapid/cuda/cudaStorage.hpp +++ b/librapid/include/librapid/cuda/cudaStorage.hpp @@ -10,568 +10,568 @@ */ namespace librapid { - namespace typetraits { - template - struct TypeInfo> { - static constexpr bool isLibRapidType = true; - using Scalar = Scalar_; - using Backend = backend::CUDA; - }; - - template - struct IsCudaStorage : std::false_type {}; - - template - struct IsCudaStorage> : std::true_type {}; - - LIBRAPID_DEFINE_AS_TYPE(typename Scalar_, CudaStorage); - } // namespace typetraits - - namespace detail { - /// Safely allocate memory for \p size elements of type on the GPU using CUDA. - /// \tparam T Scalar type - /// \param size Number of elements to allocate - /// \return GPU pointer - /// \see safeAllocate - template - T *__restrict cudaSafeAllocate(size_t size); - - /// Safely free memory for \p size elements of type on the GPU using CUDA. - /// \tparam T Scalar type - /// \param data The data to deallocate - /// \return GPU pointer - /// \see safeAllocate - template - void cudaSafeDeallocate(T *__restrict data); - - template - std::shared_ptr cudaSharedPtrAllocate(size_t size); - -# define CUDA_REF_OPERATOR(OP) \ - template \ - auto operator OP(const CudaRef &lhs, const RHS &rhs) { \ - return lhs.get() OP rhs; \ - } \ + namespace typetraits { + template + struct TypeInfo> { + static constexpr bool isLibRapidType = true; + using Scalar = Scalar_; + using Backend = backend::CUDA; + }; + + template + struct IsCudaStorage : std::false_type {}; + + template + struct IsCudaStorage> : std::true_type {}; + + LIBRAPID_DEFINE_AS_TYPE(typename Scalar_, CudaStorage); + } // namespace typetraits + + namespace detail { + /// Safely allocate memory for \p size elements of type on the GPU using CUDA. + /// \tparam T Scalar type + /// \param size Number of elements to allocate + /// \return GPU pointer + /// \see safeAllocate + template + T *__restrict cudaSafeAllocate(size_t size); + + /// Safely free memory for \p size elements of type on the GPU using CUDA. + /// \tparam T Scalar type + /// \param data The data to deallocate + /// \return GPU pointer + /// \see safeAllocate + template + void cudaSafeDeallocate(T *__restrict data); + + template + std::shared_ptr cudaSharedPtrAllocate(size_t size); + +# define CUDA_REF_OPERATOR(OP) \ + template \ + auto operator OP(const CudaRef &lhs, const RHS &rhs) { \ + return lhs.get() OP rhs; \ + } \ \ - template \ - auto operator OP(const LHS &lhs, const CudaRef &rhs) { \ - return lhs OP rhs.get(); \ - } \ + template \ + auto operator OP(const LHS &lhs, const CudaRef &rhs) { \ + return lhs OP rhs.get(); \ + } \ \ - template \ - auto operator OP(const CudaRef &lhs, const CudaRef &rhs) { \ - return lhs.get() OP rhs.get(); \ - } \ + template \ + auto operator OP(const CudaRef &lhs, const CudaRef &rhs) { \ + return lhs.get() OP rhs.get(); \ + } \ \ - template \ - auto operator OP##=(CudaRef &lhs, const RHS &rhs) { \ - lhs = lhs.get() OP rhs; \ - } \ + template \ + auto operator OP##=(CudaRef &lhs, const RHS &rhs) { \ + lhs = lhs.get() OP rhs; \ + } \ \ - template \ - auto operator OP##=(CudaRef &lhs, const CudaRef &rhs) { \ - lhs = lhs.get() OP rhs.get(); \ - } - -# define CUDA_REF_OPERATOR_NO_ASSIGN(OP) \ - template \ - auto operator OP(const CudaRef &lhs, const RHS &rhs) { \ - return lhs.get() OP rhs; \ - } \ + template \ + auto operator OP##=(CudaRef &lhs, const CudaRef &rhs) { \ + lhs = lhs.get() OP rhs.get(); \ + } + +# define CUDA_REF_OPERATOR_NO_ASSIGN(OP) \ + template \ + auto operator OP(const CudaRef &lhs, const RHS &rhs) { \ + return lhs.get() OP rhs; \ + } \ \ - template \ - auto operator OP(const LHS &lhs, const CudaRef &rhs) { \ - return lhs OP rhs.get(); \ - } \ + template \ + auto operator OP(const LHS &lhs, const CudaRef &rhs) { \ + return lhs OP rhs.get(); \ + } \ \ - template \ - auto operator OP(const CudaRef &lhs, const CudaRef &rhs) { \ - return lhs.get() OP rhs.get(); \ - } - - template - class CudaRef { - public: - using PtrType = std::shared_ptr; - - CudaRef(const PtrType &ptr, size_t offset) : m_ptr(ptr), m_offset(offset) {} - - LIBRAPID_ALWAYS_INLINE CudaRef &operator=(const T &val) { - cudaSafeCall(cudaMemcpyAsync(m_ptr.get() + m_offset, - &val, - sizeof(T), - cudaMemcpyHostToDevice, - global::cudaStream)); - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T get() const { - T tmp; - cudaSafeCall(cudaMemcpyAsync(&tmp, - m_ptr.get() + m_offset, - sizeof(T), - cudaMemcpyDeviceToHost, - global::cudaStream)); - return tmp; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE operator CAST() const { - return static_cast(get()); - } - - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const { - return fmt::format(format, get()); - } - - private: - std::shared_ptr m_ptr; - size_t m_offset; - }; - - CUDA_REF_OPERATOR(+) - CUDA_REF_OPERATOR(-) - CUDA_REF_OPERATOR(*) - CUDA_REF_OPERATOR(/) - CUDA_REF_OPERATOR(%) - CUDA_REF_OPERATOR(^) - CUDA_REF_OPERATOR(&) - CUDA_REF_OPERATOR(|) - CUDA_REF_OPERATOR(<<) - CUDA_REF_OPERATOR(>>) - CUDA_REF_OPERATOR_NO_ASSIGN(==) - CUDA_REF_OPERATOR_NO_ASSIGN(!=) - CUDA_REF_OPERATOR_NO_ASSIGN(<) - CUDA_REF_OPERATOR_NO_ASSIGN(>) - CUDA_REF_OPERATOR_NO_ASSIGN(<=) - CUDA_REF_OPERATOR_NO_ASSIGN(>=) - } // namespace detail - - template - class CudaStorage { - public: - using Scalar = Scalar_; - using Pointer = std::shared_ptr; // Scalar *__restrict; - using ConstPointer = const std::shared_ptr; // const Scalar *__restrict; - using Reference = Scalar &; - using ConstReference = const Scalar &; - using DifferenceType = std::ptrdiff_t; - using SizeType = std::size_t; - - /// Default constructor -- initializes with nullptr - CudaStorage() = default; - - /// Create a CudaStorage object with \p elements. The data is not - /// initialized. - /// \param size Number of elements - LIBRAPID_ALWAYS_INLINE explicit CudaStorage(SizeType size); - - /// Create a CudaStorage object with \p elements. The data is initialized - /// to \p value. - /// \param size Number of elements - /// \param value Value to fill with - LIBRAPID_ALWAYS_INLINE CudaStorage(SizeType size, ConstReference value); - - LIBRAPID_ALWAYS_INLINE CudaStorage(Scalar *begin, SizeType size, bool independent); - - /// Create a new CudaStorage object from an existing one. - /// \param other The CudaStorage to copy - LIBRAPID_ALWAYS_INLINE CudaStorage(const CudaStorage &other); - - /// Create a new CudaStorage object from a temporary one, moving the - /// data - /// \param other The array to move - LIBRAPID_ALWAYS_INLINE CudaStorage(CudaStorage &&other) noexcept; - - /// Create a CudaStorage object from an std::initializer_list - /// \param list Initializer list of elements - LIBRAPID_ALWAYS_INLINE CudaStorage(const std::initializer_list &list); - - /// Create a CudaStorage object from an std::vector of values - /// \param vec The vector to fill with - LIBRAPID_ALWAYS_INLINE explicit CudaStorage(const std::vector &vec); - - template - static ShapeType defaultShape(); - - template - static CudaStorage fromData(const std::initializer_list &vec); - - template - static CudaStorage fromData(const std::vector &vec); - - /// Assignment operator for a CudaStorage object - /// \param other CudaStorage object to copy - /// \return *this - LIBRAPID_ALWAYS_INLINE CudaStorage &operator=(const CudaStorage &other); - - /// Move assignment operator for a CudaStorage object - /// \param other CudaStorage object to move - /// \return *this - LIBRAPID_ALWAYS_INLINE CudaStorage &operator=(CudaStorage &&other) noexcept; - - /// Free a CudaStorage object - ~CudaStorage(); - - /// \brief Set this CudaStorage object to reference the same data as \p other - /// \param other CudaStorage object to reference - void set(const CudaStorage &other); - - /// \brief Create a deep copy of this CudaStorage object - /// \return Deep copy of this CudaStorage object - CudaStorage copy() const; - - /// Resize a CudaStorage object to \p size elements. Existing elements are preserved where - /// possible. - /// \param size Number of elements - /// \see resize(SizeType, int) - LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize); - - /// Resize a CudaStorage object to \p size elements. Existing elements are not preserved. - /// This method of resizing is faster and more efficient than the version which preserves - /// the original data, but of course, this has the drawback that data will be lost. - /// \param size Number of elements - LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize, int); - - /// Return the number of elements in the CudaStorage object. - /// \return The number of elements - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE SizeType size() const noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::CudaRef - operator[](SizeType index) const; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::CudaRef - operator[](SizeType index); - - /// Return the underlying pointer to the data - /// \return The underlying pointer to the data - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer data() const noexcept; - - /// Returns the pointer to the first element of the CudaStorage object - /// \return Pointer to the first element of the CudaStorage object - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer begin() const noexcept; - - /// Returns the pointer to the last element of the CudaStorage object - /// \return A pointer to the last element of the CudaStorage - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer end() const noexcept; - - private: - // Copy data from \p begin to \p end into this Storage object - /// \tparam P Pointer type - /// \param begin Beginning of data to copy - /// \param end End of data to copy - template - LIBRAPID_ALWAYS_INLINE void initData(P begin, P end); - - /// Resize the Storage Object to \p newSize elements, retaining existing - /// data. - /// \param newSize New size of the Storage object - LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize, int); - - /// Resize the Storage object to \p newSize elements. Note this does not - /// initialize the new elements or maintain existing data. - /// \param newSize New size of the Storage object - LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize); - - Pointer m_begin = nullptr; - size_t m_size; - bool m_ownsData = true; - }; - - namespace detail { - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T *__restrict cudaSafeAllocate(size_t size) { - static_assert(typetraits::TriviallyDefaultConstructible::value, - "Data type must be trivially constructable for use with CUDA"); - T *result; - // Round size up to nearest multiple of 32 - size = (size + size_t(31)) & ~size_t(31); - cudaSafeCall(cudaMallocAsync(&result, sizeof(T) * size, global::cudaStream)); - return result; - } - - template - LIBRAPID_ALWAYS_INLINE void cudaSafeDeallocate(T *__restrict data) { - static_assert(std::is_trivially_destructible_v, - "Data type must be trivially constructable for use with CUDA"); - cudaSafeCall(cudaFreeAsync(data, global::cudaStream)); - } - - template - std::shared_ptr cudaSharedPtrAllocate(size_t size) { - return std::shared_ptr(cudaSafeAllocate(size), cudaSafeDeallocate); - } - - template - std::shared_ptr safePointerCopyCuda(T *ptr, bool ownsData = true) { - using RawPointer = T *; - using Pointer = std::shared_ptr; - - if (ownsData) { - return Pointer(ptr, cudaSafeDeallocate); - } else { - return Pointer(ptr, [](RawPointer) {}); - } - } - - template - std::shared_ptr safePointerCopyCuda(std::shared_ptr ptr, bool ownsData = true) { - using RawPointer = T *; - using Pointer = std::shared_ptr; - - if (ownsData) { - return Pointer(ptr.get(), cudaSafeDeallocate); - } else { - return Pointer(ptr.get(), [](RawPointer) {}); - } - } - } // namespace detail - - template - CudaStorage::CudaStorage(SizeType size) : - m_size(size), m_begin(detail::cudaSharedPtrAllocate(size)), m_ownsData(true) {} - - template - CudaStorage::CudaStorage(SizeType size, ConstReference value) : - m_size(size), m_begin(detail::cudaSharedPtrAllocate(size)), m_ownsData(true) { - // Fill the data with "value" - cuda::runKernel("fill", "fillArray", size, size, m_begin, value); - } - - template - CudaStorage::CudaStorage(Scalar *begin, SizeType size, bool ownsData) : - m_size(size), m_begin(detail::safePointerCopyCuda(begin, ownsData)), - m_ownsData(ownsData) {} - - template - CudaStorage::CudaStorage(const CudaStorage &other) : - m_size(other.m_size), m_begin(other.m_begin), m_ownsData(other.m_ownsData) {} - - template - CudaStorage::CudaStorage(CudaStorage &&other) noexcept : - m_begin(other.m_begin), m_size(other.m_size), m_ownsData(other.m_ownsData) { - other.m_begin = nullptr; - other.m_size = 0; - other.m_ownsData = false; - } - - template - CudaStorage::CudaStorage(const std::initializer_list &list) : - m_size(list.size()), m_begin(detail::cudaSharedPtrAllocate(list.size())), - m_ownsData(true) { - cudaSafeCall(cudaMemcpyAsync(m_begin.get(), - list.begin(), - sizeof(T) * m_size, - cudaMemcpyHostToDevice, - global::cudaStream)); - } - - template - CudaStorage::CudaStorage(const std::vector &list) : - m_size(list.size()), m_begin(detail::cudaSharedPtrAllocate(list.size())), - m_ownsData(true) { - cudaSafeCall(cudaMemcpyAsync(m_begin.get(), - list.begin(), - sizeof(T) * m_size, - cudaMemcpyHostToDevice, - global::cudaStream)); - } - - template - template - ShapeType CudaStorage::defaultShape() { - return ShapeType({0}); - } - - template - template - auto CudaStorage::fromData(const std::initializer_list &list) -> CudaStorage { - CudaStorage ret; - ret.initData(list.begin(), list.end()); - return ret; - } - - template - template - auto CudaStorage::fromData(const std::vector &vec) -> CudaStorage { - CudaStorage ret; - ret.initData(vec.begin(), vec.end()); - return ret; - } - - template - auto CudaStorage::operator=(const CudaStorage &other) -> CudaStorage & { - if (this != &other) { - if (m_ownsData) { - // If we own the data already, we can just copy the pointer since we know it won't - // affect anything else. The shared pointer deals with the reference counting, so - // we don't need to worry about other arrays that might be using the same data. - m_begin = other.m_begin; - m_size = other.m_size; - } else { - LIBRAPID_ASSERT(m_size == other.m_size, - "Cannot copy storage with {} elements to dependent storage with " - "{} elements", - other.m_size, - m_size); - - cudaSafeCall(cudaMemcpyAsync(m_begin.get(), - other.begin().get(), - sizeof(T) * m_size, - cudaMemcpyDeviceToDevice, - global::cudaStream)); - } - } - return *this; - } - - template - auto CudaStorage::operator=(CudaStorage &&other) noexcept -> CudaStorage & { - if (this != &other) { - if (m_ownsData) { - std::swap(m_begin, other.m_begin); - std::swap(m_size, other.m_size); - std::swap(m_ownsData, other.m_ownsData); - } else { - LIBRAPID_ASSERT( - size() == other.size(), - "Mismatched storage sizes. Cannot assign CUDA storage with {} elements to " - "dependent CUDA storage with {} elements", - other.size(), - size()); - - cudaSafeCall(cudaMemcpyAsync(m_begin.get(), - other.begin().get(), - sizeof(T) * m_size, - cudaMemcpyDeviceToDevice, - global::cudaStream)); - } - } - return *this; - } - - template - CudaStorage::~CudaStorage() { - // Data is freed automatically by the shared_ptr. A custom deleter is used to ensure that - // nothing happens if the storage is dependent. - } - - template - void CudaStorage::set(const CudaStorage &other) { - m_begin = other.m_begin; - m_size = other.m_size; - m_ownsData = other.m_ownsData; - } - - template - auto CudaStorage::copy() const -> CudaStorage { - CudaStorage ret(m_size); - - cudaSafeCall(cudaMemcpyAsync(ret.begin().get(), - m_begin.get(), - sizeof(T) * m_size, - cudaMemcpyDeviceToDevice, - global::cudaStream)); - - return ret; - } - - template - template - void CudaStorage::initData(P begin, P end) { - auto size = std::distance(begin, end); - m_begin = detail::cudaSharedPtrAllocate(size); - m_size = size; - auto tmpBegin = [begin]() { - if constexpr (std::is_pointer_v

) - return begin; - else - return &(*begin); - }(); - cudaSafeCall(cudaMemcpyAsync( - m_begin.get(), tmpBegin, sizeof(T) * size, cudaMemcpyDefault, global::cudaStream)); - } - - template - void CudaStorage::resize(SizeType newSize) { - resizeImpl(newSize); - } - - template - void CudaStorage::resize(SizeType newSize, int) { - resizeImpl(newSize, 0); - } - - template - void CudaStorage::resizeImpl(SizeType newSize) { - if (newSize == size()) { return; } - LIBRAPID_ASSERT(m_ownsData, "Dependent CUDA storage cannot be resized"); - - Pointer oldBegin = m_begin; - SizeType oldSize = m_size; - - // Reallocate - m_begin = detail::cudaSharedPtrAllocate(newSize); - m_size = newSize; - - // Copy old data - cudaSafeCall(cudaMemcpyAsync(m_begin.get(), - oldBegin.get(), - sizeof(T) * std::min(oldSize, newSize), - cudaMemcpyDeviceToDevice, - global::cudaStream)); - - m_size = newSize; - } - - template - void CudaStorage::resizeImpl(SizeType newSize, int) { - if (newSize == size()) return; - LIBRAPID_ASSERT(m_ownsData, "Dependent CUDA storage cannot be resized"); - m_begin = detail::cudaSharedPtrAllocate(newSize); - m_size = newSize; - } - - template - auto CudaStorage::size() const noexcept -> SizeType { - return m_size; - } - - template - auto CudaStorage::operator[](SizeType index) const -> detail::CudaRef { - return {m_begin, index}; - } - - template - auto CudaStorage::operator[](SizeType index) -> detail::CudaRef { - return {m_begin, index}; - } - - template - auto CudaStorage::data() const noexcept -> Pointer { - return m_begin; - } - - template - auto CudaStorage::begin() const noexcept -> Pointer { - return m_begin; - } - - template - auto CudaStorage::end() const noexcept -> Pointer { - return m_begin + m_size; - } + template \ + auto operator OP(const CudaRef &lhs, const CudaRef &rhs) { \ + return lhs.get() OP rhs.get(); \ + } + + template + class CudaRef { + public: + using PtrType = std::shared_ptr; + + CudaRef(const PtrType &ptr, size_t offset) : m_ptr(ptr), m_offset(offset) {} + + LIBRAPID_ALWAYS_INLINE CudaRef &operator=(const T &val) { + cudaSafeCall(cudaMemcpyAsync(m_ptr.get() + m_offset, + &val, + sizeof(T), + cudaMemcpyHostToDevice, + global::cudaStream)); + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T get() const { + T tmp; + cudaSafeCall(cudaMemcpyAsync(&tmp, + m_ptr.get() + m_offset, + sizeof(T), + cudaMemcpyDeviceToHost, + global::cudaStream)); + return tmp; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE operator CAST() const { + return static_cast(get()); + } + + LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const { + return fmt::format(format, get()); + } + + private: + std::shared_ptr m_ptr; + size_t m_offset; + }; + + CUDA_REF_OPERATOR(+) + CUDA_REF_OPERATOR(-) + CUDA_REF_OPERATOR(*) + CUDA_REF_OPERATOR(/) + CUDA_REF_OPERATOR(%) + CUDA_REF_OPERATOR(^) + CUDA_REF_OPERATOR(&) + CUDA_REF_OPERATOR(|) + CUDA_REF_OPERATOR(<<) + CUDA_REF_OPERATOR(>>) + CUDA_REF_OPERATOR_NO_ASSIGN(==) + CUDA_REF_OPERATOR_NO_ASSIGN(!=) + CUDA_REF_OPERATOR_NO_ASSIGN(<) + CUDA_REF_OPERATOR_NO_ASSIGN(>) + CUDA_REF_OPERATOR_NO_ASSIGN(<=) + CUDA_REF_OPERATOR_NO_ASSIGN(>=) + } // namespace detail + + template + class CudaStorage { + public: + using Scalar = Scalar_; + using Pointer = std::shared_ptr; // Scalar *__restrict; + using ConstPointer = const std::shared_ptr; // const Scalar *__restrict; + using Reference = Scalar &; + using ConstReference = const Scalar &; + using DifferenceType = std::ptrdiff_t; + using SizeType = std::size_t; + + /// Default constructor -- initializes with nullptr + CudaStorage() = default; + + /// Create a CudaStorage object with \p elements. The data is not + /// initialized. + /// \param size Number of elements + LIBRAPID_ALWAYS_INLINE explicit CudaStorage(SizeType size); + + /// Create a CudaStorage object with \p elements. The data is initialized + /// to \p value. + /// \param size Number of elements + /// \param value Value to fill with + LIBRAPID_ALWAYS_INLINE CudaStorage(SizeType size, ConstReference value); + + LIBRAPID_ALWAYS_INLINE CudaStorage(Scalar *begin, SizeType size, bool independent); + + /// Create a new CudaStorage object from an existing one. + /// \param other The CudaStorage to copy + LIBRAPID_ALWAYS_INLINE CudaStorage(const CudaStorage &other); + + /// Create a new CudaStorage object from a temporary one, moving the + /// data + /// \param other The array to move + LIBRAPID_ALWAYS_INLINE CudaStorage(CudaStorage &&other) noexcept; + + /// Create a CudaStorage object from an std::initializer_list + /// \param list Initializer list of elements + LIBRAPID_ALWAYS_INLINE CudaStorage(const std::initializer_list &list); + + /// Create a CudaStorage object from an std::vector of values + /// \param vec The vector to fill with + LIBRAPID_ALWAYS_INLINE explicit CudaStorage(const std::vector &vec); + + template + static ShapeType defaultShape(); + + template + static CudaStorage fromData(const std::initializer_list &vec); + + template + static CudaStorage fromData(const std::vector &vec); + + /// Assignment operator for a CudaStorage object + /// \param other CudaStorage object to copy + /// \return *this + LIBRAPID_ALWAYS_INLINE CudaStorage &operator=(const CudaStorage &other); + + /// Move assignment operator for a CudaStorage object + /// \param other CudaStorage object to move + /// \return *this + LIBRAPID_ALWAYS_INLINE CudaStorage &operator=(CudaStorage &&other) noexcept; + + /// Free a CudaStorage object + ~CudaStorage(); + + /// \brief Set this CudaStorage object to reference the same data as \p other + /// \param other CudaStorage object to reference + void set(const CudaStorage &other); + + /// \brief Create a deep copy of this CudaStorage object + /// \return Deep copy of this CudaStorage object + CudaStorage copy() const; + + /// Resize a CudaStorage object to \p size elements. Existing elements are preserved where + /// possible. + /// \param size Number of elements + /// \see resize(SizeType, int) + LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize); + + /// Resize a CudaStorage object to \p size elements. Existing elements are not preserved. + /// This method of resizing is faster and more efficient than the version which preserves + /// the original data, but of course, this has the drawback that data will be lost. + /// \param size Number of elements + LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize, int); + + /// Return the number of elements in the CudaStorage object. + /// \return The number of elements + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE SizeType size() const noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::CudaRef + operator[](SizeType index) const; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::CudaRef + operator[](SizeType index); + + /// Return the underlying pointer to the data + /// \return The underlying pointer to the data + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer data() const noexcept; + + /// Returns the pointer to the first element of the CudaStorage object + /// \return Pointer to the first element of the CudaStorage object + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer begin() const noexcept; + + /// Returns the pointer to the last element of the CudaStorage object + /// \return A pointer to the last element of the CudaStorage + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer end() const noexcept; + + private: + // Copy data from \p begin to \p end into this Storage object + /// \tparam P Pointer type + /// \param begin Beginning of data to copy + /// \param end End of data to copy + template + LIBRAPID_ALWAYS_INLINE void initData(P begin, P end); + + /// Resize the Storage Object to \p newSize elements, retaining existing + /// data. + /// \param newSize New size of the Storage object + LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize, int); + + /// Resize the Storage object to \p newSize elements. Note this does not + /// initialize the new elements or maintain existing data. + /// \param newSize New size of the Storage object + LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize); + + Pointer m_begin = nullptr; + size_t m_size; + bool m_ownsData = true; + }; + + namespace detail { + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T *__restrict cudaSafeAllocate(size_t size) { + static_assert(typetraits::TriviallyDefaultConstructible::value, + "Data type must be trivially constructable for use with CUDA"); + T *result; + // Round size up to nearest multiple of 32 + size = (size + size_t(31)) & ~size_t(31); + cudaSafeCall(cudaMallocAsync(&result, sizeof(T) * size, global::cudaStream)); + return result; + } + + template + LIBRAPID_ALWAYS_INLINE void cudaSafeDeallocate(T *__restrict data) { + static_assert(std::is_trivially_destructible_v, + "Data type must be trivially constructable for use with CUDA"); + cudaSafeCall(cudaFreeAsync(data, global::cudaStream)); + } + + template + std::shared_ptr cudaSharedPtrAllocate(size_t size) { + return std::shared_ptr(cudaSafeAllocate(size), cudaSafeDeallocate); + } + + template + std::shared_ptr safePointerCopyCuda(T *ptr, bool ownsData = true) { + using RawPointer = T *; + using Pointer = std::shared_ptr; + + if (ownsData) { + return Pointer(ptr, cudaSafeDeallocate); + } else { + return Pointer(ptr, [](RawPointer) {}); + } + } + + template + std::shared_ptr safePointerCopyCuda(std::shared_ptr ptr, bool ownsData = true) { + using RawPointer = T *; + using Pointer = std::shared_ptr; + + if (ownsData) { + return Pointer(ptr.get(), cudaSafeDeallocate); + } else { + return Pointer(ptr.get(), [](RawPointer) {}); + } + } + } // namespace detail + + template + CudaStorage::CudaStorage(SizeType size) : + m_size(size), m_begin(detail::cudaSharedPtrAllocate(size)), m_ownsData(true) {} + + template + CudaStorage::CudaStorage(SizeType size, ConstReference value) : + m_size(size), m_begin(detail::cudaSharedPtrAllocate(size)), m_ownsData(true) { + // Fill the data with "value" + cuda::runKernel("fill", "fillArray", size, size, m_begin, value); + } + + template + CudaStorage::CudaStorage(Scalar *begin, SizeType size, bool ownsData) : + m_size(size), m_begin(detail::safePointerCopyCuda(begin, ownsData)), + m_ownsData(ownsData) {} + + template + CudaStorage::CudaStorage(const CudaStorage &other) : + m_size(other.m_size), m_begin(other.m_begin), m_ownsData(other.m_ownsData) {} + + template + CudaStorage::CudaStorage(CudaStorage &&other) noexcept : + m_begin(other.m_begin), m_size(other.m_size), m_ownsData(other.m_ownsData) { + other.m_begin = nullptr; + other.m_size = 0; + other.m_ownsData = false; + } + + template + CudaStorage::CudaStorage(const std::initializer_list &list) : + m_size(list.size()), m_begin(detail::cudaSharedPtrAllocate(list.size())), + m_ownsData(true) { + cudaSafeCall(cudaMemcpyAsync(m_begin.get(), + list.begin(), + sizeof(T) * m_size, + cudaMemcpyHostToDevice, + global::cudaStream)); + } + + template + CudaStorage::CudaStorage(const std::vector &list) : + m_size(list.size()), m_begin(detail::cudaSharedPtrAllocate(list.size())), + m_ownsData(true) { + cudaSafeCall(cudaMemcpyAsync(m_begin.get(), + list.begin(), + sizeof(T) * m_size, + cudaMemcpyHostToDevice, + global::cudaStream)); + } + + template + template + ShapeType CudaStorage::defaultShape() { + return ShapeType({0}); + } + + template + template + auto CudaStorage::fromData(const std::initializer_list &list) -> CudaStorage { + CudaStorage ret; + ret.initData(list.begin(), list.end()); + return ret; + } + + template + template + auto CudaStorage::fromData(const std::vector &vec) -> CudaStorage { + CudaStorage ret; + ret.initData(vec.begin(), vec.end()); + return ret; + } + + template + auto CudaStorage::operator=(const CudaStorage &other) -> CudaStorage & { + if (this != &other) { + if (m_ownsData) { + // If we own the data already, we can just copy the pointer since we know it won't + // affect anything else. The shared pointer deals with the reference counting, so + // we don't need to worry about other arrays that might be using the same data. + m_begin = other.m_begin; + m_size = other.m_size; + } else { + LIBRAPID_ASSERT(m_size == other.m_size, + "Cannot copy storage with {} elements to dependent storage with " + "{} elements", + other.m_size, + m_size); + + cudaSafeCall(cudaMemcpyAsync(m_begin.get(), + other.begin().get(), + sizeof(T) * m_size, + cudaMemcpyDeviceToDevice, + global::cudaStream)); + } + } + return *this; + } + + template + auto CudaStorage::operator=(CudaStorage &&other) noexcept -> CudaStorage & { + if (this != &other) { + if (m_ownsData) { + std::swap(m_begin, other.m_begin); + std::swap(m_size, other.m_size); + std::swap(m_ownsData, other.m_ownsData); + } else { + LIBRAPID_ASSERT( + size() == other.size(), + "Mismatched storage sizes. Cannot assign CUDA storage with {} elements to " + "dependent CUDA storage with {} elements", + other.size(), + size()); + + cudaSafeCall(cudaMemcpyAsync(m_begin.get(), + other.begin().get(), + sizeof(T) * m_size, + cudaMemcpyDeviceToDevice, + global::cudaStream)); + } + } + return *this; + } + + template + CudaStorage::~CudaStorage() { + // Data is freed automatically by the shared_ptr. A custom deleter is used to ensure that + // nothing happens if the storage is dependent. + } + + template + void CudaStorage::set(const CudaStorage &other) { + m_begin = other.m_begin; + m_size = other.m_size; + m_ownsData = other.m_ownsData; + } + + template + auto CudaStorage::copy() const -> CudaStorage { + CudaStorage ret(m_size); + + cudaSafeCall(cudaMemcpyAsync(ret.begin().get(), + m_begin.get(), + sizeof(T) * m_size, + cudaMemcpyDeviceToDevice, + global::cudaStream)); + + return ret; + } + + template + template + void CudaStorage::initData(P begin, P end) { + auto size = std::distance(begin, end); + m_begin = detail::cudaSharedPtrAllocate(size); + m_size = size; + auto tmpBegin = [begin]() { + if constexpr (std::is_pointer_v

) + return begin; + else + return &(*begin); + }(); + cudaSafeCall(cudaMemcpyAsync( + m_begin.get(), tmpBegin, sizeof(T) * size, cudaMemcpyDefault, global::cudaStream)); + } + + template + void CudaStorage::resize(SizeType newSize) { + resizeImpl(newSize); + } + + template + void CudaStorage::resize(SizeType newSize, int) { + resizeImpl(newSize, 0); + } + + template + void CudaStorage::resizeImpl(SizeType newSize) { + if (newSize == size()) { return; } + LIBRAPID_ASSERT(m_ownsData, "Dependent CUDA storage cannot be resized"); + + Pointer oldBegin = m_begin; + SizeType oldSize = m_size; + + // Reallocate + m_begin = detail::cudaSharedPtrAllocate(newSize); + m_size = newSize; + + // Copy old data + cudaSafeCall(cudaMemcpyAsync(m_begin.get(), + oldBegin.get(), + sizeof(T) * std::min(oldSize, newSize), + cudaMemcpyDeviceToDevice, + global::cudaStream)); + + m_size = newSize; + } + + template + void CudaStorage::resizeImpl(SizeType newSize, int) { + if (newSize == size()) return; + LIBRAPID_ASSERT(m_ownsData, "Dependent CUDA storage cannot be resized"); + m_begin = detail::cudaSharedPtrAllocate(newSize); + m_size = newSize; + } + + template + auto CudaStorage::size() const noexcept -> SizeType { + return m_size; + } + + template + auto CudaStorage::operator[](SizeType index) const -> detail::CudaRef { + return {m_begin, index}; + } + + template + auto CudaStorage::operator[](SizeType index) -> detail::CudaRef { + return {m_begin, index}; + } + + template + auto CudaStorage::data() const noexcept -> Pointer { + return m_begin; + } + + template + auto CudaStorage::begin() const noexcept -> Pointer { + return m_begin; + } + + template + auto CudaStorage::end() const noexcept -> Pointer { + return m_begin + m_size; + } } // namespace librapid -# if defined(FMT_API) +# if defined(FMT_API) LIBRAPID_SIMPLE_IO_IMPL(typename T, librapid::detail::CudaRef) -# endif // FM_API +# endif // FM_API #else // Trait implementations namespace librapid::typetraits { - // Define this so things still work correctly - template - struct IsCudaStorage : std::false_type {}; + // Define this so things still work correctly + template + struct IsCudaStorage : std::false_type {}; } // namespace librapid::typetraits #endif // LIBRAPID_HAS_CUDA #endif // LIBRAPID_ARRAY_CUDA_STORAGE_HPP diff --git a/librapid/include/librapid/cuda/exception.h b/librapid/include/librapid/cuda/exception.h index 956d9e66..a0271ff6 100644 --- a/librapid/include/librapid/cuda/exception.h +++ b/librapid/include/librapid/cuda/exception.h @@ -41,30 +41,30 @@ template class Exception : public Std_Exception { public: - //! @brief Static construction interface - //! @return Alwayss throws ( Located_Exception) - //! @param file file in which the Exception occurs - //! @param line line in which the Exception occurs - //! @param detailed details on the code fragment causing the Exception - static void throw_it(const char *file, const int line, const char *detailed = "-"); - - //! Static construction interface - //! @return Alwayss throws ( Located_Exception) - //! @param file file in which the Exception occurs - //! @param line line in which the Exception occurs - //! @param detailed details on the code fragment causing the Exception - static void throw_it(const char *file, const int line, const std::string &detailed); - - //! Destructor - virtual ~Exception() throw(); + //! @brief Static construction interface + //! @return Alwayss throws ( Located_Exception) + //! @param file file in which the Exception occurs + //! @param line line in which the Exception occurs + //! @param detailed details on the code fragment causing the Exception + static void throw_it(const char *file, const int line, const char *detailed = "-"); + + //! Static construction interface + //! @return Alwayss throws ( Located_Exception) + //! @param file file in which the Exception occurs + //! @param line line in which the Exception occurs + //! @param detailed details on the code fragment causing the Exception + static void throw_it(const char *file, const int line, const std::string &detailed); + + //! Destructor + virtual ~Exception() throw(); private: - //! Constructor, default (private) - Exception(); + //! Constructor, default (private) + Exception(); - //! Constructor, standard - //! @param str string returned by what() - explicit Exception(const std::string &str); + //! Constructor, standard + //! @param str string returned by what() + explicit Exception(const std::string &str); }; //////////////////////////////////////////////////////////////////////////////// @@ -73,9 +73,9 @@ class Exception : public Std_Exception { //////////////////////////////////////////////////////////////////////////////// template inline void handleException(const Exception_Typ &ex) { - std::cerr << ex.what() << std::endl; + std::cerr << ex.what() << std::endl; - exit(EXIT_FAILURE); + exit(EXIT_FAILURE); } //! Convenience macros @@ -101,14 +101,14 @@ inline void handleException(const Exception_Typ &ex) { //////////////////////////////////////////////////////////////////////////////// /*static*/ template void Exception::throw_it(const char *file, const int line, const char *detailed) { - std::stringstream s; + std::stringstream s; - // Quiet heavy-weight but exceptions are not for - // performance / release versions - s << "Exception in file '" << file << "' in line " << line << "\n" - << "Detailed description: " << detailed << "\n"; + // Quiet heavy-weight but exceptions are not for + // performance / release versions + s << "Exception in file '" << file << "' in line " << line << "\n" + << "Detailed description: " << detailed << "\n"; - throw Exception(s.str()); + throw Exception(s.str()); } //////////////////////////////////////////////////////////////////////////////// @@ -117,7 +117,7 @@ void Exception::throw_it(const char *file, const int line, const //////////////////////////////////////////////////////////////////////////////// /*static*/ template void Exception::throw_it(const char *file, const int line, const std::string &msg) { - throw_it(file, line, msg.c_str()); + throw_it(file, line, msg.c_str()); } //////////////////////////////////////////////////////////////////////////////// diff --git a/librapid/include/librapid/cuda/helper_cuda.h b/librapid/include/librapid/cuda/helper_cuda.h index b20ce627..5cc56f25 100644 --- a/librapid/include/librapid/cuda/helper_cuda.h +++ b/librapid/include/librapid/cuda/helper_cuda.h @@ -40,7 +40,7 @@ #include #ifndef EXIT_WAIVED -# define EXIT_WAIVED 2 +# define EXIT_WAIVED 2 #endif // Note, it is required that your SDK sample to include the proper header @@ -94,165 +94,165 @@ const char *_cudaGetErrorEnum(NppStatus error); template void check(T result, char const *const func, const char *const file, int const line) { - if (result) { - fprintf(stderr, - "CUDA error at %s:%d code=%d(%s) \"%s\" \n", - file, - line, - static_cast(result), - _cudaGetErrorEnum(result), - func); - exit(EXIT_FAILURE); - } + if (result) { + fprintf(stderr, + "CUDA error at %s:%d code=%d(%s) \"%s\" \n", + file, + line, + static_cast(result), + _cudaGetErrorEnum(result), + func); + exit(EXIT_FAILURE); + } } #ifdef __DRIVER_TYPES_H__ // This will output the proper CUDA error strings in the event // that a CUDA host call returns an error -# define checkCudaErrors(val) check((val), # val, __FILE__, __LINE__) +# define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__) // This will output the proper error string when calling cudaGetLastError -# define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__) +# define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__) inline void __getLastCudaError(const char *errorMessage, const char *file, const int line) { - cudaError_t err = cudaGetLastError(); - - if (cudaSuccess != err) { - fprintf(stderr, - "%s(%i) : getLastCudaError() CUDA error :" - " %s : (%d) %s.\n", - file, - line, - errorMessage, - static_cast(err), - cudaGetErrorString(err)); - exit(EXIT_FAILURE); - } + cudaError_t err = cudaGetLastError(); + + if (cudaSuccess != err) { + fprintf(stderr, + "%s(%i) : getLastCudaError() CUDA error :" + " %s : (%d) %s.\n", + file, + line, + errorMessage, + static_cast(err), + cudaGetErrorString(err)); + exit(EXIT_FAILURE); + } } // This will only print the proper error string when calling cudaGetLastError // but not exit program incase error detected. -# define printLastCudaError(msg) __printLastCudaError(msg, __FILE__, __LINE__) +# define printLastCudaError(msg) __printLastCudaError(msg, __FILE__, __LINE__) inline void __printLastCudaError(const char *errorMessage, const char *file, const int line) { - cudaError_t err = cudaGetLastError(); - - if (cudaSuccess != err) { - fprintf(stderr, - "%s(%i) : getLastCudaError() CUDA error :" - " %s : (%d) %s.\n", - file, - line, - errorMessage, - static_cast(err), - cudaGetErrorString(err)); - } + cudaError_t err = cudaGetLastError(); + + if (cudaSuccess != err) { + fprintf(stderr, + "%s(%i) : getLastCudaError() CUDA error :" + " %s : (%d) %s.\n", + file, + line, + errorMessage, + static_cast(err), + cudaGetErrorString(err)); + } } #endif #ifndef MAX -# define MAX(a, b) (a > b ? a : b) +# define MAX(a, b) (a > b ? a : b) #endif // Float To Int conversion inline int ftoi(float value) { - return (value >= 0 ? static_cast(value + 0.5) : static_cast(value - 0.5)); + return (value >= 0 ? static_cast(value + 0.5) : static_cast(value - 0.5)); } // Beginning of GPU Architecture definitions inline int _ConvertSMVer2Cores(int major, int minor) { - // Defines for GPU Architecture types (using the SM version to determine - // the # of cores per SM - typedef struct { - int SM; // 0xMm (hexidecimal notation), M = SM Major version, - // and m = SM minor version - int Cores; - } sSMtoCores; - - sSMtoCores nGpuArchCoresPerSM[] = {{0x30, 192}, - {0x32, 192}, - {0x35, 192}, - {0x37, 192}, - {0x50, 128}, - {0x52, 128}, - {0x53, 128}, - {0x60, 64}, - {0x61, 128}, - {0x62, 128}, - {0x70, 64}, - {0x72, 64}, - {0x75, 64}, - {0x80, 64}, - {0x86, 128}, - {-1, -1}}; - - int index = 0; - - while (nGpuArchCoresPerSM[index].SM != -1) { - if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { - return nGpuArchCoresPerSM[index].Cores; - } - - index++; - } - - // If we don't find the values, we default use the previous one - // to run properly - printf( - "MapSMtoCores for SM %d.%d is undefined." - " Default to use %d Cores/SM\n", - major, - minor, - nGpuArchCoresPerSM[index - 1].Cores); - return nGpuArchCoresPerSM[index - 1].Cores; + // Defines for GPU Architecture types (using the SM version to determine + // the # of cores per SM + typedef struct { + int SM; // 0xMm (hexidecimal notation), M = SM Major version, + // and m = SM minor version + int Cores; + } sSMtoCores; + + sSMtoCores nGpuArchCoresPerSM[] = {{0x30, 192}, + {0x32, 192}, + {0x35, 192}, + {0x37, 192}, + {0x50, 128}, + {0x52, 128}, + {0x53, 128}, + {0x60, 64}, + {0x61, 128}, + {0x62, 128}, + {0x70, 64}, + {0x72, 64}, + {0x75, 64}, + {0x80, 64}, + {0x86, 128}, + {-1, -1}}; + + int index = 0; + + while (nGpuArchCoresPerSM[index].SM != -1) { + if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { + return nGpuArchCoresPerSM[index].Cores; + } + + index++; + } + + // If we don't find the values, we default use the previous one + // to run properly + printf( + "MapSMtoCores for SM %d.%d is undefined." + " Default to use %d Cores/SM\n", + major, + minor, + nGpuArchCoresPerSM[index - 1].Cores); + return nGpuArchCoresPerSM[index - 1].Cores; } inline const char *_ConvertSMVer2ArchName(int major, int minor) { - // Defines for GPU Architecture types (using the SM version to determine - // the GPU Arch name) - typedef struct { - int SM; // 0xMm (hexidecimal notation), M = SM Major version, - // and m = SM minor version - const char *name; - } sSMtoArchName; - - sSMtoArchName nGpuArchNameSM[] = {{0x30, "Kepler"}, - {0x32, "Kepler"}, - {0x35, "Kepler"}, - {0x37, "Kepler"}, - {0x50, "Maxwell"}, - {0x52, "Maxwell"}, - {0x53, "Maxwell"}, - {0x60, "Pascal"}, - {0x61, "Pascal"}, - {0x62, "Pascal"}, - {0x70, "Volta"}, - {0x72, "Xavier"}, - {0x75, "Turing"}, - {0x80, "Ampere"}, - {0x86, "Ampere"}, - {-1, "Graphics Device"}}; - - int index = 0; - - while (nGpuArchNameSM[index].SM != -1) { - if (nGpuArchNameSM[index].SM == ((major << 4) + minor)) { - return nGpuArchNameSM[index].name; - } - - index++; - } - - // If we don't find the values, we default use the previous one - // to run properly - printf( - "MapSMtoArchName for SM %d.%d is undefined." - " Default to use %s\n", - major, - minor, - nGpuArchNameSM[index - 1].name); - return nGpuArchNameSM[index - 1].name; + // Defines for GPU Architecture types (using the SM version to determine + // the GPU Arch name) + typedef struct { + int SM; // 0xMm (hexidecimal notation), M = SM Major version, + // and m = SM minor version + const char *name; + } sSMtoArchName; + + sSMtoArchName nGpuArchNameSM[] = {{0x30, "Kepler"}, + {0x32, "Kepler"}, + {0x35, "Kepler"}, + {0x37, "Kepler"}, + {0x50, "Maxwell"}, + {0x52, "Maxwell"}, + {0x53, "Maxwell"}, + {0x60, "Pascal"}, + {0x61, "Pascal"}, + {0x62, "Pascal"}, + {0x70, "Volta"}, + {0x72, "Xavier"}, + {0x75, "Turing"}, + {0x80, "Ampere"}, + {0x86, "Ampere"}, + {-1, "Graphics Device"}}; + + int index = 0; + + while (nGpuArchNameSM[index].SM != -1) { + if (nGpuArchNameSM[index].SM == ((major << 4) + minor)) { + return nGpuArchNameSM[index].name; + } + + index++; + } + + // If we don't find the values, we default use the previous one + // to run properly + printf( + "MapSMtoArchName for SM %d.%d is undefined." + " Default to use %s\n", + major, + minor, + nGpuArchNameSM[index - 1].name); + return nGpuArchNameSM[index - 1].name; } // end of GPU Architecture definitions @@ -260,244 +260,244 @@ inline const char *_ConvertSMVer2ArchName(int major, int minor) { // General GPU Device CUDA Initialization inline int gpuDeviceInit(int devID) { - int device_count; - checkCudaErrors(cudaGetDeviceCount(&device_count)); - - if (device_count == 0) { - fprintf(stderr, - "gpuDeviceInit() CUDA error: " - "no devices supporting CUDA.\n"); - exit(EXIT_FAILURE); - } - - if (devID < 0) { devID = 0; } - - if (devID > device_count - 1) { - fprintf(stderr, "\n"); - fprintf(stderr, ">> %d CUDA capable GPU device(s) detected. <<\n", device_count); - fprintf(stderr, - ">> gpuDeviceInit (-device=%d) is not a valid" - " GPU device. <<\n", - devID); - fprintf(stderr, "\n"); - return -devID; - } - - int computeMode = -1, major = 0, minor = 0; - checkCudaErrors(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, devID)); - checkCudaErrors(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, devID)); - checkCudaErrors(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, devID)); - if (computeMode == cudaComputeModeProhibited) { - fprintf(stderr, - "Error: device is running in , no threads can use cudaSetDevice().\n"); - return -1; - } - - if (major < 1) { - fprintf(stderr, "gpuDeviceInit(): GPU device does not support CUDA.\n"); - exit(EXIT_FAILURE); - } - - checkCudaErrors(cudaSetDevice(devID)); - printf("gpuDeviceInit() CUDA Device [%d]: \"%s\n", devID, _ConvertSMVer2ArchName(major, minor)); - - return devID; + int device_count; + checkCudaErrors(cudaGetDeviceCount(&device_count)); + + if (device_count == 0) { + fprintf(stderr, + "gpuDeviceInit() CUDA error: " + "no devices supporting CUDA.\n"); + exit(EXIT_FAILURE); + } + + if (devID < 0) { devID = 0; } + + if (devID > device_count - 1) { + fprintf(stderr, "\n"); + fprintf(stderr, ">> %d CUDA capable GPU device(s) detected. <<\n", device_count); + fprintf(stderr, + ">> gpuDeviceInit (-device=%d) is not a valid" + " GPU device. <<\n", + devID); + fprintf(stderr, "\n"); + return -devID; + } + + int computeMode = -1, major = 0, minor = 0; + checkCudaErrors(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, devID)); + checkCudaErrors(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, devID)); + checkCudaErrors(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, devID)); + if (computeMode == cudaComputeModeProhibited) { + fprintf(stderr, + "Error: device is running in , no threads can use cudaSetDevice().\n"); + return -1; + } + + if (major < 1) { + fprintf(stderr, "gpuDeviceInit(): GPU device does not support CUDA.\n"); + exit(EXIT_FAILURE); + } + + checkCudaErrors(cudaSetDevice(devID)); + printf("gpuDeviceInit() CUDA Device [%d]: \"%s\n", devID, _ConvertSMVer2ArchName(major, minor)); + + return devID; } // This function returns the best GPU (with maximum GFLOPS) inline int gpuGetMaxGflopsDeviceId() { - int current_device = 0, sm_per_multiproc = 0; - int max_perf_device = 0; - int device_count = 0; - int devices_prohibited = 0; - - uint64_t max_compute_perf = 0; - checkCudaErrors(cudaGetDeviceCount(&device_count)); - - if (device_count == 0) { - fprintf(stderr, - "gpuGetMaxGflopsDeviceId() CUDA error:" - " no devices supporting CUDA.\n"); - exit(EXIT_FAILURE); - } - - // Find the best CUDA capable GPU device - current_device = 0; - - while (current_device < device_count) { - int computeMode = -1, major = 0, minor = 0; - checkCudaErrors( - cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, current_device)); - checkCudaErrors( - cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, current_device)); - checkCudaErrors( - cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, current_device)); - - // If this GPU is not running on Compute Mode prohibited, - // then we can add it to the list - if (computeMode != cudaComputeModeProhibited) { - if (major == 9999 && minor == 9999) { - sm_per_multiproc = 1; - } else { - sm_per_multiproc = _ConvertSMVer2Cores(major, minor); - } - int multiProcessorCount = 0, clockRate = 0; - checkCudaErrors(cudaDeviceGetAttribute( - &multiProcessorCount, cudaDevAttrMultiProcessorCount, current_device)); - cudaError_t result = - cudaDeviceGetAttribute(&clockRate, cudaDevAttrClockRate, current_device); - if (result != cudaSuccess) { - // If cudaDevAttrClockRate attribute is not supported we - // set clockRate as 1, to consider GPU with most SMs and CUDA - // Cores. - if (result == cudaErrorInvalidValue) { - clockRate = 1; - } else { - fprintf(stderr, - "CUDA error at %s:%d code=%d(%s) \n", - __FILE__, - __LINE__, - static_cast(result), - _cudaGetErrorEnum(result)); - exit(EXIT_FAILURE); - } - } - uint64_t compute_perf = (uint64_t)multiProcessorCount * sm_per_multiproc * clockRate; - - if (compute_perf > max_compute_perf) { - max_compute_perf = compute_perf; - max_perf_device = current_device; - } - } else { - devices_prohibited++; - } - - ++current_device; - } - - if (devices_prohibited == device_count) { - fprintf(stderr, - "gpuGetMaxGflopsDeviceId() CUDA error:" - " all devices have compute mode prohibited.\n"); - exit(EXIT_FAILURE); - } - - return max_perf_device; + int current_device = 0, sm_per_multiproc = 0; + int max_perf_device = 0; + int device_count = 0; + int devices_prohibited = 0; + + uint64_t max_compute_perf = 0; + checkCudaErrors(cudaGetDeviceCount(&device_count)); + + if (device_count == 0) { + fprintf(stderr, + "gpuGetMaxGflopsDeviceId() CUDA error:" + " no devices supporting CUDA.\n"); + exit(EXIT_FAILURE); + } + + // Find the best CUDA capable GPU device + current_device = 0; + + while (current_device < device_count) { + int computeMode = -1, major = 0, minor = 0; + checkCudaErrors( + cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, current_device)); + checkCudaErrors( + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, current_device)); + checkCudaErrors( + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, current_device)); + + // If this GPU is not running on Compute Mode prohibited, + // then we can add it to the list + if (computeMode != cudaComputeModeProhibited) { + if (major == 9999 && minor == 9999) { + sm_per_multiproc = 1; + } else { + sm_per_multiproc = _ConvertSMVer2Cores(major, minor); + } + int multiProcessorCount = 0, clockRate = 0; + checkCudaErrors(cudaDeviceGetAttribute( + &multiProcessorCount, cudaDevAttrMultiProcessorCount, current_device)); + cudaError_t result = + cudaDeviceGetAttribute(&clockRate, cudaDevAttrClockRate, current_device); + if (result != cudaSuccess) { + // If cudaDevAttrClockRate attribute is not supported we + // set clockRate as 1, to consider GPU with most SMs and CUDA + // Cores. + if (result == cudaErrorInvalidValue) { + clockRate = 1; + } else { + fprintf(stderr, + "CUDA error at %s:%d code=%d(%s) \n", + __FILE__, + __LINE__, + static_cast(result), + _cudaGetErrorEnum(result)); + exit(EXIT_FAILURE); + } + } + uint64_t compute_perf = (uint64_t)multiProcessorCount * sm_per_multiproc * clockRate; + + if (compute_perf > max_compute_perf) { + max_compute_perf = compute_perf; + max_perf_device = current_device; + } + } else { + devices_prohibited++; + } + + ++current_device; + } + + if (devices_prohibited == device_count) { + fprintf(stderr, + "gpuGetMaxGflopsDeviceId() CUDA error:" + " all devices have compute mode prohibited.\n"); + exit(EXIT_FAILURE); + } + + return max_perf_device; } // Initialization code to find the best CUDA Device inline int findCudaDevice(int argc, const char **argv) { - int devID = 0; - - // If the command-line has a device number specified, use it - if (checkCmdLineFlag(argc, argv, "device")) { - devID = getCmdLineArgumentInt(argc, argv, "device="); - - if (devID < 0) { - printf("Invalid command line parameter\n "); - exit(EXIT_FAILURE); - } else { - devID = gpuDeviceInit(devID); - - if (devID < 0) { - printf("exiting...\n"); - exit(EXIT_FAILURE); - } - } - } else { - // Otherwise pick the device with highest Gflops/s - devID = gpuGetMaxGflopsDeviceId(); - checkCudaErrors(cudaSetDevice(devID)); - int major = 0, minor = 0; - checkCudaErrors(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, devID)); - checkCudaErrors(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, devID)); - printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", - devID, - _ConvertSMVer2ArchName(major, minor), - major, - minor); - } - - return devID; + int devID = 0; + + // If the command-line has a device number specified, use it + if (checkCmdLineFlag(argc, argv, "device")) { + devID = getCmdLineArgumentInt(argc, argv, "device="); + + if (devID < 0) { + printf("Invalid command line parameter\n "); + exit(EXIT_FAILURE); + } else { + devID = gpuDeviceInit(devID); + + if (devID < 0) { + printf("exiting...\n"); + exit(EXIT_FAILURE); + } + } + } else { + // Otherwise pick the device with highest Gflops/s + devID = gpuGetMaxGflopsDeviceId(); + checkCudaErrors(cudaSetDevice(devID)); + int major = 0, minor = 0; + checkCudaErrors(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, devID)); + checkCudaErrors(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, devID)); + printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", + devID, + _ConvertSMVer2ArchName(major, minor), + major, + minor); + } + + return devID; } inline int findIntegratedGPU() { - int current_device = 0; - int device_count = 0; - int devices_prohibited = 0; - - checkCudaErrors(cudaGetDeviceCount(&device_count)); - - if (device_count == 0) { - fprintf(stderr, "CUDA error: no devices supporting CUDA.\n"); - exit(EXIT_FAILURE); - } - - // Find the integrated GPU which is compute capable - while (current_device < device_count) { - int computeMode = -1, integrated = -1; - checkCudaErrors( - cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, current_device)); - checkCudaErrors(cudaDeviceGetAttribute(&integrated, cudaDevAttrIntegrated, current_device)); - // If GPU is integrated and is not running on Compute Mode prohibited, - // then cuda can map to GLES resource - if (integrated && (computeMode != cudaComputeModeProhibited)) { - checkCudaErrors(cudaSetDevice(current_device)); - - int major = 0, minor = 0; - checkCudaErrors( - cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, current_device)); - checkCudaErrors( - cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, current_device)); - printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", - current_device, - _ConvertSMVer2ArchName(major, minor), - major, - minor); - - return current_device; - } else { - devices_prohibited++; - } - - current_device++; - } - - if (devices_prohibited == device_count) { - fprintf(stderr, - "CUDA error:" - " No GLES-CUDA Interop capable GPU found.\n"); - exit(EXIT_FAILURE); - } - - return -1; + int current_device = 0; + int device_count = 0; + int devices_prohibited = 0; + + checkCudaErrors(cudaGetDeviceCount(&device_count)); + + if (device_count == 0) { + fprintf(stderr, "CUDA error: no devices supporting CUDA.\n"); + exit(EXIT_FAILURE); + } + + // Find the integrated GPU which is compute capable + while (current_device < device_count) { + int computeMode = -1, integrated = -1; + checkCudaErrors( + cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, current_device)); + checkCudaErrors(cudaDeviceGetAttribute(&integrated, cudaDevAttrIntegrated, current_device)); + // If GPU is integrated and is not running on Compute Mode prohibited, + // then cuda can map to GLES resource + if (integrated && (computeMode != cudaComputeModeProhibited)) { + checkCudaErrors(cudaSetDevice(current_device)); + + int major = 0, minor = 0; + checkCudaErrors( + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, current_device)); + checkCudaErrors( + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, current_device)); + printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", + current_device, + _ConvertSMVer2ArchName(major, minor), + major, + minor); + + return current_device; + } else { + devices_prohibited++; + } + + current_device++; + } + + if (devices_prohibited == device_count) { + fprintf(stderr, + "CUDA error:" + " No GLES-CUDA Interop capable GPU found.\n"); + exit(EXIT_FAILURE); + } + + return -1; } // General check for CUDA GPU SM Capabilities inline bool checkCudaCapabilities(int major_version, int minor_version) { - int dev; - int major = 0, minor = 0; - - checkCudaErrors(cudaGetDevice(&dev)); - checkCudaErrors(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, dev)); - checkCudaErrors(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, dev)); - - if ((major > major_version) || (major == major_version && minor >= minor_version)) { - printf(" Device %d: <%16s >, Compute SM %d.%d detected\n", - dev, - _ConvertSMVer2ArchName(major, minor), - major, - minor); - return true; - } else { - printf( - " No GPU device was found that can support " - "CUDA compute capability %d.%d.\n", - major_version, - minor_version); - return false; - } + int dev; + int major = 0, minor = 0; + + checkCudaErrors(cudaGetDevice(&dev)); + checkCudaErrors(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, dev)); + checkCudaErrors(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, dev)); + + if ((major > major_version) || (major == major_version && minor >= minor_version)) { + printf(" Device %d: <%16s >, Compute SM %d.%d detected\n", + dev, + _ConvertSMVer2ArchName(major, minor), + major, + minor); + return true; + } else { + printf( + " No GPU device was found that can support " + "CUDA compute capability %d.%d.\n", + major_version, + minor_version); + return false; + } } #endif diff --git a/librapid/include/librapid/cuda/helper_cuda_drvapi.h b/librapid/include/librapid/cuda/helper_cuda_drvapi.h index 2f61571e..73b83508 100644 --- a/librapid/include/librapid/cuda/helper_cuda_drvapi.h +++ b/librapid/include/librapid/cuda/helper_cuda_drvapi.h @@ -39,19 +39,19 @@ #include #ifndef MAX -# define MAX(a, b) (a > b ? a : b) +# define MAX(a, b) (a > b ? a : b) #endif #ifndef COMMON_HELPER_CUDA_H_ inline int ftoi(float value) { - return (value >= 0 ? static_cast(value + 0.5) : static_cast(value - 0.5)); + return (value >= 0 ? static_cast(value + 0.5) : static_cast(value - 0.5)); } #endif #ifndef EXIT_WAIVED -# define EXIT_WAIVED 2 +# define EXIT_WAIVED 2 #endif //////////////////////////////////////////////////////////////////////////////// @@ -62,327 +62,327 @@ inline int ftoi(float value) { #ifdef __cuda_cuda_h__ // This will output the proper CUDA error strings in the event that a CUDA host // call returns an error -# ifndef checkCudaErrors -# define checkCudaErrors(err) __checkCudaErrors(err, __FILE__, __LINE__) +# ifndef checkCudaErrors +# define checkCudaErrors(err) __checkCudaErrors(err, __FILE__, __LINE__) // These are the inline versions for all of the SDK helper functions inline void __checkCudaErrors(CUresult err, const char *file, const int line) { - if (CUDA_SUCCESS != err) { - const char *errorStr = NULL; - cuGetErrorString(err, &errorStr); - fprintf(stderr, - "checkCudaErrors() Driver API error = %04d \"%s\" from file <%s>, " - "line %i.\n", - err, - errorStr, - file, - line); - exit(EXIT_FAILURE); - } + if (CUDA_SUCCESS != err) { + const char *errorStr = NULL; + cuGetErrorString(err, &errorStr); + fprintf(stderr, + "checkCudaErrors() Driver API error = %04d \"%s\" from file <%s>, " + "line %i.\n", + err, + errorStr, + file, + line); + exit(EXIT_FAILURE); + } } -# endif +# endif // This function wraps the CUDA Driver API into a template function template inline void getCudaAttribute(T *attribute, CUdevice_attribute device_attribute, int device) { - checkCudaErrors(cuDeviceGetAttribute(attribute, device_attribute, device)); + checkCudaErrors(cuDeviceGetAttribute(attribute, device_attribute, device)); } #endif // Beginning of GPU Architecture definitions inline int _ConvertSMVer2CoresDRV(int major, int minor) { - // Defines for GPU Architecture types (using the SM version to determine the - // # of cores per SM - typedef struct { - int SM; // 0xMm (hexidecimal notation), M = SM Major version, and m = SM - // minor version - int Cores; - } sSMtoCores; - - sSMtoCores nGpuArchCoresPerSM[] = {{0x30, 192}, - {0x32, 192}, - {0x35, 192}, - {0x37, 192}, - {0x50, 128}, - {0x52, 128}, - {0x53, 128}, - {0x60, 64}, - {0x61, 128}, - {0x62, 128}, - {0x70, 64}, - {0x72, 64}, - {0x75, 64}, - {0x80, 64}, - {0x86, 128}, - {-1, -1}}; - - int index = 0; - - while (nGpuArchCoresPerSM[index].SM != -1) { - if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { - return nGpuArchCoresPerSM[index].Cores; - } - - index++; - } - - // If we don't find the values, we default use the previous one to run - // properly - printf("MapSMtoCores for SM %d.%d is undefined. Default to use %d Cores/SM\n", - major, - minor, - nGpuArchCoresPerSM[index - 1].Cores); - return nGpuArchCoresPerSM[index - 1].Cores; + // Defines for GPU Architecture types (using the SM version to determine the + // # of cores per SM + typedef struct { + int SM; // 0xMm (hexidecimal notation), M = SM Major version, and m = SM + // minor version + int Cores; + } sSMtoCores; + + sSMtoCores nGpuArchCoresPerSM[] = {{0x30, 192}, + {0x32, 192}, + {0x35, 192}, + {0x37, 192}, + {0x50, 128}, + {0x52, 128}, + {0x53, 128}, + {0x60, 64}, + {0x61, 128}, + {0x62, 128}, + {0x70, 64}, + {0x72, 64}, + {0x75, 64}, + {0x80, 64}, + {0x86, 128}, + {-1, -1}}; + + int index = 0; + + while (nGpuArchCoresPerSM[index].SM != -1) { + if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { + return nGpuArchCoresPerSM[index].Cores; + } + + index++; + } + + // If we don't find the values, we default use the previous one to run + // properly + printf("MapSMtoCores for SM %d.%d is undefined. Default to use %d Cores/SM\n", + major, + minor, + nGpuArchCoresPerSM[index - 1].Cores); + return nGpuArchCoresPerSM[index - 1].Cores; } // end of GPU Architecture definitions #ifdef __cuda_cuda_h__ // General GPU Device CUDA Initialization inline int gpuDeviceInitDRV(int ARGC, const char **ARGV) { - int cuDevice = 0; - int deviceCount = 0; - checkCudaErrors(cuInit(0)); + int cuDevice = 0; + int deviceCount = 0; + checkCudaErrors(cuInit(0)); - checkCudaErrors(cuDeviceGetCount(&deviceCount)); + checkCudaErrors(cuDeviceGetCount(&deviceCount)); - if (deviceCount == 0) { - fprintf(stderr, "cudaDeviceInit error: no devices supporting CUDA\n"); - exit(EXIT_FAILURE); - } + if (deviceCount == 0) { + fprintf(stderr, "cudaDeviceInit error: no devices supporting CUDA\n"); + exit(EXIT_FAILURE); + } - int dev = 0; - dev = getCmdLineArgumentInt(ARGC, (const char **)ARGV, "device="); + int dev = 0; + dev = getCmdLineArgumentInt(ARGC, (const char **)ARGV, "device="); - if (dev < 0) { dev = 0; } + if (dev < 0) { dev = 0; } - if (dev > deviceCount - 1) { - fprintf(stderr, "\n"); - fprintf(stderr, ">> %d CUDA capable GPU device(s) detected. <<\n", deviceCount); - fprintf(stderr, ">> cudaDeviceInit (-device=%d) is not a valid GPU device. <<\n", dev); - fprintf(stderr, "\n"); - return -dev; - } + if (dev > deviceCount - 1) { + fprintf(stderr, "\n"); + fprintf(stderr, ">> %d CUDA capable GPU device(s) detected. <<\n", deviceCount); + fprintf(stderr, ">> cudaDeviceInit (-device=%d) is not a valid GPU device. <<\n", dev); + fprintf(stderr, "\n"); + return -dev; + } - checkCudaErrors(cuDeviceGet(&cuDevice, dev)); - char name[100]; - checkCudaErrors(cuDeviceGetName(name, 100, cuDevice)); + checkCudaErrors(cuDeviceGet(&cuDevice, dev)); + char name[100]; + checkCudaErrors(cuDeviceGetName(name, 100, cuDevice)); - int computeMode; - getCudaAttribute(&computeMode, CU_DEVICE_ATTRIBUTE_COMPUTE_MODE, dev); + int computeMode; + getCudaAttribute(&computeMode, CU_DEVICE_ATTRIBUTE_COMPUTE_MODE, dev); - if (computeMode == CU_COMPUTEMODE_PROHIBITED) { - fprintf(stderr, - "Error: device is running in , no " - "threads can use this CUDA Device.\n"); - return -1; - } + if (computeMode == CU_COMPUTEMODE_PROHIBITED) { + fprintf(stderr, + "Error: device is running in , no " + "threads can use this CUDA Device.\n"); + return -1; + } - if (checkCmdLineFlag(ARGC, (const char **)ARGV, "quiet") == false) { - printf("gpuDeviceInitDRV() Using CUDA Device [%d]: %s\n", dev, name); - } + if (checkCmdLineFlag(ARGC, (const char **)ARGV, "quiet") == false) { + printf("gpuDeviceInitDRV() Using CUDA Device [%d]: %s\n", dev, name); + } - return dev; + return dev; } // This function returns the best GPU based on performance inline int gpuGetMaxGflopsDeviceIdDRV() { - CUdevice current_device = 0; - CUdevice max_perf_device = 0; - int device_count = 0; - int sm_per_multiproc = 0; - unsigned long long max_compute_perf = 0; - int major = 0; - int minor = 0; - int multiProcessorCount; - int clockRate; - int devices_prohibited = 0; - - cuInit(0); - checkCudaErrors(cuDeviceGetCount(&device_count)); - - if (device_count == 0) { - fprintf(stderr, "gpuGetMaxGflopsDeviceIdDRV error: no devices supporting CUDA\n"); - exit(EXIT_FAILURE); - } - - // Find the best CUDA capable GPU device - current_device = 0; - - while (current_device < device_count) { - checkCudaErrors(cuDeviceGetAttribute( - &multiProcessorCount, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, current_device)); - checkCudaErrors( - cuDeviceGetAttribute(&clockRate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, current_device)); - checkCudaErrors(cuDeviceGetAttribute( - &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, current_device)); - checkCudaErrors(cuDeviceGetAttribute( - &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, current_device)); - - int computeMode; - getCudaAttribute(&computeMode, CU_DEVICE_ATTRIBUTE_COMPUTE_MODE, current_device); - - if (computeMode != CU_COMPUTEMODE_PROHIBITED) { - if (major == 9999 && minor == 9999) { - sm_per_multiproc = 1; - } else { - sm_per_multiproc = _ConvertSMVer2CoresDRV(major, minor); - } - - unsigned long long compute_perf = - (unsigned long long)(multiProcessorCount * sm_per_multiproc * clockRate); - - if (compute_perf > max_compute_perf) { - max_compute_perf = compute_perf; - max_perf_device = current_device; - } - } else { - devices_prohibited++; - } - - ++current_device; - } - - if (devices_prohibited == device_count) { - fprintf(stderr, - "gpuGetMaxGflopsDeviceIdDRV error: all devices have compute mode " - "prohibited.\n"); - exit(EXIT_FAILURE); - } - - return max_perf_device; + CUdevice current_device = 0; + CUdevice max_perf_device = 0; + int device_count = 0; + int sm_per_multiproc = 0; + unsigned long long max_compute_perf = 0; + int major = 0; + int minor = 0; + int multiProcessorCount; + int clockRate; + int devices_prohibited = 0; + + cuInit(0); + checkCudaErrors(cuDeviceGetCount(&device_count)); + + if (device_count == 0) { + fprintf(stderr, "gpuGetMaxGflopsDeviceIdDRV error: no devices supporting CUDA\n"); + exit(EXIT_FAILURE); + } + + // Find the best CUDA capable GPU device + current_device = 0; + + while (current_device < device_count) { + checkCudaErrors(cuDeviceGetAttribute( + &multiProcessorCount, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, current_device)); + checkCudaErrors( + cuDeviceGetAttribute(&clockRate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, current_device)); + checkCudaErrors(cuDeviceGetAttribute( + &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, current_device)); + checkCudaErrors(cuDeviceGetAttribute( + &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, current_device)); + + int computeMode; + getCudaAttribute(&computeMode, CU_DEVICE_ATTRIBUTE_COMPUTE_MODE, current_device); + + if (computeMode != CU_COMPUTEMODE_PROHIBITED) { + if (major == 9999 && minor == 9999) { + sm_per_multiproc = 1; + } else { + sm_per_multiproc = _ConvertSMVer2CoresDRV(major, minor); + } + + unsigned long long compute_perf = + (unsigned long long)(multiProcessorCount * sm_per_multiproc * clockRate); + + if (compute_perf > max_compute_perf) { + max_compute_perf = compute_perf; + max_perf_device = current_device; + } + } else { + devices_prohibited++; + } + + ++current_device; + } + + if (devices_prohibited == device_count) { + fprintf(stderr, + "gpuGetMaxGflopsDeviceIdDRV error: all devices have compute mode " + "prohibited.\n"); + exit(EXIT_FAILURE); + } + + return max_perf_device; } // General initialization call to pick the best CUDA Device inline CUdevice findCudaDeviceDRV(int argc, const char **argv) { - CUdevice cuDevice; - int devID = 0; - - // If the command-line has a device number specified, use it - if (checkCmdLineFlag(argc, (const char **)argv, "device")) { - devID = gpuDeviceInitDRV(argc, argv); - - if (devID < 0) { - printf("exiting...\n"); - exit(EXIT_SUCCESS); - } - } else { - // Otherwise pick the device with highest Gflops/s - char name[100]; - devID = gpuGetMaxGflopsDeviceIdDRV(); - checkCudaErrors(cuDeviceGet(&cuDevice, devID)); - cuDeviceGetName(name, 100, cuDevice); - printf("> Using CUDA Device [%d]: %s\n", devID, name); - } - - cuDeviceGet(&cuDevice, devID); - - return cuDevice; + CUdevice cuDevice; + int devID = 0; + + // If the command-line has a device number specified, use it + if (checkCmdLineFlag(argc, (const char **)argv, "device")) { + devID = gpuDeviceInitDRV(argc, argv); + + if (devID < 0) { + printf("exiting...\n"); + exit(EXIT_SUCCESS); + } + } else { + // Otherwise pick the device with highest Gflops/s + char name[100]; + devID = gpuGetMaxGflopsDeviceIdDRV(); + checkCudaErrors(cuDeviceGet(&cuDevice, devID)); + cuDeviceGetName(name, 100, cuDevice); + printf("> Using CUDA Device [%d]: %s\n", devID, name); + } + + cuDeviceGet(&cuDevice, devID); + + return cuDevice; } inline CUdevice findIntegratedGPUDrv() { - CUdevice current_device = 0; - int device_count = 0; - int devices_prohibited = 0; - int isIntegrated; - - cuInit(0); - checkCudaErrors(cuDeviceGetCount(&device_count)); - - if (device_count == 0) { - fprintf(stderr, "CUDA error: no devices supporting CUDA.\n"); - exit(EXIT_FAILURE); - } - - // Find the integrated GPU which is compute capable - while (current_device < device_count) { - int computeMode = -1; - checkCudaErrors( - cuDeviceGetAttribute(&isIntegrated, CU_DEVICE_ATTRIBUTE_INTEGRATED, current_device)); - checkCudaErrors( - cuDeviceGetAttribute(&computeMode, CU_DEVICE_ATTRIBUTE_COMPUTE_MODE, current_device)); - - // If GPU is integrated and is not running on Compute Mode prohibited - // use that - if (isIntegrated && (computeMode != CU_COMPUTEMODE_PROHIBITED)) { - int major = 0, minor = 0; - char deviceName[256]; - checkCudaErrors(cuDeviceGetAttribute( - &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, current_device)); - checkCudaErrors(cuDeviceGetAttribute( - &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, current_device)); - checkCudaErrors(cuDeviceGetName(deviceName, 256, current_device)); - printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", - current_device, - deviceName, - major, - minor); - - return current_device; - } else { - devices_prohibited++; - } - - current_device++; - } - - if (devices_prohibited == device_count) { - fprintf(stderr, "CUDA error: No Integrated CUDA capable GPU found.\n"); - exit(EXIT_FAILURE); - } - - return -1; + CUdevice current_device = 0; + int device_count = 0; + int devices_prohibited = 0; + int isIntegrated; + + cuInit(0); + checkCudaErrors(cuDeviceGetCount(&device_count)); + + if (device_count == 0) { + fprintf(stderr, "CUDA error: no devices supporting CUDA.\n"); + exit(EXIT_FAILURE); + } + + // Find the integrated GPU which is compute capable + while (current_device < device_count) { + int computeMode = -1; + checkCudaErrors( + cuDeviceGetAttribute(&isIntegrated, CU_DEVICE_ATTRIBUTE_INTEGRATED, current_device)); + checkCudaErrors( + cuDeviceGetAttribute(&computeMode, CU_DEVICE_ATTRIBUTE_COMPUTE_MODE, current_device)); + + // If GPU is integrated and is not running on Compute Mode prohibited + // use that + if (isIntegrated && (computeMode != CU_COMPUTEMODE_PROHIBITED)) { + int major = 0, minor = 0; + char deviceName[256]; + checkCudaErrors(cuDeviceGetAttribute( + &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, current_device)); + checkCudaErrors(cuDeviceGetAttribute( + &minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, current_device)); + checkCudaErrors(cuDeviceGetName(deviceName, 256, current_device)); + printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", + current_device, + deviceName, + major, + minor); + + return current_device; + } else { + devices_prohibited++; + } + + current_device++; + } + + if (devices_prohibited == device_count) { + fprintf(stderr, "CUDA error: No Integrated CUDA capable GPU found.\n"); + exit(EXIT_FAILURE); + } + + return -1; } // General check for CUDA GPU SM Capabilities inline bool checkCudaCapabilitiesDRV(int major_version, int minor_version, int devID) { - CUdevice cuDevice; - char name[256]; - int major = 0, minor = 0; - - checkCudaErrors(cuDeviceGet(&cuDevice, devID)); - checkCudaErrors(cuDeviceGetName(name, 100, cuDevice)); - checkCudaErrors( - cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cuDevice)); - checkCudaErrors( - cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cuDevice)); - - if ((major > major_version) || (major == major_version && minor >= minor_version)) { - printf("> Device %d: <%16s >, Compute SM %d.%d detected\n", devID, name, major, minor); - return true; - } else { - printf( - "No GPU device was found that can support CUDA compute capability " - "%d.%d.\n", - major_version, - minor_version); - return false; - } + CUdevice cuDevice; + char name[256]; + int major = 0, minor = 0; + + checkCudaErrors(cuDeviceGet(&cuDevice, devID)); + checkCudaErrors(cuDeviceGetName(name, 100, cuDevice)); + checkCudaErrors( + cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cuDevice)); + checkCudaErrors( + cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cuDevice)); + + if ((major > major_version) || (major == major_version && minor >= minor_version)) { + printf("> Device %d: <%16s >, Compute SM %d.%d detected\n", devID, name, major, minor); + return true; + } else { + printf( + "No GPU device was found that can support CUDA compute capability " + "%d.%d.\n", + major_version, + minor_version); + return false; + } } #endif bool inline findFatbinPath(const char *module_file, std::string &module_path, char **argv, - std::ostringstream &ostrm) { - char *actual_path = sdkFindFilePath(module_file, argv[0]); - - if (actual_path) { - module_path = actual_path; - } else { - printf("> findModulePath file not found: <%s> \n", module_file); - return false; - } - - if (module_path.empty()) { - printf("> findModulePath could not find file: <%s> \n", module_file); - return false; - } else { - printf("> findModulePath found file at <%s>\n", module_path.c_str()); - if (module_path.rfind("fatbin") != std::string::npos) { - std::ifstream fileIn(module_path.c_str(), std::ios::binary); - ostrm << fileIn.rdbuf(); - fileIn.close(); - } - return true; - } + std::ostringstream &ostrm) { + char *actual_path = sdkFindFilePath(module_file, argv[0]); + + if (actual_path) { + module_path = actual_path; + } else { + printf("> findModulePath file not found: <%s> \n", module_file); + return false; + } + + if (module_path.empty()) { + printf("> findModulePath could not find file: <%s> \n", module_file); + return false; + } else { + printf("> findModulePath found file at <%s>\n", module_path.c_str()); + if (module_path.rfind("fatbin") != std::string::npos) { + std::ifstream fileIn(module_path.c_str(), std::ios::binary); + ostrm << fileIn.rdbuf(); + fileIn.close(); + } + return true; + } } // end of CUDA Helper Functions diff --git a/librapid/include/librapid/cuda/helper_cusolver.h b/librapid/include/librapid/cuda/helper_cusolver.h index 5293805d..a35d31d8 100644 --- a/librapid/include/librapid/cuda/helper_cusolver.h +++ b/librapid/include/librapid/cuda/helper_cusolver.h @@ -39,74 +39,74 @@ #define SWITCH_CHAR '-' struct testOpts { - char *sparse_mat_filename; // by switch -F - const char *testFunc; // by switch -R - const char *reorder; // by switch -P - int lda; // by switch -lda + char *sparse_mat_filename; // by switch -F + const char *testFunc; // by switch -R + const char *reorder; // by switch -P + int lda; // by switch -lda }; double vec_norminf(int n, const double *x) { - double norminf = 0; - for (int j = 0; j < n; j++) { - double x_abs = fabs(x[j]); - norminf = (norminf > x_abs) ? norminf : x_abs; - } - return norminf; + double norminf = 0; + for (int j = 0; j < n; j++) { + double x_abs = fabs(x[j]); + norminf = (norminf > x_abs) ? norminf : x_abs; + } + return norminf; } /* * |A| = max { |A|*ones(m,1) } */ double mat_norminf(int m, int n, const double *A, int lda) { - double norminf = 0; - for (int i = 0; i < m; i++) { - double sum = 0.0; - for (int j = 0; j < n; j++) { - double A_abs = fabs(A[i + j * lda]); - sum += A_abs; - } - norminf = (norminf > sum) ? norminf : sum; - } - return norminf; + double norminf = 0; + for (int i = 0; i < m; i++) { + double sum = 0.0; + for (int j = 0; j < n; j++) { + double A_abs = fabs(A[i + j * lda]); + sum += A_abs; + } + norminf = (norminf > sum) ? norminf : sum; + } + return norminf; } /* * |A| = max { |A|*ones(m,1) } */ double csr_mat_norminf(int m, int n, int nnzA, const cusparseMatDescr_t descrA, - const double *csrValA, const int *csrRowPtrA, const int *csrColIndA) { - const int baseA = (CUSPARSE_INDEX_BASE_ONE == cusparseGetMatIndexBase(descrA)) ? 1 : 0; + const double *csrValA, const int *csrRowPtrA, const int *csrColIndA) { + const int baseA = (CUSPARSE_INDEX_BASE_ONE == cusparseGetMatIndexBase(descrA)) ? 1 : 0; - double norminf = 0; - for (int i = 0; i < m; i++) { - double sum = 0.0; - const int start = csrRowPtrA[i] - baseA; - const int end = csrRowPtrA[i + 1] - baseA; - for (int colidx = start; colidx < end; colidx++) { - // const int j = csrColIndA[colidx] - baseA; - double A_abs = fabs(csrValA[colidx]); - sum += A_abs; - } - norminf = (norminf > sum) ? norminf : sum; - } - return norminf; + double norminf = 0; + for (int i = 0; i < m; i++) { + double sum = 0.0; + const int start = csrRowPtrA[i] - baseA; + const int end = csrRowPtrA[i + 1] - baseA; + for (int colidx = start; colidx < end; colidx++) { + // const int j = csrColIndA[colidx] - baseA; + double A_abs = fabs(csrValA[colidx]); + sum += A_abs; + } + norminf = (norminf > sum) ? norminf : sum; + } + return norminf; } void display_matrix(int m, int n, int nnzA, const cusparseMatDescr_t descrA, const double *csrValA, - const int *csrRowPtrA, const int *csrColIndA) { - const int baseA = (CUSPARSE_INDEX_BASE_ONE == cusparseGetMatIndexBase(descrA)) ? 1 : 0; + const int *csrRowPtrA, const int *csrColIndA) { + const int baseA = (CUSPARSE_INDEX_BASE_ONE == cusparseGetMatIndexBase(descrA)) ? 1 : 0; - printf("m = %d, n = %d, nnz = %d, matlab base-1\n", m, n, nnzA); + printf("m = %d, n = %d, nnz = %d, matlab base-1\n", m, n, nnzA); - for (int row = 0; row < m; row++) { - const int start = csrRowPtrA[row] - baseA; - const int end = csrRowPtrA[row + 1] - baseA; - for (int colidx = start; colidx < end; colidx++) { - const int col = csrColIndA[colidx] - baseA; - double Areg = csrValA[colidx]; - printf("A(%d, %d) = %20.16E\n", row + 1, col + 1, Areg); - } - } + for (int row = 0; row < m; row++) { + const int start = csrRowPtrA[row] - baseA; + const int end = csrRowPtrA[row + 1] - baseA; + for (int colidx = start; colidx < end; colidx++) { + const int col = csrColIndA[colidx] - baseA; + double Areg = csrValA[colidx]; + printf("A(%d, %d) = %20.16E\n", row + 1, col + 1, Areg); + } + } } #endif diff --git a/librapid/include/librapid/cuda/helper_functions.h b/librapid/include/librapid/cuda/helper_functions.h index 45692152..be6371ba 100644 --- a/librapid/include/librapid/cuda/helper_functions.h +++ b/librapid/include/librapid/cuda/helper_functions.h @@ -31,7 +31,7 @@ #define COMMON_HELPER_FUNCTIONS_H_ #ifdef WIN32 -# pragma warning(disable : 4996) +# pragma warning(disable : 4996) #endif // includes, project @@ -51,7 +51,7 @@ #include "helper_string.h" // helper functions for string parsing #ifndef EXIT_WAIVED -# define EXIT_WAIVED 2 +# define EXIT_WAIVED 2 #endif #endif // COMMON_HELPER_FUNCTIONS_H_ diff --git a/librapid/include/librapid/cuda/helper_image.h b/librapid/include/librapid/cuda/helper_image.h index f4adfe6e..fbf1dfdf 100644 --- a/librapid/include/librapid/cuda/helper_image.h +++ b/librapid/include/librapid/cuda/helper_image.h @@ -40,284 +40,284 @@ #include #ifndef MIN -# define MIN(a, b) ((a < b) ? a : b) +# define MIN(a, b) ((a < b) ? a : b) #endif #ifndef MAX -# define MAX(a, b) ((a > b) ? a : b) +# define MAX(a, b) ((a > b) ? a : b) #endif #ifndef EXIT_WAIVED -# define EXIT_WAIVED 2 +# define EXIT_WAIVED 2 #endif #include "helper_string.h" // namespace unnamed (internal) namespace helper_image_internal { - //! size of PGM file header - const unsigned int PGMHeaderSize = 0x40; - - // types - - //! Data converter from unsigned char / unsigned byte to type T - template - struct ConverterFromUByte; - - //! Data converter from unsigned char / unsigned byte - template<> - struct ConverterFromUByte { - //! Conversion operator - //! @return converted value - //! @param val value to convert - float operator()(const unsigned char &val) { return static_cast(val); } - }; - - //! Data converter from unsigned char / unsigned byte to float - template<> - struct ConverterFromUByte { - //! Conversion operator - //! @return converted value - //! @param val value to convert - float operator()(const unsigned char &val) { return static_cast(val) / 255.0f; } - }; - - //! Data converter from unsigned char / unsigned byte to type T - template - struct ConverterToUByte; - - //! Data converter from unsigned char / unsigned byte to unsigned int - template<> - struct ConverterToUByte { - //! Conversion operator (essentially a passthru - //! @return converted value - //! @param val value to convert - unsigned char operator()(const unsigned char &val) { return val; } - }; - - //! Data converter from unsigned char / unsigned byte to unsigned int - template<> - struct ConverterToUByte { - //! Conversion operator - //! @return converted value - //! @param val value to convert - unsigned char operator()(const float &val) { - return static_cast(val * 255.0f); - } - }; + //! size of PGM file header + const unsigned int PGMHeaderSize = 0x40; + + // types + + //! Data converter from unsigned char / unsigned byte to type T + template + struct ConverterFromUByte; + + //! Data converter from unsigned char / unsigned byte + template<> + struct ConverterFromUByte { + //! Conversion operator + //! @return converted value + //! @param val value to convert + float operator()(const unsigned char &val) { return static_cast(val); } + }; + + //! Data converter from unsigned char / unsigned byte to float + template<> + struct ConverterFromUByte { + //! Conversion operator + //! @return converted value + //! @param val value to convert + float operator()(const unsigned char &val) { return static_cast(val) / 255.0f; } + }; + + //! Data converter from unsigned char / unsigned byte to type T + template + struct ConverterToUByte; + + //! Data converter from unsigned char / unsigned byte to unsigned int + template<> + struct ConverterToUByte { + //! Conversion operator (essentially a passthru + //! @return converted value + //! @param val value to convert + unsigned char operator()(const unsigned char &val) { return val; } + }; + + //! Data converter from unsigned char / unsigned byte to unsigned int + template<> + struct ConverterToUByte { + //! Conversion operator + //! @return converted value + //! @param val value to convert + unsigned char operator()(const float &val) { + return static_cast(val * 255.0f); + } + }; } // namespace helper_image_internal #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) -# ifndef FOPEN -# define FOPEN(fHandle, filename, mode) fopen_s(&fHandle, filename, mode) -# endif -# ifndef FOPEN_FAIL -# define FOPEN_FAIL(result) (result != 0) -# endif -# ifndef SSCANF -# define SSCANF sscanf_s -# endif +# ifndef FOPEN +# define FOPEN(fHandle, filename, mode) fopen_s(&fHandle, filename, mode) +# endif +# ifndef FOPEN_FAIL +# define FOPEN_FAIL(result) (result != 0) +# endif +# ifndef SSCANF +# define SSCANF sscanf_s +# endif #else -# ifndef FOPEN -# define FOPEN(fHandle, filename, mode) (fHandle = fopen(filename, mode)) -# endif -# ifndef FOPEN_FAIL -# define FOPEN_FAIL(result) (result == NULL) -# endif -# ifndef SSCANF -# define SSCANF sscanf -# endif +# ifndef FOPEN +# define FOPEN(fHandle, filename, mode) (fHandle = fopen(filename, mode)) +# endif +# ifndef FOPEN_FAIL +# define FOPEN_FAIL(result) (result == NULL) +# endif +# ifndef SSCANF +# define SSCANF sscanf +# endif #endif inline bool __loadPPM(const char *file, unsigned char **data, unsigned int *w, unsigned int *h, - unsigned int *channels) { - FILE *fp = NULL; - - if (FOPEN_FAIL(FOPEN(fp, file, "rb"))) { - std::cerr << "__LoadPPM() : Failed to open file: " << file << std::endl; - return false; - } - - // check header - char header[helper_image_internal::PGMHeaderSize]; - - if (fgets(header, helper_image_internal::PGMHeaderSize, fp) == NULL) { - std::cerr << "__LoadPPM() : reading PGM header returned NULL" << std::endl; - return false; - } - - if (strncmp(header, "P5", 2) == 0) { - *channels = 1; - } else if (strncmp(header, "P6", 2) == 0) { - *channels = 3; - } else { - std::cerr << "__LoadPPM() : File is not a PPM or PGM image" << std::endl; - *channels = 0; - return false; - } - - // parse header, read maxval, width and height - unsigned int width = 0; - unsigned int height = 0; - unsigned int maxval = 0; - unsigned int i = 0; - - while (i < 3) { - if (fgets(header, helper_image_internal::PGMHeaderSize, fp) == NULL) { - std::cerr << "__LoadPPM() : reading PGM header returned NULL" << std::endl; - return false; - } - - if (header[0] == '#') { continue; } - - if (i == 0) { - i += SSCANF(header, "%u %u %u", &width, &height, &maxval); - } else if (i == 1) { - i += SSCANF(header, "%u %u", &height, &maxval); - } else if (i == 2) { - i += SSCANF(header, "%u", &maxval); - } - } - - // check if given handle for the data is initialized - if (NULL != *data) { - if (*w != width || *h != height) { - std::cerr << "__LoadPPM() : Invalid image dimensions." << std::endl; - } - } else { - *data = (unsigned char *)malloc(sizeof(unsigned char) * width * height * *channels); - *w = width; - *h = height; - } - - // read and close file - if (fread(*data, sizeof(unsigned char), width * height * *channels, fp) == 0) { - std::cerr << "__LoadPPM() read data returned error." << std::endl; - } - - fclose(fp); - - return true; + unsigned int *channels) { + FILE *fp = NULL; + + if (FOPEN_FAIL(FOPEN(fp, file, "rb"))) { + std::cerr << "__LoadPPM() : Failed to open file: " << file << std::endl; + return false; + } + + // check header + char header[helper_image_internal::PGMHeaderSize]; + + if (fgets(header, helper_image_internal::PGMHeaderSize, fp) == NULL) { + std::cerr << "__LoadPPM() : reading PGM header returned NULL" << std::endl; + return false; + } + + if (strncmp(header, "P5", 2) == 0) { + *channels = 1; + } else if (strncmp(header, "P6", 2) == 0) { + *channels = 3; + } else { + std::cerr << "__LoadPPM() : File is not a PPM or PGM image" << std::endl; + *channels = 0; + return false; + } + + // parse header, read maxval, width and height + unsigned int width = 0; + unsigned int height = 0; + unsigned int maxval = 0; + unsigned int i = 0; + + while (i < 3) { + if (fgets(header, helper_image_internal::PGMHeaderSize, fp) == NULL) { + std::cerr << "__LoadPPM() : reading PGM header returned NULL" << std::endl; + return false; + } + + if (header[0] == '#') { continue; } + + if (i == 0) { + i += SSCANF(header, "%u %u %u", &width, &height, &maxval); + } else if (i == 1) { + i += SSCANF(header, "%u %u", &height, &maxval); + } else if (i == 2) { + i += SSCANF(header, "%u", &maxval); + } + } + + // check if given handle for the data is initialized + if (NULL != *data) { + if (*w != width || *h != height) { + std::cerr << "__LoadPPM() : Invalid image dimensions." << std::endl; + } + } else { + *data = (unsigned char *)malloc(sizeof(unsigned char) * width * height * *channels); + *w = width; + *h = height; + } + + // read and close file + if (fread(*data, sizeof(unsigned char), width * height * *channels, fp) == 0) { + std::cerr << "__LoadPPM() read data returned error." << std::endl; + } + + fclose(fp); + + return true; } template inline bool sdkLoadPGM(const char *file, T **data, unsigned int *w, unsigned int *h) { - unsigned char *idata = NULL; - unsigned int channels; + unsigned char *idata = NULL; + unsigned int channels; - if (true != __loadPPM(file, &idata, w, h, &channels)) { return false; } + if (true != __loadPPM(file, &idata, w, h, &channels)) { return false; } - unsigned int size = *w * *h * channels; + unsigned int size = *w * *h * channels; - // initialize mem if necessary - // the correct size is checked / set in loadPGMc() - if (NULL == *data) { *data = reinterpret_cast(malloc(sizeof(T) * size)); } + // initialize mem if necessary + // the correct size is checked / set in loadPGMc() + if (NULL == *data) { *data = reinterpret_cast(malloc(sizeof(T) * size)); } - // copy and cast data - std::transform(idata, idata + size, *data, helper_image_internal::ConverterFromUByte()); + // copy and cast data + std::transform(idata, idata + size, *data, helper_image_internal::ConverterFromUByte()); - free(idata); + free(idata); - return true; + return true; } template inline bool sdkLoadPPM4(const char *file, T **data, unsigned int *w, unsigned int *h) { - unsigned char *idata = 0; - unsigned int channels; - - if (__loadPPM(file, &idata, w, h, &channels)) { - // pad 4th component - int size = *w * *h; - // keep the original pointer - unsigned char *idata_orig = idata; - *data = reinterpret_cast(malloc(sizeof(T) * size * 4)); - unsigned char *ptr = *data; - - for (int i = 0; i < size; i++) { - *ptr++ = *idata++; - *ptr++ = *idata++; - *ptr++ = *idata++; - *ptr++ = 0; - } - - free(idata_orig); - return true; - } else { - free(idata); - return false; - } + unsigned char *idata = 0; + unsigned int channels; + + if (__loadPPM(file, &idata, w, h, &channels)) { + // pad 4th component + int size = *w * *h; + // keep the original pointer + unsigned char *idata_orig = idata; + *data = reinterpret_cast(malloc(sizeof(T) * size * 4)); + unsigned char *ptr = *data; + + for (int i = 0; i < size; i++) { + *ptr++ = *idata++; + *ptr++ = *idata++; + *ptr++ = *idata++; + *ptr++ = 0; + } + + free(idata_orig); + return true; + } else { + free(idata); + return false; + } } inline bool __savePPM(const char *file, unsigned char *data, unsigned int w, unsigned int h, - unsigned int channels) { - assert(NULL != data); - assert(w > 0); - assert(h > 0); + unsigned int channels) { + assert(NULL != data); + assert(w > 0); + assert(h > 0); - std::fstream fh(file, std::fstream::out | std::fstream::binary); + std::fstream fh(file, std::fstream::out | std::fstream::binary); - if (fh.bad()) { - std::cerr << "__savePPM() : Opening file failed." << std::endl; - return false; - } + if (fh.bad()) { + std::cerr << "__savePPM() : Opening file failed." << std::endl; + return false; + } - if (channels == 1) { - fh << "P5\n"; - } else if (channels == 3) { - fh << "P6\n"; - } else { - std::cerr << "__savePPM() : Invalid number of channels." << std::endl; - return false; - } + if (channels == 1) { + fh << "P5\n"; + } else if (channels == 3) { + fh << "P6\n"; + } else { + std::cerr << "__savePPM() : Invalid number of channels." << std::endl; + return false; + } - fh << w << "\n" << h << "\n" << 0xff << std::endl; + fh << w << "\n" << h << "\n" << 0xff << std::endl; - for (unsigned int i = 0; (i < (w * h * channels)) && fh.good(); ++i) { fh << data[i]; } + for (unsigned int i = 0; (i < (w * h * channels)) && fh.good(); ++i) { fh << data[i]; } - fh.flush(); + fh.flush(); - if (fh.bad()) { - std::cerr << "__savePPM() : Writing data failed." << std::endl; - return false; - } + if (fh.bad()) { + std::cerr << "__savePPM() : Writing data failed." << std::endl; + return false; + } - fh.close(); + fh.close(); - return true; + return true; } template inline bool sdkSavePGM(const char *file, T *data, unsigned int w, unsigned int h) { - unsigned int size = w * h; - unsigned char *idata = (unsigned char *)malloc(sizeof(unsigned char) * size); + unsigned int size = w * h; + unsigned char *idata = (unsigned char *)malloc(sizeof(unsigned char) * size); - std::transform(data, data + size, idata, helper_image_internal::ConverterToUByte()); + std::transform(data, data + size, idata, helper_image_internal::ConverterToUByte()); - // write file - bool result = __savePPM(file, idata, w, h, 1); + // write file + bool result = __savePPM(file, idata, w, h, 1); - // cleanup - free(idata); + // cleanup + free(idata); - return result; + return result; } inline bool sdkSavePPM4ub(const char *file, unsigned char *data, unsigned int w, unsigned int h) { - // strip 4th component - int size = w * h; - unsigned char *ndata = (unsigned char *)malloc(sizeof(unsigned char) * size * 3); - unsigned char *ptr = ndata; - - for (int i = 0; i < size; i++) { - *ptr++ = *data++; - *ptr++ = *data++; - *ptr++ = *data++; - data++; - } - - bool result = __savePPM(file, ndata, w, h, 3); - free(ndata); - return result; + // strip 4th component + int size = w * h; + unsigned char *ndata = (unsigned char *)malloc(sizeof(unsigned char) * size * 3); + unsigned char *ptr = ndata; + + for (int i = 0; i < size; i++) { + *ptr++ = *data++; + *ptr++ = *data++; + *ptr++ = *data++; + data++; + } + + bool result = __savePPM(file, ndata, w, h, 3); + free(ndata); + return result; } ////////////////////////////////////////////////////////////////////////////// @@ -330,55 +330,55 @@ inline bool sdkSavePPM4ub(const char *file, unsigned char *data, unsigned int w, ////////////////////////////////////////////////////////////////////////////// template inline bool sdkReadFile(const char *filename, T **data, unsigned int *len, bool verbose) { - // check input arguments - assert(NULL != filename); - assert(NULL != len); - - // intermediate storage for the data read - std::vector data_read; - - // open file for reading - FILE *fh = NULL; - - // check if filestream is valid - if (FOPEN_FAIL(FOPEN(fh, filename, "r"))) { - printf("Unable to open input file: %s\n", filename); - return false; - } - - // read all data elements - T token; - - while (!feof(fh)) { - fscanf(fh, "%f", &token); - data_read.push_back(token); - } - - // the last element is read twice - data_read.pop_back(); - fclose(fh); - - // check if the given handle is already initialized - if (NULL != *data) { - if (*len != data_read.size()) { - std::cerr << "sdkReadFile() : Initialized memory given but " - << "size mismatch with signal read " - << "(data read / data init = " << (unsigned int)data_read.size() << " / " - << *len << ")" << std::endl; - - return false; - } - } else { - // allocate storage for the data read - *data = reinterpret_cast(malloc(sizeof(T) * data_read.size())); - // store signal size - *len = static_cast(data_read.size()); - } - - // copy data - memcpy(*data, &data_read.front(), sizeof(T) * data_read.size()); - - return true; + // check input arguments + assert(NULL != filename); + assert(NULL != len); + + // intermediate storage for the data read + std::vector data_read; + + // open file for reading + FILE *fh = NULL; + + // check if filestream is valid + if (FOPEN_FAIL(FOPEN(fh, filename, "r"))) { + printf("Unable to open input file: %s\n", filename); + return false; + } + + // read all data elements + T token; + + while (!feof(fh)) { + fscanf(fh, "%f", &token); + data_read.push_back(token); + } + + // the last element is read twice + data_read.pop_back(); + fclose(fh); + + // check if the given handle is already initialized + if (NULL != *data) { + if (*len != data_read.size()) { + std::cerr << "sdkReadFile() : Initialized memory given but " + << "size mismatch with signal read " + << "(data read / data init = " << (unsigned int)data_read.size() << " / " + << *len << ")" << std::endl; + + return false; + } + } else { + // allocate storage for the data read + *data = reinterpret_cast(malloc(sizeof(T) * data_read.size())); + // store signal size + *len = static_cast(data_read.size()); + } + + // copy data + memcpy(*data, &data_read.front(), sizeof(T) * data_read.size()); + + return true; } ////////////////////////////////////////////////////////////////////////////// @@ -391,30 +391,30 @@ inline bool sdkReadFile(const char *filename, T **data, unsigned int *len, bool ////////////////////////////////////////////////////////////////////////////// template inline bool sdkReadFileBlocks(const char *filename, T **data, unsigned int *len, - unsigned int block_num, unsigned int block_size, bool verbose) { - // check input arguments - assert(NULL != filename); - assert(NULL != len); + unsigned int block_num, unsigned int block_size, bool verbose) { + // check input arguments + assert(NULL != filename); + assert(NULL != len); - // open file for reading - FILE *fh = fopen(filename, "rb"); + // open file for reading + FILE *fh = fopen(filename, "rb"); - if (fh == NULL && verbose) { - std::cerr << "sdkReadFile() : Opening file failed." << std::endl; - return false; - } + if (fh == NULL && verbose) { + std::cerr << "sdkReadFile() : Opening file failed." << std::endl; + return false; + } - // check if the given handle is already initialized - // allocate storage for the data read - data[block_num] = reinterpret_cast(malloc(block_size)); + // check if the given handle is already initialized + // allocate storage for the data read + data[block_num] = reinterpret_cast(malloc(block_size)); - // read all data elements - fseek(fh, block_num * block_size, SEEK_SET); - *len = fread(data[block_num], sizeof(T), block_size / sizeof(T), fh); + // read all data elements + fseek(fh, block_num * block_size, SEEK_SET); + *len = fread(data[block_num], sizeof(T), block_size / sizeof(T), fh); - fclose(fh); + fclose(fh); - return true; + return true; } ////////////////////////////////////////////////////////////////////////////// @@ -427,51 +427,51 @@ inline bool sdkReadFileBlocks(const char *filename, T **data, unsigned int *len, ////////////////////////////////////////////////////////////////////////////// template inline bool sdkWriteFile(const char *filename, const T *data, unsigned int len, const S epsilon, - bool verbose, bool append = false) { - assert(NULL != filename); - assert(NULL != data); - - // open file for writing - // if (append) { - std::fstream fh(filename, std::fstream::out | std::fstream::ate); - - if (verbose) { - std::cerr << "sdkWriteFile() : Open file " << filename << " for write/append." << std::endl; - } - - /* } else { - std::fstream fh(filename, std::fstream::out); - if (verbose) { - std::cerr << "sdkWriteFile() : Open file " << filename << " for - write." << std::endl; - } - } - */ + bool verbose, bool append = false) { + assert(NULL != filename); + assert(NULL != data); + + // open file for writing + // if (append) { + std::fstream fh(filename, std::fstream::out | std::fstream::ate); - // check if filestream is valid - if (!fh.good()) { - if (verbose) { std::cerr << "sdkWriteFile() : Opening file failed." << std::endl; } + if (verbose) { + std::cerr << "sdkWriteFile() : Open file " << filename << " for write/append." << std::endl; + } - return false; - } + /* } else { + std::fstream fh(filename, std::fstream::out); + if (verbose) { + std::cerr << "sdkWriteFile() : Open file " << filename << " for + write." << std::endl; + } + } + */ - // first write epsilon - fh << "# " << epsilon << "\n"; + // check if filestream is valid + if (!fh.good()) { + if (verbose) { std::cerr << "sdkWriteFile() : Opening file failed." << std::endl; } - // write data - for (unsigned int i = 0; (i < len) && (fh.good()); ++i) { fh << data[i] << ' '; } + return false; + } - // Check if writing succeeded - if (!fh.good()) { - if (verbose) { std::cerr << "sdkWriteFile() : Writing file failed." << std::endl; } + // first write epsilon + fh << "# " << epsilon << "\n"; - return false; - } + // write data + for (unsigned int i = 0; (i < len) && (fh.good()); ++i) { fh << data[i] << ' '; } - // file ends with nl - fh << std::endl; + // Check if writing succeeded + if (!fh.good()) { + if (verbose) { std::cerr << "sdkWriteFile() : Writing file failed." << std::endl; } - return true; + return false; + } + + // file ends with nl + fh << std::endl; + + return true; } ////////////////////////////////////////////////////////////////////////////// @@ -484,18 +484,18 @@ inline bool sdkWriteFile(const char *filename, const T *data, unsigned int len, ////////////////////////////////////////////////////////////////////////////// template inline bool compareData(const T *reference, const T *data, const unsigned int len, const S epsilon, - const float threshold) { - assert(epsilon >= 0); + const float threshold) { + assert(epsilon >= 0); - bool result = true; - unsigned int error_count = 0; + bool result = true; + unsigned int error_count = 0; - for (unsigned int i = 0; i < len; ++i) { - float diff = static_cast(reference[i]) - static_cast(data[i]); - bool comp = (diff <= epsilon) && (diff >= -epsilon); - result &= comp; + for (unsigned int i = 0; i < len; ++i) { + float diff = static_cast(reference[i]) - static_cast(data[i]); + bool comp = (diff <= epsilon) && (diff >= -epsilon); + result &= comp; - error_count += !comp; + error_count += !comp; #if 0 @@ -507,23 +507,23 @@ inline bool compareData(const T *reference, const T *data, const unsigned int le } #endif - } - - if (threshold == 0.0f) { - return (result) ? true : false; - } else { - if (error_count) { - printf("%4.2f(%%) of bytes mismatched (count=%d)\n", - static_cast(error_count) * 100 / static_cast(len), - error_count); - } - - return (len * threshold > error_count) ? true : false; - } + } + + if (threshold == 0.0f) { + return (result) ? true : false; + } else { + if (error_count) { + printf("%4.2f(%%) of bytes mismatched (count=%d)\n", + static_cast(error_count) * 100 / static_cast(len), + error_count); + } + + return (len * threshold > error_count) ? true : false; + } } #ifndef __MIN_EPSILON_ERROR -# define __MIN_EPSILON_ERROR 1e-3f +# define __MIN_EPSILON_ERROR 1e-3f #endif ////////////////////////////////////////////////////////////////////////////// @@ -537,387 +537,387 @@ inline bool compareData(const T *reference, const T *data, const unsigned int le ////////////////////////////////////////////////////////////////////////////// template inline bool compareDataAsFloatThreshold(const T *reference, const T *data, const unsigned int len, - const S epsilon, const float threshold) { - assert(epsilon >= 0); - - // If we set epsilon to be 0, let's set a minimum threshold - float max_error = MAX((float)epsilon, __MIN_EPSILON_ERROR); - int error_count = 0; - bool result = true; - - for (unsigned int i = 0; i < len; ++i) { - float diff = fabs(static_cast(reference[i]) - static_cast(data[i])); - bool comp = (diff < max_error); - result &= comp; - - if (!comp) { error_count++; } - } - - if (threshold == 0.0f) { - if (error_count) { printf("total # of errors = %d\n", error_count); } - - return (error_count == 0) ? true : false; - } else { - if (error_count) { - printf("%4.2f(%%) of bytes mismatched (count=%d)\n", - static_cast(error_count) * 100 / static_cast(len), - error_count); - } - - return ((len * threshold > error_count) ? true : false); - } + const S epsilon, const float threshold) { + assert(epsilon >= 0); + + // If we set epsilon to be 0, let's set a minimum threshold + float max_error = MAX((float)epsilon, __MIN_EPSILON_ERROR); + int error_count = 0; + bool result = true; + + for (unsigned int i = 0; i < len; ++i) { + float diff = fabs(static_cast(reference[i]) - static_cast(data[i])); + bool comp = (diff < max_error); + result &= comp; + + if (!comp) { error_count++; } + } + + if (threshold == 0.0f) { + if (error_count) { printf("total # of errors = %d\n", error_count); } + + return (error_count == 0) ? true : false; + } else { + if (error_count) { + printf("%4.2f(%%) of bytes mismatched (count=%d)\n", + static_cast(error_count) * 100 / static_cast(len), + error_count); + } + + return ((len * threshold > error_count) ? true : false); + } } inline void sdkDumpBin(void *data, unsigned int bytes, const char *filename) { - printf("sdkDumpBin: <%s>\n", filename); - FILE *fp; - FOPEN(fp, filename, "wb"); - fwrite(data, bytes, 1, fp); - fflush(fp); - fclose(fp); + printf("sdkDumpBin: <%s>\n", filename); + FILE *fp; + FOPEN(fp, filename, "wb"); + fwrite(data, bytes, 1, fp); + fflush(fp); + fclose(fp); } inline bool sdkCompareBin2BinUint(const char *src_file, const char *ref_file, - unsigned int nelements, const float epsilon, - const float threshold, char *exec_path) { - unsigned int *src_buffer, *ref_buffer; - FILE *src_fp = NULL, *ref_fp = NULL; - - uint64_t error_count = 0; - size_t fsize = 0; - - if (FOPEN_FAIL(FOPEN(src_fp, src_file, "rb"))) { - printf("compareBin2Bin unable to open src_file: %s\n", src_file); - error_count++; - } - - char *ref_file_path = sdkFindFilePath(ref_file, exec_path); - - if (ref_file_path == NULL) { - printf("compareBin2Bin unable to find <%s> in <%s>\n", ref_file, exec_path); - printf(">>> Check info.xml and [project//data] folder <%s> <<<\n", ref_file); - printf("Aborting comparison!\n"); - printf(" FAILED\n"); - error_count++; - - if (src_fp) { fclose(src_fp); } - - if (ref_fp) { fclose(ref_fp); } - } else { - if (FOPEN_FAIL(FOPEN(ref_fp, ref_file_path, "rb"))) { - printf( - "compareBin2Bin " - " unable to open ref_file: %s\n", - ref_file_path); - error_count++; - } - - if (src_fp && ref_fp) { - src_buffer = (unsigned int *)malloc(nelements * sizeof(unsigned int)); - ref_buffer = (unsigned int *)malloc(nelements * sizeof(unsigned int)); - - fsize = fread(src_buffer, nelements, sizeof(unsigned int), src_fp); - fsize = fread(ref_buffer, nelements, sizeof(unsigned int), ref_fp); - - printf( - "> compareBin2Bin nelements=%d," - " epsilon=%4.2f, threshold=%4.2f\n", - nelements, - epsilon, - threshold); - printf(" src_file <%s>, size=%d bytes\n", src_file, static_cast(fsize)); - printf(" ref_file <%s>, size=%d bytes\n", ref_file_path, static_cast(fsize)); - - if (!compareData( - ref_buffer, src_buffer, nelements, epsilon, threshold)) { - error_count++; - } - - fclose(src_fp); - fclose(ref_fp); - - free(src_buffer); - free(ref_buffer); - } else { - if (src_fp) { fclose(src_fp); } - - if (ref_fp) { fclose(ref_fp); } - } - } - - if (error_count == 0) { - printf(" OK\n"); - } else { - printf(" FAILURE: %d errors...\n", (unsigned int)error_count); - } - - return (error_count == 0); // returns true if all pixels pass + unsigned int nelements, const float epsilon, + const float threshold, char *exec_path) { + unsigned int *src_buffer, *ref_buffer; + FILE *src_fp = NULL, *ref_fp = NULL; + + uint64_t error_count = 0; + size_t fsize = 0; + + if (FOPEN_FAIL(FOPEN(src_fp, src_file, "rb"))) { + printf("compareBin2Bin unable to open src_file: %s\n", src_file); + error_count++; + } + + char *ref_file_path = sdkFindFilePath(ref_file, exec_path); + + if (ref_file_path == NULL) { + printf("compareBin2Bin unable to find <%s> in <%s>\n", ref_file, exec_path); + printf(">>> Check info.xml and [project//data] folder <%s> <<<\n", ref_file); + printf("Aborting comparison!\n"); + printf(" FAILED\n"); + error_count++; + + if (src_fp) { fclose(src_fp); } + + if (ref_fp) { fclose(ref_fp); } + } else { + if (FOPEN_FAIL(FOPEN(ref_fp, ref_file_path, "rb"))) { + printf( + "compareBin2Bin " + " unable to open ref_file: %s\n", + ref_file_path); + error_count++; + } + + if (src_fp && ref_fp) { + src_buffer = (unsigned int *)malloc(nelements * sizeof(unsigned int)); + ref_buffer = (unsigned int *)malloc(nelements * sizeof(unsigned int)); + + fsize = fread(src_buffer, nelements, sizeof(unsigned int), src_fp); + fsize = fread(ref_buffer, nelements, sizeof(unsigned int), ref_fp); + + printf( + "> compareBin2Bin nelements=%d," + " epsilon=%4.2f, threshold=%4.2f\n", + nelements, + epsilon, + threshold); + printf(" src_file <%s>, size=%d bytes\n", src_file, static_cast(fsize)); + printf(" ref_file <%s>, size=%d bytes\n", ref_file_path, static_cast(fsize)); + + if (!compareData( + ref_buffer, src_buffer, nelements, epsilon, threshold)) { + error_count++; + } + + fclose(src_fp); + fclose(ref_fp); + + free(src_buffer); + free(ref_buffer); + } else { + if (src_fp) { fclose(src_fp); } + + if (ref_fp) { fclose(ref_fp); } + } + } + + if (error_count == 0) { + printf(" OK\n"); + } else { + printf(" FAILURE: %d errors...\n", (unsigned int)error_count); + } + + return (error_count == 0); // returns true if all pixels pass } inline bool sdkCompareBin2BinFloat(const char *src_file, const char *ref_file, - unsigned int nelements, const float epsilon, - const float threshold, char *exec_path) { - float *src_buffer = NULL, *ref_buffer = NULL; - FILE *src_fp = NULL, *ref_fp = NULL; - size_t fsize = 0; - - uint64_t error_count = 0; - - if (FOPEN_FAIL(FOPEN(src_fp, src_file, "rb"))) { - printf("compareBin2Bin unable to open src_file: %s\n", src_file); - error_count = 1; - } - - char *ref_file_path = sdkFindFilePath(ref_file, exec_path); - - if (ref_file_path == NULL) { - printf("compareBin2Bin unable to find <%s> in <%s>\n", ref_file, exec_path); - printf(">>> Check info.xml and [project//data] folder <%s> <<<\n", exec_path); - printf("Aborting comparison!\n"); - printf(" FAILED\n"); - error_count++; - - if (src_fp) { fclose(src_fp); } - - if (ref_fp) { fclose(ref_fp); } - } else { - if (FOPEN_FAIL(FOPEN(ref_fp, ref_file_path, "rb"))) { - printf("compareBin2Bin unable to open ref_file: %s\n", ref_file_path); - error_count = 1; - } - - if (src_fp && ref_fp) { - src_buffer = reinterpret_cast(malloc(nelements * sizeof(float))); - ref_buffer = reinterpret_cast(malloc(nelements * sizeof(float))); - - printf( - "> compareBin2Bin nelements=%d, epsilon=%4.2f," - " threshold=%4.2f\n", - nelements, - epsilon, - threshold); - fsize = fread(src_buffer, sizeof(float), nelements, src_fp); - printf(" src_file <%s>, size=%d bytes\n", - src_file, - static_cast(fsize * sizeof(float))); - fsize = fread(ref_buffer, sizeof(float), nelements, ref_fp); - printf(" ref_file <%s>, size=%d bytes\n", - ref_file_path, - static_cast(fsize * sizeof(float))); - - if (!compareDataAsFloatThreshold( - ref_buffer, src_buffer, nelements, epsilon, threshold)) { - error_count++; - } - - fclose(src_fp); - fclose(ref_fp); - - free(src_buffer); - free(ref_buffer); - } else { - if (src_fp) { fclose(src_fp); } - - if (ref_fp) { fclose(ref_fp); } - } - } - - if (error_count == 0) { - printf(" OK\n"); - } else { - printf(" FAILURE: %d errors...\n", (unsigned int)error_count); - } - - return (error_count == 0); // returns true if all pixels pass + unsigned int nelements, const float epsilon, + const float threshold, char *exec_path) { + float *src_buffer = NULL, *ref_buffer = NULL; + FILE *src_fp = NULL, *ref_fp = NULL; + size_t fsize = 0; + + uint64_t error_count = 0; + + if (FOPEN_FAIL(FOPEN(src_fp, src_file, "rb"))) { + printf("compareBin2Bin unable to open src_file: %s\n", src_file); + error_count = 1; + } + + char *ref_file_path = sdkFindFilePath(ref_file, exec_path); + + if (ref_file_path == NULL) { + printf("compareBin2Bin unable to find <%s> in <%s>\n", ref_file, exec_path); + printf(">>> Check info.xml and [project//data] folder <%s> <<<\n", exec_path); + printf("Aborting comparison!\n"); + printf(" FAILED\n"); + error_count++; + + if (src_fp) { fclose(src_fp); } + + if (ref_fp) { fclose(ref_fp); } + } else { + if (FOPEN_FAIL(FOPEN(ref_fp, ref_file_path, "rb"))) { + printf("compareBin2Bin unable to open ref_file: %s\n", ref_file_path); + error_count = 1; + } + + if (src_fp && ref_fp) { + src_buffer = reinterpret_cast(malloc(nelements * sizeof(float))); + ref_buffer = reinterpret_cast(malloc(nelements * sizeof(float))); + + printf( + "> compareBin2Bin nelements=%d, epsilon=%4.2f," + " threshold=%4.2f\n", + nelements, + epsilon, + threshold); + fsize = fread(src_buffer, sizeof(float), nelements, src_fp); + printf(" src_file <%s>, size=%d bytes\n", + src_file, + static_cast(fsize * sizeof(float))); + fsize = fread(ref_buffer, sizeof(float), nelements, ref_fp); + printf(" ref_file <%s>, size=%d bytes\n", + ref_file_path, + static_cast(fsize * sizeof(float))); + + if (!compareDataAsFloatThreshold( + ref_buffer, src_buffer, nelements, epsilon, threshold)) { + error_count++; + } + + fclose(src_fp); + fclose(ref_fp); + + free(src_buffer); + free(ref_buffer); + } else { + if (src_fp) { fclose(src_fp); } + + if (ref_fp) { fclose(ref_fp); } + } + } + + if (error_count == 0) { + printf(" OK\n"); + } else { + printf(" FAILURE: %d errors...\n", (unsigned int)error_count); + } + + return (error_count == 0); // returns true if all pixels pass } inline bool sdkCompareL2fe(const float *reference, const float *data, const unsigned int len, - const float epsilon) { - assert(epsilon >= 0); + const float epsilon) { + assert(epsilon >= 0); - float error = 0; - float ref = 0; + float error = 0; + float ref = 0; - for (unsigned int i = 0; i < len; ++i) { - float diff = reference[i] - data[i]; - error += diff * diff; - ref += reference[i] * reference[i]; - } + for (unsigned int i = 0; i < len; ++i) { + float diff = reference[i] - data[i]; + error += diff * diff; + ref += reference[i] * reference[i]; + } - float normRef = sqrtf(ref); + float normRef = sqrtf(ref); - if (fabs(ref) < 1e-7) { + if (fabs(ref) < 1e-7) { #ifdef _DEBUG - std::cerr << "ERROR, reference l2-norm is 0\n"; + std::cerr << "ERROR, reference l2-norm is 0\n"; #endif - return false; - } + return false; + } - float normError = sqrtf(error); - error = normError / normRef; - bool result = error < epsilon; + float normError = sqrtf(error); + error = normError / normRef; + bool result = error < epsilon; #ifdef _DEBUG - if (!result) { - std::cerr << "ERROR, l2-norm error " << error << " is greater than epsilon " << epsilon - << "\n"; - } + if (!result) { + std::cerr << "ERROR, l2-norm error " << error << " is greater than epsilon " << epsilon + << "\n"; + } #endif - return result; + return result; } inline bool sdkLoadPPMub(const char *file, unsigned char **data, unsigned int *w, unsigned int *h) { - unsigned int channels; - return __loadPPM(file, data, w, h, &channels); + unsigned int channels; + return __loadPPM(file, data, w, h, &channels); } inline bool sdkLoadPPM4ub(const char *file, unsigned char **data, unsigned int *w, - unsigned int *h) { - unsigned char *idata = 0; - unsigned int channels; - - if (__loadPPM(file, &idata, w, h, &channels)) { - // pad 4th component - int size = *w * *h; - // keep the original pointer - unsigned char *idata_orig = idata; - *data = (unsigned char *)malloc(sizeof(unsigned char) * size * 4); - unsigned char *ptr = *data; - - for (int i = 0; i < size; i++) { - *ptr++ = *idata++; - *ptr++ = *idata++; - *ptr++ = *idata++; - *ptr++ = 0; - } - - free(idata_orig); - return true; - } else { - free(idata); - return false; - } + unsigned int *h) { + unsigned char *idata = 0; + unsigned int channels; + + if (__loadPPM(file, &idata, w, h, &channels)) { + // pad 4th component + int size = *w * *h; + // keep the original pointer + unsigned char *idata_orig = idata; + *data = (unsigned char *)malloc(sizeof(unsigned char) * size * 4); + unsigned char *ptr = *data; + + for (int i = 0; i < size; i++) { + *ptr++ = *idata++; + *ptr++ = *idata++; + *ptr++ = *idata++; + *ptr++ = 0; + } + + free(idata_orig); + return true; + } else { + free(idata); + return false; + } } inline bool sdkComparePPM(const char *src_file, const char *ref_file, const float epsilon, - const float threshold, bool verboseErrors) { - unsigned char *src_data, *ref_data; - uint64_t error_count = 0; - unsigned int ref_width, ref_height; - unsigned int src_width, src_height; - - if (src_file == NULL || ref_file == NULL) { - if (verboseErrors) { - std::cerr << "PPMvsPPM: src_file or ref_file is NULL." - " Aborting comparison\n"; - } - - return false; - } - - if (verboseErrors) { - std::cerr << "> Compare (a)rendered: <" << src_file << ">\n"; - std::cerr << "> (b)reference: <" << ref_file << ">\n"; - } - - if (sdkLoadPPM4ub(ref_file, &ref_data, &ref_width, &ref_height) != true) { - if (verboseErrors) { - std::cerr << "PPMvsPPM: unable to load ref image file: " << ref_file << "\n"; - } - - return false; - } - - if (sdkLoadPPM4ub(src_file, &src_data, &src_width, &src_height) != true) { - std::cerr << "PPMvsPPM: unable to load src image file: " << src_file << "\n"; - return false; - } - - if (src_height != ref_height || src_width != ref_width) { - if (verboseErrors) { - std::cerr << "PPMvsPPM: source and ref size mismatch (" << src_width << "," - << src_height << ")vs(" << ref_width << "," << ref_height << ")\n"; - } - } - - if (verboseErrors) { - std::cerr << "PPMvsPPM: comparing images size (" << src_width << "," << src_height - << ") epsilon(" << epsilon << "), threshold(" << threshold * 100 << "%)\n"; - } - - if (compareData(ref_data, src_data, src_width * src_height * 4, epsilon, threshold) == false) { - error_count = 1; - } - - if (error_count == 0) { - if (verboseErrors) { std::cerr << " OK\n\n"; } - } else { - if (verboseErrors) { std::cerr << " FAILURE! " << error_count << " errors...\n\n"; } - } - - // returns true if all pixels pass - return (error_count == 0) ? true : false; + const float threshold, bool verboseErrors) { + unsigned char *src_data, *ref_data; + uint64_t error_count = 0; + unsigned int ref_width, ref_height; + unsigned int src_width, src_height; + + if (src_file == NULL || ref_file == NULL) { + if (verboseErrors) { + std::cerr << "PPMvsPPM: src_file or ref_file is NULL." + " Aborting comparison\n"; + } + + return false; + } + + if (verboseErrors) { + std::cerr << "> Compare (a)rendered: <" << src_file << ">\n"; + std::cerr << "> (b)reference: <" << ref_file << ">\n"; + } + + if (sdkLoadPPM4ub(ref_file, &ref_data, &ref_width, &ref_height) != true) { + if (verboseErrors) { + std::cerr << "PPMvsPPM: unable to load ref image file: " << ref_file << "\n"; + } + + return false; + } + + if (sdkLoadPPM4ub(src_file, &src_data, &src_width, &src_height) != true) { + std::cerr << "PPMvsPPM: unable to load src image file: " << src_file << "\n"; + return false; + } + + if (src_height != ref_height || src_width != ref_width) { + if (verboseErrors) { + std::cerr << "PPMvsPPM: source and ref size mismatch (" << src_width << "," + << src_height << ")vs(" << ref_width << "," << ref_height << ")\n"; + } + } + + if (verboseErrors) { + std::cerr << "PPMvsPPM: comparing images size (" << src_width << "," << src_height + << ") epsilon(" << epsilon << "), threshold(" << threshold * 100 << "%)\n"; + } + + if (compareData(ref_data, src_data, src_width * src_height * 4, epsilon, threshold) == false) { + error_count = 1; + } + + if (error_count == 0) { + if (verboseErrors) { std::cerr << " OK\n\n"; } + } else { + if (verboseErrors) { std::cerr << " FAILURE! " << error_count << " errors...\n\n"; } + } + + // returns true if all pixels pass + return (error_count == 0) ? true : false; } inline bool sdkComparePGM(const char *src_file, const char *ref_file, const float epsilon, - const float threshold, bool verboseErrors) { - unsigned char *src_data = 0, *ref_data = 0; - uint64_t error_count = 0; - unsigned int ref_width, ref_height; - unsigned int src_width, src_height; - - if (src_file == NULL || ref_file == NULL) { - if (verboseErrors) { - std::cerr << "PGMvsPGM: src_file or ref_file is NULL." - " Aborting comparison\n"; - } - - return false; - } - - if (verboseErrors) { - std::cerr << "> Compare (a)rendered: <" << src_file << ">\n"; - std::cerr << "> (b)reference: <" << ref_file << ">\n"; - } - - if (sdkLoadPPMub(ref_file, &ref_data, &ref_width, &ref_height) != true) { - if (verboseErrors) { - std::cerr << "PGMvsPGM: unable to load ref image file: " << ref_file << "\n"; - } - - return false; - } - - if (sdkLoadPPMub(src_file, &src_data, &src_width, &src_height) != true) { - std::cerr << "PGMvsPGM: unable to load src image file: " << src_file << "\n"; - return false; - } - - if (src_height != ref_height || src_width != ref_width) { - if (verboseErrors) { - std::cerr << "PGMvsPGM: source and ref size mismatch (" << src_width << "," - << src_height << ")vs(" << ref_width << "," << ref_height << ")\n"; - } - } - - if (verboseErrors) - std::cerr << "PGMvsPGM: comparing images size (" << src_width << "," << src_height - << ") epsilon(" << epsilon << "), threshold(" << threshold * 100 << "%)\n"; - - if (compareData(ref_data, src_data, src_width * src_height, epsilon, threshold) == false) { - error_count = 1; - } - - if (error_count == 0) { - if (verboseErrors) { std::cerr << " OK\n\n"; } - } else { - if (verboseErrors) { std::cerr << " FAILURE! " << error_count << " errors...\n\n"; } - } - - // returns true if all pixels pass - return (error_count == 0) ? true : false; + const float threshold, bool verboseErrors) { + unsigned char *src_data = 0, *ref_data = 0; + uint64_t error_count = 0; + unsigned int ref_width, ref_height; + unsigned int src_width, src_height; + + if (src_file == NULL || ref_file == NULL) { + if (verboseErrors) { + std::cerr << "PGMvsPGM: src_file or ref_file is NULL." + " Aborting comparison\n"; + } + + return false; + } + + if (verboseErrors) { + std::cerr << "> Compare (a)rendered: <" << src_file << ">\n"; + std::cerr << "> (b)reference: <" << ref_file << ">\n"; + } + + if (sdkLoadPPMub(ref_file, &ref_data, &ref_width, &ref_height) != true) { + if (verboseErrors) { + std::cerr << "PGMvsPGM: unable to load ref image file: " << ref_file << "\n"; + } + + return false; + } + + if (sdkLoadPPMub(src_file, &src_data, &src_width, &src_height) != true) { + std::cerr << "PGMvsPGM: unable to load src image file: " << src_file << "\n"; + return false; + } + + if (src_height != ref_height || src_width != ref_width) { + if (verboseErrors) { + std::cerr << "PGMvsPGM: source and ref size mismatch (" << src_width << "," + << src_height << ")vs(" << ref_width << "," << ref_height << ")\n"; + } + } + + if (verboseErrors) + std::cerr << "PGMvsPGM: comparing images size (" << src_width << "," << src_height + << ") epsilon(" << epsilon << "), threshold(" << threshold * 100 << "%)\n"; + + if (compareData(ref_data, src_data, src_width * src_height, epsilon, threshold) == false) { + error_count = 1; + } + + if (error_count == 0) { + if (verboseErrors) { std::cerr << " OK\n\n"; } + } else { + if (verboseErrors) { std::cerr << " FAILURE! " << error_count << " errors...\n\n"; } + } + + // returns true if all pixels pass + return (error_count == 0) ? true : false; } #endif // COMMON_HELPER_IMAGE_H_ diff --git a/librapid/include/librapid/cuda/helper_math.h b/librapid/include/librapid/cuda/helper_math.h index bcc4885c..5eda2fd3 100644 --- a/librapid/include/librapid/cuda/helper_math.h +++ b/librapid/include/librapid/cuda/helper_math.h @@ -46,12 +46,12 @@ typedef unsigned int uint; typedef unsigned short ushort; #ifndef EXIT_WAIVED -# define EXIT_WAIVED 2 +# define EXIT_WAIVED 2 #endif #ifndef __CUDACC__ -# include +# include //////////////////////////////////////////////////////////////////////////////// // host implementations of CUDA functions @@ -77,287 +77,287 @@ inline __host__ __device__ float2 make_float2(float s) { - return make_float2(s, s); + return make_float2(s, s); } inline __host__ __device__ float2 make_float2(float3 a) { - return make_float2(a.x, a.y); + return make_float2(a.x, a.y); } inline __host__ __device__ float2 make_float2(int2 a) { - return make_float2(float(a.x), float(a.y)); + return make_float2(float(a.x), float(a.y)); } inline __host__ __device__ float2 make_float2(uint2 a) { - return make_float2(float(a.x), float(a.y)); + return make_float2(float(a.x), float(a.y)); } inline __host__ __device__ int2 make_int2(int s) { - return make_int2(s, s); + return make_int2(s, s); } inline __host__ __device__ int2 make_int2(int3 a) { - return make_int2(a.x, a.y); + return make_int2(a.x, a.y); } inline __host__ __device__ int2 make_int2(uint2 a) { - return make_int2(int(a.x), int(a.y)); + return make_int2(int(a.x), int(a.y)); } inline __host__ __device__ int2 make_int2(float2 a) { - return make_int2(int(a.x), int(a.y)); + return make_int2(int(a.x), int(a.y)); } inline __host__ __device__ uint2 make_uint2(uint s) { - return make_uint2(s, s); + return make_uint2(s, s); } inline __host__ __device__ uint2 make_uint2(uint3 a) { - return make_uint2(a.x, a.y); + return make_uint2(a.x, a.y); } inline __host__ __device__ uint2 make_uint2(int2 a) { - return make_uint2(uint(a.x), uint(a.y)); + return make_uint2(uint(a.x), uint(a.y)); } inline __host__ __device__ float3 make_float3(float s) { - return make_float3(s, s, s); + return make_float3(s, s, s); } inline __host__ __device__ float3 make_float3(float2 a) { - return make_float3(a.x, a.y, 0.0f); + return make_float3(a.x, a.y, 0.0f); } inline __host__ __device__ float3 make_float3(float2 a, float s) { - return make_float3(a.x, a.y, s); + return make_float3(a.x, a.y, s); } inline __host__ __device__ float3 make_float3(float4 a) { - return make_float3(a.x, a.y, a.z); + return make_float3(a.x, a.y, a.z); } inline __host__ __device__ float3 make_float3(int3 a) { - return make_float3(float(a.x), float(a.y), float(a.z)); + return make_float3(float(a.x), float(a.y), float(a.z)); } inline __host__ __device__ float3 make_float3(uint3 a) { - return make_float3(float(a.x), float(a.y), float(a.z)); + return make_float3(float(a.x), float(a.y), float(a.z)); } inline __host__ __device__ int3 make_int3(int s) { - return make_int3(s, s, s); + return make_int3(s, s, s); } inline __host__ __device__ int3 make_int3(int2 a) { - return make_int3(a.x, a.y, 0); + return make_int3(a.x, a.y, 0); } inline __host__ __device__ int3 make_int3(int2 a, int s) { - return make_int3(a.x, a.y, s); + return make_int3(a.x, a.y, s); } inline __host__ __device__ int3 make_int3(uint3 a) { - return make_int3(int(a.x), int(a.y), int(a.z)); + return make_int3(int(a.x), int(a.y), int(a.z)); } inline __host__ __device__ int3 make_int3(float3 a) { - return make_int3(int(a.x), int(a.y), int(a.z)); + return make_int3(int(a.x), int(a.y), int(a.z)); } inline __host__ __device__ uint3 make_uint3(uint s) { - return make_uint3(s, s, s); + return make_uint3(s, s, s); } inline __host__ __device__ uint3 make_uint3(uint2 a) { - return make_uint3(a.x, a.y, 0); + return make_uint3(a.x, a.y, 0); } inline __host__ __device__ uint3 make_uint3(uint2 a, uint s) { - return make_uint3(a.x, a.y, s); + return make_uint3(a.x, a.y, s); } inline __host__ __device__ uint3 make_uint3(uint4 a) { - return make_uint3(a.x, a.y, a.z); + return make_uint3(a.x, a.y, a.z); } inline __host__ __device__ uint3 make_uint3(int3 a) { - return make_uint3(uint(a.x), uint(a.y), uint(a.z)); + return make_uint3(uint(a.x), uint(a.y), uint(a.z)); } inline __host__ __device__ float4 make_float4(float s) { - return make_float4(s, s, s, s); + return make_float4(s, s, s, s); } inline __host__ __device__ float4 make_float4(float3 a) { - return make_float4(a.x, a.y, a.z, 0.0f); + return make_float4(a.x, a.y, a.z, 0.0f); } inline __host__ __device__ float4 make_float4(float3 a, float w) { - return make_float4(a.x, a.y, a.z, w); + return make_float4(a.x, a.y, a.z, w); } inline __host__ __device__ float4 make_float4(int4 a) { - return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); + return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); } inline __host__ __device__ float4 make_float4(uint4 a) { - return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); + return make_float4(float(a.x), float(a.y), float(a.z), float(a.w)); } inline __host__ __device__ int4 make_int4(int s) { - return make_int4(s, s, s, s); + return make_int4(s, s, s, s); } inline __host__ __device__ int4 make_int4(int3 a) { - return make_int4(a.x, a.y, a.z, 0); + return make_int4(a.x, a.y, a.z, 0); } inline __host__ __device__ int4 make_int4(int3 a, int w) { - return make_int4(a.x, a.y, a.z, w); + return make_int4(a.x, a.y, a.z, w); } inline __host__ __device__ int4 make_int4(uint4 a) { - return make_int4(int(a.x), int(a.y), int(a.z), int(a.w)); + return make_int4(int(a.x), int(a.y), int(a.z), int(a.w)); } inline __host__ __device__ int4 make_int4(float4 a) { - return make_int4(int(a.x), int(a.y), int(a.z), int(a.w)); + return make_int4(int(a.x), int(a.y), int(a.z), int(a.w)); } inline __host__ __device__ uint4 make_uint4(uint s) { - return make_uint4(s, s, s, s); + return make_uint4(s, s, s, s); } inline __host__ __device__ uint4 make_uint4(uint3 a) { - return make_uint4(a.x, a.y, a.z, 0); + return make_uint4(a.x, a.y, a.z, 0); } inline __host__ __device__ uint4 make_uint4(uint3 a, uint w) { - return make_uint4(a.x, a.y, a.z, w); + return make_uint4(a.x, a.y, a.z, w); } inline __host__ __device__ uint4 make_uint4(int4 a) { - return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w)); + return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w)); } //////////////////////////////////////////////////////////////////////////////// @@ -368,42 +368,42 @@ inline __host__ __device__ float2 operator-(float2 &a) { - return make_float2(-a.x, -a.y); + return make_float2(-a.x, -a.y); } inline __host__ __device__ int2 operator-(int2 &a) { - return make_int2(-a.x, -a.y); + return make_int2(-a.x, -a.y); } inline __host__ __device__ float3 operator-(float3 &a) { - return make_float3(-a.x, -a.y, -a.z); + return make_float3(-a.x, -a.y, -a.z); } inline __host__ __device__ int3 operator-(int3 &a) { - return make_int3(-a.x, -a.y, -a.z); + return make_int3(-a.x, -a.y, -a.z); } inline __host__ __device__ float4 operator-(float4 &a) { - return make_float4(-a.x, -a.y, -a.z, -a.w); + return make_float4(-a.x, -a.y, -a.z, -a.w); } inline __host__ __device__ int4 operator-(int4 &a) { - return make_int4(-a.x, -a.y, -a.z, -a.w); + return make_int4(-a.x, -a.y, -a.z, -a.w); } //////////////////////////////////////////////////////////////////////////////// @@ -414,351 +414,351 @@ inline __host__ __device__ float2 operator+(float2 a, float2 b) { - return make_float2(a.x + b.x, a.y + b.y); + return make_float2(a.x + b.x, a.y + b.y); } inline __host__ __device__ void operator+=(float2 &a, float2 b) { - a.x += b.x; - a.y += b.y; + a.x += b.x; + a.y += b.y; } inline __host__ __device__ float2 operator+(float2 a, float b) { - return make_float2(a.x + b, a.y + b); + return make_float2(a.x + b, a.y + b); } inline __host__ __device__ float2 operator+(float b, float2 a) { - return make_float2(a.x + b, a.y + b); + return make_float2(a.x + b, a.y + b); } inline __host__ __device__ void operator+=(float2 &a, float b) { - a.x += b; - a.y += b; + a.x += b; + a.y += b; } inline __host__ __device__ int2 operator+(int2 a, int2 b) { - return make_int2(a.x + b.x, a.y + b.y); + return make_int2(a.x + b.x, a.y + b.y); } inline __host__ __device__ void operator+=(int2 &a, int2 b) { - a.x += b.x; - a.y += b.y; + a.x += b.x; + a.y += b.y; } inline __host__ __device__ int2 operator+(int2 a, int b) { - return make_int2(a.x + b, a.y + b); + return make_int2(a.x + b, a.y + b); } inline __host__ __device__ int2 operator+(int b, int2 a) { - return make_int2(a.x + b, a.y + b); + return make_int2(a.x + b, a.y + b); } inline __host__ __device__ void operator+=(int2 &a, int b) { - a.x += b; - a.y += b; + a.x += b; + a.y += b; } inline __host__ __device__ uint2 operator+(uint2 a, uint2 b) { - return make_uint2(a.x + b.x, a.y + b.y); + return make_uint2(a.x + b.x, a.y + b.y); } inline __host__ __device__ void operator+=(uint2 &a, uint2 b) { - a.x += b.x; - a.y += b.y; + a.x += b.x; + a.y += b.y; } inline __host__ __device__ uint2 operator+(uint2 a, uint b) { - return make_uint2(a.x + b, a.y + b); + return make_uint2(a.x + b, a.y + b); } inline __host__ __device__ uint2 operator+(uint b, uint2 a) { - return make_uint2(a.x + b, a.y + b); + return make_uint2(a.x + b, a.y + b); } inline __host__ __device__ void operator+=(uint2 &a, uint b) { - a.x += b; - a.y += b; + a.x += b; + a.y += b; } inline __host__ __device__ float3 operator+(float3 a, float3 b) { - return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); } inline __host__ __device__ void operator+=(float3 &a, float3 b) { - a.x += b.x; - a.y += b.y; - a.z += b.z; + a.x += b.x; + a.y += b.y; + a.z += b.z; } inline __host__ __device__ float3 operator+(float3 a, float b) { - return make_float3(a.x + b, a.y + b, a.z + b); + return make_float3(a.x + b, a.y + b, a.z + b); } inline __host__ __device__ void operator+=(float3 &a, float b) { - a.x += b; - a.y += b; - a.z += b; + a.x += b; + a.y += b; + a.z += b; } inline __host__ __device__ int3 operator+(int3 a, int3 b) { - return make_int3(a.x + b.x, a.y + b.y, a.z + b.z); + return make_int3(a.x + b.x, a.y + b.y, a.z + b.z); } inline __host__ __device__ void operator+=(int3 &a, int3 b) { - a.x += b.x; - a.y += b.y; - a.z += b.z; + a.x += b.x; + a.y += b.y; + a.z += b.z; } inline __host__ __device__ int3 operator+(int3 a, int b) { - return make_int3(a.x + b, a.y + b, a.z + b); + return make_int3(a.x + b, a.y + b, a.z + b); } inline __host__ __device__ void operator+=(int3 &a, int b) { - a.x += b; - a.y += b; - a.z += b; + a.x += b; + a.y += b; + a.z += b; } inline __host__ __device__ uint3 operator+(uint3 a, uint3 b) { - return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z); + return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z); } inline __host__ __device__ void operator+=(uint3 &a, uint3 b) { - a.x += b.x; - a.y += b.y; - a.z += b.z; + a.x += b.x; + a.y += b.y; + a.z += b.z; } inline __host__ __device__ uint3 operator+(uint3 a, uint b) { - return make_uint3(a.x + b, a.y + b, a.z + b); + return make_uint3(a.x + b, a.y + b, a.z + b); } inline __host__ __device__ void operator+=(uint3 &a, uint b) { - a.x += b; - a.y += b; - a.z += b; + a.x += b; + a.y += b; + a.z += b; } inline __host__ __device__ int3 operator+(int b, int3 a) { - return make_int3(a.x + b, a.y + b, a.z + b); + return make_int3(a.x + b, a.y + b, a.z + b); } inline __host__ __device__ uint3 operator+(uint b, uint3 a) { - return make_uint3(a.x + b, a.y + b, a.z + b); + return make_uint3(a.x + b, a.y + b, a.z + b); } inline __host__ __device__ float3 operator+(float b, float3 a) { - return make_float3(a.x + b, a.y + b, a.z + b); + return make_float3(a.x + b, a.y + b, a.z + b); } inline __host__ __device__ float4 operator+(float4 a, float4 b) { - return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); + return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } inline __host__ __device__ void operator+=(float4 &a, float4 b) { - a.x += b.x; - a.y += b.y; - a.z += b.z; - a.w += b.w; + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; } inline __host__ __device__ float4 operator+(float4 a, float b) { - return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); + return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); } inline __host__ __device__ float4 operator+(float b, float4 a) { - return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); + return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); } inline __host__ __device__ void operator+=(float4 &a, float b) { - a.x += b; - a.y += b; - a.z += b; - a.w += b; + a.x += b; + a.y += b; + a.z += b; + a.w += b; } inline __host__ __device__ int4 operator+(int4 a, int4 b) { - return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); + return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } inline __host__ __device__ void operator+=(int4 &a, int4 b) { - a.x += b.x; - a.y += b.y; - a.z += b.z; - a.w += b.w; + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; } inline __host__ __device__ int4 operator+(int4 a, int b) { - return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); + return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); } inline __host__ __device__ int4 operator+(int b, int4 a) { - return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); + return make_int4(a.x + b, a.y + b, a.z + b, a.w + b); } inline __host__ __device__ void operator+=(int4 &a, int b) { - a.x += b; - a.y += b; - a.z += b; - a.w += b; + a.x += b; + a.y += b; + a.z += b; + a.w += b; } inline __host__ __device__ uint4 operator+(uint4 a, uint4 b) { - return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); + return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } inline __host__ __device__ void operator+=(uint4 &a, uint4 b) { - a.x += b.x; - a.y += b.y; - a.z += b.z; - a.w += b.w; + a.x += b.x; + a.y += b.y; + a.z += b.z; + a.w += b.w; } inline __host__ __device__ uint4 operator+(uint4 a, uint b) { - return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); + return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); } inline __host__ __device__ uint4 operator+(uint b, uint4 a) { - return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); + return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b); } inline __host__ __device__ void operator+=(uint4 &a, uint b) { - a.x += b; - a.y += b; - a.z += b; - a.w += b; + a.x += b; + a.y += b; + a.z += b; + a.w += b; } //////////////////////////////////////////////////////////////////////////////// @@ -769,344 +769,344 @@ inline __host__ __device__ float2 operator-(float2 a, float2 b) { - return make_float2(a.x - b.x, a.y - b.y); + return make_float2(a.x - b.x, a.y - b.y); } inline __host__ __device__ void operator-=(float2 &a, float2 b) { - a.x -= b.x; - a.y -= b.y; + a.x -= b.x; + a.y -= b.y; } inline __host__ __device__ float2 operator-(float2 a, float b) { - return make_float2(a.x - b, a.y - b); + return make_float2(a.x - b, a.y - b); } inline __host__ __device__ float2 operator-(float b, float2 a) { - return make_float2(b - a.x, b - a.y); + return make_float2(b - a.x, b - a.y); } inline __host__ __device__ void operator-=(float2 &a, float b) { - a.x -= b; - a.y -= b; + a.x -= b; + a.y -= b; } inline __host__ __device__ int2 operator-(int2 a, int2 b) { - return make_int2(a.x - b.x, a.y - b.y); + return make_int2(a.x - b.x, a.y - b.y); } inline __host__ __device__ void operator-=(int2 &a, int2 b) { - a.x -= b.x; - a.y -= b.y; + a.x -= b.x; + a.y -= b.y; } inline __host__ __device__ int2 operator-(int2 a, int b) { - return make_int2(a.x - b, a.y - b); + return make_int2(a.x - b, a.y - b); } inline __host__ __device__ int2 operator-(int b, int2 a) { - return make_int2(b - a.x, b - a.y); + return make_int2(b - a.x, b - a.y); } inline __host__ __device__ void operator-=(int2 &a, int b) { - a.x -= b; - a.y -= b; + a.x -= b; + a.y -= b; } inline __host__ __device__ uint2 operator-(uint2 a, uint2 b) { - return make_uint2(a.x - b.x, a.y - b.y); + return make_uint2(a.x - b.x, a.y - b.y); } inline __host__ __device__ void operator-=(uint2 &a, uint2 b) { - a.x -= b.x; - a.y -= b.y; + a.x -= b.x; + a.y -= b.y; } inline __host__ __device__ uint2 operator-(uint2 a, uint b) { - return make_uint2(a.x - b, a.y - b); + return make_uint2(a.x - b, a.y - b); } inline __host__ __device__ uint2 operator-(uint b, uint2 a) { - return make_uint2(b - a.x, b - a.y); + return make_uint2(b - a.x, b - a.y); } inline __host__ __device__ void operator-=(uint2 &a, uint b) { - a.x -= b; - a.y -= b; + a.x -= b; + a.y -= b; } inline __host__ __device__ float3 operator-(float3 a, float3 b) { - return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); + return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); } inline __host__ __device__ void operator-=(float3 &a, float3 b) { - a.x -= b.x; - a.y -= b.y; - a.z -= b.z; + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; } inline __host__ __device__ float3 operator-(float3 a, float b) { - return make_float3(a.x - b, a.y - b, a.z - b); + return make_float3(a.x - b, a.y - b, a.z - b); } inline __host__ __device__ float3 operator-(float b, float3 a) { - return make_float3(b - a.x, b - a.y, b - a.z); + return make_float3(b - a.x, b - a.y, b - a.z); } inline __host__ __device__ void operator-=(float3 &a, float b) { - a.x -= b; - a.y -= b; - a.z -= b; + a.x -= b; + a.y -= b; + a.z -= b; } inline __host__ __device__ int3 operator-(int3 a, int3 b) { - return make_int3(a.x - b.x, a.y - b.y, a.z - b.z); + return make_int3(a.x - b.x, a.y - b.y, a.z - b.z); } inline __host__ __device__ void operator-=(int3 &a, int3 b) { - a.x -= b.x; - a.y -= b.y; - a.z -= b.z; + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; } inline __host__ __device__ int3 operator-(int3 a, int b) { - return make_int3(a.x - b, a.y - b, a.z - b); + return make_int3(a.x - b, a.y - b, a.z - b); } inline __host__ __device__ int3 operator-(int b, int3 a) { - return make_int3(b - a.x, b - a.y, b - a.z); + return make_int3(b - a.x, b - a.y, b - a.z); } inline __host__ __device__ void operator-=(int3 &a, int b) { - a.x -= b; - a.y -= b; - a.z -= b; + a.x -= b; + a.y -= b; + a.z -= b; } inline __host__ __device__ uint3 operator-(uint3 a, uint3 b) { - return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z); + return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z); } inline __host__ __device__ void operator-=(uint3 &a, uint3 b) { - a.x -= b.x; - a.y -= b.y; - a.z -= b.z; + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; } inline __host__ __device__ uint3 operator-(uint3 a, uint b) { - return make_uint3(a.x - b, a.y - b, a.z - b); + return make_uint3(a.x - b, a.y - b, a.z - b); } inline __host__ __device__ uint3 operator-(uint b, uint3 a) { - return make_uint3(b - a.x, b - a.y, b - a.z); + return make_uint3(b - a.x, b - a.y, b - a.z); } inline __host__ __device__ void operator-=(uint3 &a, uint b) { - a.x -= b; - a.y -= b; - a.z -= b; + a.x -= b; + a.y -= b; + a.z -= b; } inline __host__ __device__ float4 operator-(float4 a, float4 b) { - return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); + return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } inline __host__ __device__ void operator-=(float4 &a, float4 b) { - a.x -= b.x; - a.y -= b.y; - a.z -= b.z; - a.w -= b.w; + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; } inline __host__ __device__ float4 operator-(float4 a, float b) { - return make_float4(a.x - b, a.y - b, a.z - b, a.w - b); + return make_float4(a.x - b, a.y - b, a.z - b, a.w - b); } inline __host__ __device__ void operator-=(float4 &a, float b) { - a.x -= b; - a.y -= b; - a.z -= b; - a.w -= b; + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; } inline __host__ __device__ int4 operator-(int4 a, int4 b) { - return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); + return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } inline __host__ __device__ void operator-=(int4 &a, int4 b) { - a.x -= b.x; - a.y -= b.y; - a.z -= b.z; - a.w -= b.w; + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; } inline __host__ __device__ int4 operator-(int4 a, int b) { - return make_int4(a.x - b, a.y - b, a.z - b, a.w - b); + return make_int4(a.x - b, a.y - b, a.z - b, a.w - b); } inline __host__ __device__ int4 operator-(int b, int4 a) { - return make_int4(b - a.x, b - a.y, b - a.z, b - a.w); + return make_int4(b - a.x, b - a.y, b - a.z, b - a.w); } inline __host__ __device__ void operator-=(int4 &a, int b) { - a.x -= b; - a.y -= b; - a.z -= b; - a.w -= b; + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; } inline __host__ __device__ uint4 operator-(uint4 a, uint4 b) { - return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); + return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } inline __host__ __device__ void operator-=(uint4 &a, uint4 b) { - a.x -= b.x; - a.y -= b.y; - a.z -= b.z; - a.w -= b.w; + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; + a.w -= b.w; } inline __host__ __device__ uint4 operator-(uint4 a, uint b) { - return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b); + return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b); } inline __host__ __device__ uint4 operator-(uint b, uint4 a) { - return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w); + return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w); } inline __host__ __device__ void operator-=(uint4 &a, uint b) { - a.x -= b; - a.y -= b; - a.z -= b; - a.w -= b; + a.x -= b; + a.y -= b; + a.z -= b; + a.w -= b; } //////////////////////////////////////////////////////////////////////////////// @@ -1117,351 +1117,351 @@ inline __host__ __device__ float2 operator*(float2 a, float2 b) { - return make_float2(a.x * b.x, a.y * b.y); + return make_float2(a.x * b.x, a.y * b.y); } inline __host__ __device__ void operator*=(float2 &a, float2 b) { - a.x *= b.x; - a.y *= b.y; + a.x *= b.x; + a.y *= b.y; } inline __host__ __device__ float2 operator*(float2 a, float b) { - return make_float2(a.x * b, a.y * b); + return make_float2(a.x * b, a.y * b); } inline __host__ __device__ float2 operator*(float b, float2 a) { - return make_float2(b * a.x, b * a.y); + return make_float2(b * a.x, b * a.y); } inline __host__ __device__ void operator*=(float2 &a, float b) { - a.x *= b; - a.y *= b; + a.x *= b; + a.y *= b; } inline __host__ __device__ int2 operator*(int2 a, int2 b) { - return make_int2(a.x * b.x, a.y * b.y); + return make_int2(a.x * b.x, a.y * b.y); } inline __host__ __device__ void operator*=(int2 &a, int2 b) { - a.x *= b.x; - a.y *= b.y; + a.x *= b.x; + a.y *= b.y; } inline __host__ __device__ int2 operator*(int2 a, int b) { - return make_int2(a.x * b, a.y * b); + return make_int2(a.x * b, a.y * b); } inline __host__ __device__ int2 operator*(int b, int2 a) { - return make_int2(b * a.x, b * a.y); + return make_int2(b * a.x, b * a.y); } inline __host__ __device__ void operator*=(int2 &a, int b) { - a.x *= b; - a.y *= b; + a.x *= b; + a.y *= b; } inline __host__ __device__ uint2 operator*(uint2 a, uint2 b) { - return make_uint2(a.x * b.x, a.y * b.y); + return make_uint2(a.x * b.x, a.y * b.y); } inline __host__ __device__ void operator*=(uint2 &a, uint2 b) { - a.x *= b.x; - a.y *= b.y; + a.x *= b.x; + a.y *= b.y; } inline __host__ __device__ uint2 operator*(uint2 a, uint b) { - return make_uint2(a.x * b, a.y * b); + return make_uint2(a.x * b, a.y * b); } inline __host__ __device__ uint2 operator*(uint b, uint2 a) { - return make_uint2(b * a.x, b * a.y); + return make_uint2(b * a.x, b * a.y); } inline __host__ __device__ void operator*=(uint2 &a, uint b) { - a.x *= b; - a.y *= b; + a.x *= b; + a.y *= b; } inline __host__ __device__ float3 operator*(float3 a, float3 b) { - return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); + return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); } inline __host__ __device__ void operator*=(float3 &a, float3 b) { - a.x *= b.x; - a.y *= b.y; - a.z *= b.z; + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; } inline __host__ __device__ float3 operator*(float3 a, float b) { - return make_float3(a.x * b, a.y * b, a.z * b); + return make_float3(a.x * b, a.y * b, a.z * b); } inline __host__ __device__ float3 operator*(float b, float3 a) { - return make_float3(b * a.x, b * a.y, b * a.z); + return make_float3(b * a.x, b * a.y, b * a.z); } inline __host__ __device__ void operator*=(float3 &a, float b) { - a.x *= b; - a.y *= b; - a.z *= b; + a.x *= b; + a.y *= b; + a.z *= b; } inline __host__ __device__ int3 operator*(int3 a, int3 b) { - return make_int3(a.x * b.x, a.y * b.y, a.z * b.z); + return make_int3(a.x * b.x, a.y * b.y, a.z * b.z); } inline __host__ __device__ void operator*=(int3 &a, int3 b) { - a.x *= b.x; - a.y *= b.y; - a.z *= b.z; + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; } inline __host__ __device__ int3 operator*(int3 a, int b) { - return make_int3(a.x * b, a.y * b, a.z * b); + return make_int3(a.x * b, a.y * b, a.z * b); } inline __host__ __device__ int3 operator*(int b, int3 a) { - return make_int3(b * a.x, b * a.y, b * a.z); + return make_int3(b * a.x, b * a.y, b * a.z); } inline __host__ __device__ void operator*=(int3 &a, int b) { - a.x *= b; - a.y *= b; - a.z *= b; + a.x *= b; + a.y *= b; + a.z *= b; } inline __host__ __device__ uint3 operator*(uint3 a, uint3 b) { - return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z); + return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z); } inline __host__ __device__ void operator*=(uint3 &a, uint3 b) { - a.x *= b.x; - a.y *= b.y; - a.z *= b.z; + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; } inline __host__ __device__ uint3 operator*(uint3 a, uint b) { - return make_uint3(a.x * b, a.y * b, a.z * b); + return make_uint3(a.x * b, a.y * b, a.z * b); } inline __host__ __device__ uint3 operator*(uint b, uint3 a) { - return make_uint3(b * a.x, b * a.y, b * a.z); + return make_uint3(b * a.x, b * a.y, b * a.z); } inline __host__ __device__ void operator*=(uint3 &a, uint b) { - a.x *= b; - a.y *= b; - a.z *= b; + a.x *= b; + a.y *= b; + a.z *= b; } inline __host__ __device__ float4 operator*(float4 a, float4 b) { - return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); + return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } inline __host__ __device__ void operator*=(float4 &a, float4 b) { - a.x *= b.x; - a.y *= b.y; - a.z *= b.z; - a.w *= b.w; + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; } inline __host__ __device__ float4 operator*(float4 a, float b) { - return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); + return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); } inline __host__ __device__ float4 operator*(float b, float4 a) { - return make_float4(b * a.x, b * a.y, b * a.z, b * a.w); + return make_float4(b * a.x, b * a.y, b * a.z, b * a.w); } inline __host__ __device__ void operator*=(float4 &a, float b) { - a.x *= b; - a.y *= b; - a.z *= b; - a.w *= b; + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; } inline __host__ __device__ int4 operator*(int4 a, int4 b) { - return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); + return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } inline __host__ __device__ void operator*=(int4 &a, int4 b) { - a.x *= b.x; - a.y *= b.y; - a.z *= b.z; - a.w *= b.w; + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; } inline __host__ __device__ int4 operator*(int4 a, int b) { - return make_int4(a.x * b, a.y * b, a.z * b, a.w * b); + return make_int4(a.x * b, a.y * b, a.z * b, a.w * b); } inline __host__ __device__ int4 operator*(int b, int4 a) { - return make_int4(b * a.x, b * a.y, b * a.z, b * a.w); + return make_int4(b * a.x, b * a.y, b * a.z, b * a.w); } inline __host__ __device__ void operator*=(int4 &a, int b) { - a.x *= b; - a.y *= b; - a.z *= b; - a.w *= b; + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; } inline __host__ __device__ uint4 operator*(uint4 a, uint4 b) { - return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); + return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } inline __host__ __device__ void operator*=(uint4 &a, uint4 b) { - a.x *= b.x; - a.y *= b.y; - a.z *= b.z; - a.w *= b.w; + a.x *= b.x; + a.y *= b.y; + a.z *= b.z; + a.w *= b.w; } inline __host__ __device__ uint4 operator*(uint4 a, uint b) { - return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b); + return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b); } inline __host__ __device__ uint4 operator*(uint b, uint4 a) { - return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w); + return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w); } inline __host__ __device__ void operator*=(uint4 &a, uint b) { - a.x *= b; - a.y *= b; - a.z *= b; - a.w *= b; + a.x *= b; + a.y *= b; + a.z *= b; + a.w *= b; } //////////////////////////////////////////////////////////////////////////////// @@ -1472,117 +1472,117 @@ inline __host__ __device__ float2 operator/(float2 a, float2 b) { - return make_float2(a.x / b.x, a.y / b.y); + return make_float2(a.x / b.x, a.y / b.y); } inline __host__ __device__ void operator/=(float2 &a, float2 b) { - a.x /= b.x; - a.y /= b.y; + a.x /= b.x; + a.y /= b.y; } inline __host__ __device__ float2 operator/(float2 a, float b) { - return make_float2(a.x / b, a.y / b); + return make_float2(a.x / b, a.y / b); } inline __host__ __device__ void operator/=(float2 &a, float b) { - a.x /= b; - a.y /= b; + a.x /= b; + a.y /= b; } inline __host__ __device__ float2 operator/(float b, float2 a) { - return make_float2(b / a.x, b / a.y); + return make_float2(b / a.x, b / a.y); } inline __host__ __device__ float3 operator/(float3 a, float3 b) { - return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); + return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); } inline __host__ __device__ void operator/=(float3 &a, float3 b) { - a.x /= b.x; - a.y /= b.y; - a.z /= b.z; + a.x /= b.x; + a.y /= b.y; + a.z /= b.z; } inline __host__ __device__ float3 operator/(float3 a, float b) { - return make_float3(a.x / b, a.y / b, a.z / b); + return make_float3(a.x / b, a.y / b, a.z / b); } inline __host__ __device__ void operator/=(float3 &a, float b) { - a.x /= b; - a.y /= b; - a.z /= b; + a.x /= b; + a.y /= b; + a.z /= b; } inline __host__ __device__ float3 operator/(float b, float3 a) { - return make_float3(b / a.x, b / a.y, b / a.z); + return make_float3(b / a.x, b / a.y, b / a.z); } inline __host__ __device__ float4 operator/(float4 a, float4 b) { - return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); + return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); } inline __host__ __device__ void operator/=(float4 &a, float4 b) { - a.x /= b.x; - a.y /= b.y; - a.z /= b.z; - a.w /= b.w; + a.x /= b.x; + a.y /= b.y; + a.z /= b.z; + a.w /= b.w; } inline __host__ __device__ float4 operator/(float4 a, float b) { - return make_float4(a.x / b, a.y / b, a.z / b, a.w / b); + return make_float4(a.x / b, a.y / b, a.z / b, a.w / b); } inline __host__ __device__ void operator/=(float4 &a, float b) { - a.x /= b; - a.y /= b; - a.z /= b; - a.w /= b; + a.x /= b; + a.y /= b; + a.z /= b; + a.w /= b; } inline __host__ __device__ float4 operator/(float b, float4 a) { - return make_float4(b / a.x, b / a.y, b / a.z, b / a.w); + return make_float4(b / a.x, b / a.y, b / a.z, b / a.w); } //////////////////////////////////////////////////////////////////////////////// @@ -1593,63 +1593,63 @@ inline __host__ __device__ float2 fminf(float2 a, float2 b) { - return make_float2(fminf(a.x, b.x), fminf(a.y, b.y)); + return make_float2(fminf(a.x, b.x), fminf(a.y, b.y)); } inline __host__ __device__ float3 fminf(float3 a, float3 b) { - return make_float3(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z)); + return make_float3(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z)); } inline __host__ __device__ float4 fminf(float4 a, float4 b) { - return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w)); + return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w)); } inline __host__ __device__ int2 min(int2 a, int2 b) { - return make_int2(min(a.x, b.x), min(a.y, b.y)); + return make_int2(min(a.x, b.x), min(a.y, b.y)); } inline __host__ __device__ int3 min(int3 a, int3 b) { - return make_int3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z)); + return make_int3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z)); } inline __host__ __device__ int4 min(int4 a, int4 b) { - return make_int4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w)); + return make_int4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w)); } inline __host__ __device__ uint2 min(uint2 a, uint2 b) { - return make_uint2(min(a.x, b.x), min(a.y, b.y)); + return make_uint2(min(a.x, b.x), min(a.y, b.y)); } inline __host__ __device__ uint3 min(uint3 a, uint3 b) { - return make_uint3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z)); + return make_uint3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z)); } inline __host__ __device__ uint4 min(uint4 a, uint4 b) { - return make_uint4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w)); + return make_uint4(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z), min(a.w, b.w)); } //////////////////////////////////////////////////////////////////////////////// @@ -1660,63 +1660,63 @@ inline __host__ __device__ float2 fmaxf(float2 a, float2 b) { - return make_float2(fmaxf(a.x, b.x), fmaxf(a.y, b.y)); + return make_float2(fmaxf(a.x, b.x), fmaxf(a.y, b.y)); } inline __host__ __device__ float3 fmaxf(float3 a, float3 b) { - return make_float3(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z)); + return make_float3(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z)); } inline __host__ __device__ float4 fmaxf(float4 a, float4 b) { - return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w)); + return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w)); } inline __host__ __device__ int2 max(int2 a, int2 b) { - return make_int2(max(a.x, b.x), max(a.y, b.y)); + return make_int2(max(a.x, b.x), max(a.y, b.y)); } inline __host__ __device__ int3 max(int3 a, int3 b) { - return make_int3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z)); + return make_int3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z)); } inline __host__ __device__ int4 max(int4 a, int4 b) { - return make_int4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w)); + return make_int4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w)); } inline __host__ __device__ uint2 max(uint2 a, uint2 b) { - return make_uint2(max(a.x, b.x), max(a.y, b.y)); + return make_uint2(max(a.x, b.x), max(a.y, b.y)); } inline __host__ __device__ uint3 max(uint3 a, uint3 b) { - return make_uint3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z)); + return make_uint3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z)); } inline __host__ __device__ uint4 max(uint4 a, uint4 b) { - return make_uint4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w)); + return make_uint4(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z), max(a.w, b.w)); } //////////////////////////////////////////////////////////////////////////////// @@ -1728,28 +1728,28 @@ inline __device__ __host__ float lerp(float a, float b, float t) { - return a + t * (b - a); + return a + t * (b - a); } inline __device__ __host__ float2 lerp(float2 a, float2 b, float t) { - return a + t * (b - a); + return a + t * (b - a); } inline __device__ __host__ float3 lerp(float3 a, float3 b, float t) { - return a + t * (b - a); + return a + t * (b - a); } inline __device__ __host__ float4 lerp(float4 a, float4 b, float t) { - return a + t * (b - a); + return a + t * (b - a); } //////////////////////////////////////////////////////////////////////////////// @@ -1761,150 +1761,150 @@ inline __device__ __host__ float clamp(float f, float a, float b) { - return fmaxf(a, fminf(f, b)); + return fmaxf(a, fminf(f, b)); } inline __device__ __host__ int clamp(int f, int a, int b) { - return max(a, min(f, b)); + return max(a, min(f, b)); } inline __device__ __host__ uint clamp(uint f, uint a, uint b) { - return max(a, min(f, b)); + return max(a, min(f, b)); } inline __device__ __host__ float2 clamp(float2 v, float a, float b) { - return make_float2(clamp(v.x, a, b), clamp(v.y, a, b)); + return make_float2(clamp(v.x, a, b), clamp(v.y, a, b)); } inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b) { - return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); + return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); } inline __device__ __host__ float3 clamp(float3 v, float a, float b) { - return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); + return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); } inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b) { - return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); + return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); } inline __device__ __host__ float4 clamp(float4 v, float a, float b) { - return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); + return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); } inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b) { - return make_float4( - clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); + return make_float4( + clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); } inline __device__ __host__ int2 clamp(int2 v, int a, int b) { - return make_int2(clamp(v.x, a, b), clamp(v.y, a, b)); + return make_int2(clamp(v.x, a, b), clamp(v.y, a, b)); } inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b) { - return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); + return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); } inline __device__ __host__ int3 clamp(int3 v, int a, int b) { - return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); + return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); } inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b) { - return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); + return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); } inline __device__ __host__ int4 clamp(int4 v, int a, int b) { - return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); + return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); } inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b) { - return make_int4( - clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); + return make_int4( + clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); } inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b) { - return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b)); + return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b)); } inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b) { - return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); + return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y)); } inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b) { - return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); + return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b)); } inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b) { - return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); + return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z)); } inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b) { - return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); + return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b)); } inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b) { - return make_uint4( - clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); + return make_uint4( + clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w)); } //////////////////////////////////////////////////////////////////////////////// @@ -1915,63 +1915,63 @@ inline __host__ __device__ float dot(float2 a, float2 b) { - return a.x * b.x + a.y * b.y; + return a.x * b.x + a.y * b.y; } inline __host__ __device__ float dot(float3 a, float3 b) { - return a.x * b.x + a.y * b.y + a.z * b.z; + return a.x * b.x + a.y * b.y + a.z * b.z; } inline __host__ __device__ float dot(float4 a, float4 b) { - return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; } inline __host__ __device__ int dot(int2 a, int2 b) { - return a.x * b.x + a.y * b.y; + return a.x * b.x + a.y * b.y; } inline __host__ __device__ int dot(int3 a, int3 b) { - return a.x * b.x + a.y * b.y + a.z * b.z; + return a.x * b.x + a.y * b.y + a.z * b.z; } inline __host__ __device__ int dot(int4 a, int4 b) { - return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; } inline __host__ __device__ uint dot(uint2 a, uint2 b) { - return a.x * b.x + a.y * b.y; + return a.x * b.x + a.y * b.y; } inline __host__ __device__ uint dot(uint3 a, uint3 b) { - return a.x * b.x + a.y * b.y + a.z * b.z; + return a.x * b.x + a.y * b.y + a.z * b.z; } inline __host__ __device__ uint dot(uint4 a, uint4 b) { - return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; } //////////////////////////////////////////////////////////////////////////////// @@ -1982,21 +1982,21 @@ inline __host__ __device__ float length(float2 v) { - return sqrtf(dot(v, v)); + return sqrtf(dot(v, v)); } inline __host__ __device__ float length(float3 v) { - return sqrtf(dot(v, v)); + return sqrtf(dot(v, v)); } inline __host__ __device__ float length(float4 v) { - return sqrtf(dot(v, v)); + return sqrtf(dot(v, v)); } //////////////////////////////////////////////////////////////////////////////// @@ -2007,24 +2007,24 @@ inline __host__ __device__ float2 normalize(float2 v) { - float invLen = rsqrtf(dot(v, v)); - return v * invLen; + float invLen = rsqrtf(dot(v, v)); + return v * invLen; } inline __host__ __device__ float3 normalize(float3 v) { - float invLen = rsqrtf(dot(v, v)); - return v * invLen; + float invLen = rsqrtf(dot(v, v)); + return v * invLen; } inline __host__ __device__ float4 normalize(float4 v) { - float invLen = rsqrtf(dot(v, v)); - return v * invLen; + float invLen = rsqrtf(dot(v, v)); + return v * invLen; } //////////////////////////////////////////////////////////////////////////////// @@ -2035,21 +2035,21 @@ inline __host__ __device__ float2 floorf(float2 v) { - return make_float2(floorf(v.x), floorf(v.y)); + return make_float2(floorf(v.x), floorf(v.y)); } inline __host__ __device__ float3 floorf(float3 v) { - return make_float3(floorf(v.x), floorf(v.y), floorf(v.z)); + return make_float3(floorf(v.x), floorf(v.y), floorf(v.z)); } inline __host__ __device__ float4 floorf(float4 v) { - return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w)); + return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w)); } //////////////////////////////////////////////////////////////////////////////// @@ -2060,28 +2060,28 @@ inline __host__ __device__ float fracf(float v) { - return v - floorf(v); + return v - floorf(v); } inline __host__ __device__ float2 fracf(float2 v) { - return make_float2(fracf(v.x), fracf(v.y)); + return make_float2(fracf(v.x), fracf(v.y)); } inline __host__ __device__ float3 fracf(float3 v) { - return make_float3(fracf(v.x), fracf(v.y), fracf(v.z)); + return make_float3(fracf(v.x), fracf(v.y), fracf(v.z)); } inline __host__ __device__ float4 fracf(float4 v) { - return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w)); + return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w)); } //////////////////////////////////////////////////////////////////////////////// @@ -2092,21 +2092,21 @@ inline __host__ __device__ float2 fmodf(float2 a, float2 b) { - return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y)); + return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y)); } inline __host__ __device__ float3 fmodf(float3 a, float3 b) { - return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z)); + return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z)); } inline __host__ __device__ float4 fmodf(float4 a, float4 b) { - return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w)); + return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w)); } //////////////////////////////////////////////////////////////////////////////// @@ -2117,42 +2117,42 @@ inline __host__ __device__ float2 fabs(float2 v) { - return make_float2(fabs(v.x), fabs(v.y)); + return make_float2(fabs(v.x), fabs(v.y)); } inline __host__ __device__ float3 fabs(float3 v) { - return make_float3(fabs(v.x), fabs(v.y), fabs(v.z)); + return make_float3(fabs(v.x), fabs(v.y), fabs(v.z)); } inline __host__ __device__ float4 fabs(float4 v) { - return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w)); + return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w)); } inline __host__ __device__ int2 abs(int2 v) { - return make_int2(abs(v.x), abs(v.y)); + return make_int2(abs(v.x), abs(v.y)); } inline __host__ __device__ int3 abs(int3 v) { - return make_int3(abs(v.x), abs(v.y), abs(v.z)); + return make_int3(abs(v.x), abs(v.y), abs(v.z)); } inline __host__ __device__ int4 abs(int4 v) { - return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w)); + return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w)); } //////////////////////////////////////////////////////////////////////////////// @@ -2165,7 +2165,7 @@ inline __host__ __device__ float3 reflect(float3 i, float3 n) { - return i - 2.0f * n * dot(n, i); + return i - 2.0f * n * dot(n, i); } //////////////////////////////////////////////////////////////////////////////// @@ -2176,7 +2176,7 @@ inline __host__ __device__ float3 cross(float3 a, float3 b) { - return make_float3(a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x); + return make_float3(a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x); } //////////////////////////////////////////////////////////////////////////////// @@ -2190,32 +2190,32 @@ inline __device__ __host__ float smoothstep(float a, float b, float x) { - float y = clamp((x - a) / (b - a), 0.0f, 1.0f); - return (y * y * (3.0f - (2.0f * y))); + float y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y * y * (3.0f - (2.0f * y))); } inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x) { - float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f); - return (y * y * (make_float2(3.0f) - (make_float2(2.0f) * y))); + float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y * y * (make_float2(3.0f) - (make_float2(2.0f) * y))); } inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x) { - float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f); - return (y * y * (make_float3(3.0f) - (make_float3(2.0f) * y))); + float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y * y * (make_float3(3.0f) - (make_float3(2.0f) * y))); } inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x) { - float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f); - return (y * y * (make_float4(3.0f) - (make_float4(2.0f) * y))); + float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f); + return (y * y * (make_float4(3.0f) - (make_float4(2.0f) * y))); } #endif diff --git a/librapid/include/librapid/cuda/helper_string.h b/librapid/include/librapid/cuda/helper_string.h index a60a3c25..fa1dd0f0 100644 --- a/librapid/include/librapid/cuda/helper_string.h +++ b/librapid/include/librapid/cuda/helper_string.h @@ -35,225 +35,225 @@ #include #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) -# ifndef _CRT_SECURE_NO_DEPRECATE -# define _CRT_SECURE_NO_DEPRECATE -# endif -# ifndef STRCASECMP -# define STRCASECMP _stricmp -# endif -# ifndef STRNCASECMP -# define STRNCASECMP _strnicmp -# endif -# ifndef STRCPY -# define STRCPY(sFilePath, nLength, sPath) strcpy_s(sFilePath, nLength, sPath) -# endif - -# ifndef FOPEN -# define FOPEN(fHandle, filename, mode) fopen_s(&fHandle, filename, mode) -# endif -# ifndef FOPEN_FAIL -# define FOPEN_FAIL(result) (result != 0) -# endif -# ifndef SSCANF -# define SSCANF sscanf_s -# endif -# ifndef SPRINTF -# define SPRINTF sprintf_s -# endif +# ifndef _CRT_SECURE_NO_DEPRECATE +# define _CRT_SECURE_NO_DEPRECATE +# endif +# ifndef STRCASECMP +# define STRCASECMP _stricmp +# endif +# ifndef STRNCASECMP +# define STRNCASECMP _strnicmp +# endif +# ifndef STRCPY +# define STRCPY(sFilePath, nLength, sPath) strcpy_s(sFilePath, nLength, sPath) +# endif + +# ifndef FOPEN +# define FOPEN(fHandle, filename, mode) fopen_s(&fHandle, filename, mode) +# endif +# ifndef FOPEN_FAIL +# define FOPEN_FAIL(result) (result != 0) +# endif +# ifndef SSCANF +# define SSCANF sscanf_s +# endif +# ifndef SPRINTF +# define SPRINTF sprintf_s +# endif #else // Linux Includes -# include -# include - -# ifndef STRCASECMP -# define STRCASECMP strcasecmp -# endif -# ifndef STRNCASECMP -# define STRNCASECMP strncasecmp -# endif -# ifndef STRCPY -# define STRCPY(sFilePath, nLength, sPath) strcpy(sFilePath, sPath) -# endif - -# ifndef FOPEN -# define FOPEN(fHandle, filename, mode) (fHandle = fopen(filename, mode)) -# endif -# ifndef FOPEN_FAIL -# define FOPEN_FAIL(result) (result == NULL) -# endif -# ifndef SSCANF -# define SSCANF sscanf -# endif -# ifndef SPRINTF -# define SPRINTF sprintf -# endif +# include +# include + +# ifndef STRCASECMP +# define STRCASECMP strcasecmp +# endif +# ifndef STRNCASECMP +# define STRNCASECMP strncasecmp +# endif +# ifndef STRCPY +# define STRCPY(sFilePath, nLength, sPath) strcpy(sFilePath, sPath) +# endif + +# ifndef FOPEN +# define FOPEN(fHandle, filename, mode) (fHandle = fopen(filename, mode)) +# endif +# ifndef FOPEN_FAIL +# define FOPEN_FAIL(result) (result == NULL) +# endif +# ifndef SSCANF +# define SSCANF sscanf +# endif +# ifndef SPRINTF +# define SPRINTF sprintf +# endif #endif #ifndef EXIT_WAIVED -# define EXIT_WAIVED 2 +# define EXIT_WAIVED 2 #endif // CUDA Utility Helper Functions inline int stringRemoveDelimiter(char delimiter, const char *string) { - int string_start = 0; + int string_start = 0; - while (string[string_start] == delimiter) { string_start++; } + while (string[string_start] == delimiter) { string_start++; } - if (string_start >= static_cast(strlen(string) - 1)) { return 0; } + if (string_start >= static_cast(strlen(string) - 1)) { return 0; } - return string_start; + return string_start; } inline int getFileExtension(char *filename, char **extension) { - int string_length = static_cast(strlen(filename)); + int string_length = static_cast(strlen(filename)); - while (filename[string_length--] != '.') { - if (string_length == 0) break; - } + while (filename[string_length--] != '.') { + if (string_length == 0) break; + } - if (string_length > 0) string_length += 2; + if (string_length > 0) string_length += 2; - if (string_length == 0) - *extension = NULL; - else - *extension = &filename[string_length]; + if (string_length == 0) + *extension = NULL; + else + *extension = &filename[string_length]; - return string_length; + return string_length; } inline bool checkCmdLineFlag(const int argc, const char **argv, const char *string_ref) { - bool bFound = false; + bool bFound = false; - if (argc >= 1) { - for (int i = 1; i < argc; i++) { - int string_start = stringRemoveDelimiter('-', argv[i]); - const char *string_argv = &argv[i][string_start]; + if (argc >= 1) { + for (int i = 1; i < argc; i++) { + int string_start = stringRemoveDelimiter('-', argv[i]); + const char *string_argv = &argv[i][string_start]; - const char *equal_pos = strchr(string_argv, '='); - int argv_length = - static_cast(equal_pos == 0 ? strlen(string_argv) : equal_pos - string_argv); + const char *equal_pos = strchr(string_argv, '='); + int argv_length = + static_cast(equal_pos == 0 ? strlen(string_argv) : equal_pos - string_argv); - int length = static_cast(strlen(string_ref)); + int length = static_cast(strlen(string_ref)); - if (length == argv_length && !STRNCASECMP(string_argv, string_ref, length)) { - bFound = true; - continue; - } - } - } + if (length == argv_length && !STRNCASECMP(string_argv, string_ref, length)) { + bFound = true; + continue; + } + } + } - return bFound; + return bFound; } // This function wraps the CUDA Driver API into a template function template inline bool getCmdLineArgumentValue(const int argc, const char **argv, const char *string_ref, - T *value) { - bool bFound = false; - - if (argc >= 1) { - for (int i = 1; i < argc; i++) { - int string_start = stringRemoveDelimiter('-', argv[i]); - const char *string_argv = &argv[i][string_start]; - int length = static_cast(strlen(string_ref)); - - if (!STRNCASECMP(string_argv, string_ref, length)) { - if (length + 1 <= static_cast(strlen(string_argv))) { - int auto_inc = (string_argv[length] == '=') ? 1 : 0; - *value = (T)atoi(&string_argv[length + auto_inc]); - } - - bFound = true; - i = argc; - } - } - } - - return bFound; + T *value) { + bool bFound = false; + + if (argc >= 1) { + for (int i = 1; i < argc; i++) { + int string_start = stringRemoveDelimiter('-', argv[i]); + const char *string_argv = &argv[i][string_start]; + int length = static_cast(strlen(string_ref)); + + if (!STRNCASECMP(string_argv, string_ref, length)) { + if (length + 1 <= static_cast(strlen(string_argv))) { + int auto_inc = (string_argv[length] == '=') ? 1 : 0; + *value = (T)atoi(&string_argv[length + auto_inc]); + } + + bFound = true; + i = argc; + } + } + } + + return bFound; } inline int getCmdLineArgumentInt(const int argc, const char **argv, const char *string_ref) { - bool bFound = false; - int value = -1; - - if (argc >= 1) { - for (int i = 1; i < argc; i++) { - int string_start = stringRemoveDelimiter('-', argv[i]); - const char *string_argv = &argv[i][string_start]; - int length = static_cast(strlen(string_ref)); - - if (!STRNCASECMP(string_argv, string_ref, length)) { - if (length + 1 <= static_cast(strlen(string_argv))) { - int auto_inc = (string_argv[length] == '=') ? 1 : 0; - value = atoi(&string_argv[length + auto_inc]); - } else { - value = 0; - } - - bFound = true; - continue; - } - } - } - - if (bFound) { - return value; - } else { - return 0; - } + bool bFound = false; + int value = -1; + + if (argc >= 1) { + for (int i = 1; i < argc; i++) { + int string_start = stringRemoveDelimiter('-', argv[i]); + const char *string_argv = &argv[i][string_start]; + int length = static_cast(strlen(string_ref)); + + if (!STRNCASECMP(string_argv, string_ref, length)) { + if (length + 1 <= static_cast(strlen(string_argv))) { + int auto_inc = (string_argv[length] == '=') ? 1 : 0; + value = atoi(&string_argv[length + auto_inc]); + } else { + value = 0; + } + + bFound = true; + continue; + } + } + } + + if (bFound) { + return value; + } else { + return 0; + } } inline float getCmdLineArgumentFloat(const int argc, const char **argv, const char *string_ref) { - bool bFound = false; - float value = -1; - - if (argc >= 1) { - for (int i = 1; i < argc; i++) { - int string_start = stringRemoveDelimiter('-', argv[i]); - const char *string_argv = &argv[i][string_start]; - int length = static_cast(strlen(string_ref)); - - if (!STRNCASECMP(string_argv, string_ref, length)) { - if (length + 1 <= static_cast(strlen(string_argv))) { - int auto_inc = (string_argv[length] == '=') ? 1 : 0; - value = static_cast(atof(&string_argv[length + auto_inc])); - } else { - value = 0.f; - } - - bFound = true; - continue; - } - } - } - - if (bFound) { - return value; - } else { - return 0; - } + bool bFound = false; + float value = -1; + + if (argc >= 1) { + for (int i = 1; i < argc; i++) { + int string_start = stringRemoveDelimiter('-', argv[i]); + const char *string_argv = &argv[i][string_start]; + int length = static_cast(strlen(string_ref)); + + if (!STRNCASECMP(string_argv, string_ref, length)) { + if (length + 1 <= static_cast(strlen(string_argv))) { + int auto_inc = (string_argv[length] == '=') ? 1 : 0; + value = static_cast(atof(&string_argv[length + auto_inc])); + } else { + value = 0.f; + } + + bFound = true; + continue; + } + } + } + + if (bFound) { + return value; + } else { + return 0; + } } inline bool getCmdLineArgumentString(const int argc, const char **argv, const char *string_ref, - char **string_retval) { - bool bFound = false; - - if (argc >= 1) { - for (int i = 1; i < argc; i++) { - int string_start = stringRemoveDelimiter('-', argv[i]); - char *string_argv = const_cast(&argv[i][string_start]); - int length = static_cast(strlen(string_ref)); - - if (!STRNCASECMP(string_argv, string_ref, length)) { - *string_retval = &string_argv[length + 1]; - bFound = true; - continue; - } - } - } - - if (!bFound) { *string_retval = NULL; } - - return bFound; + char **string_retval) { + bool bFound = false; + + if (argc >= 1) { + for (int i = 1; i < argc; i++) { + int string_start = stringRemoveDelimiter('-', argv[i]); + char *string_argv = const_cast(&argv[i][string_start]); + int length = static_cast(strlen(string_ref)); + + if (!STRNCASECMP(string_argv, string_ref, length)) { + *string_retval = &string_argv[length + 1]; + bFound = true; + continue; + } + } + } + + if (!bFound) { *string_retval = NULL; } + + return bFound; } ////////////////////////////////////////////////////////////////////////////// @@ -265,90 +265,90 @@ inline bool getCmdLineArgumentString(const int argc, const char **argv, const ch //! @param executable_path optional absolute path of the executable ////////////////////////////////////////////////////////////////////////////// inline char *sdkFindFilePath(const char *filename, const char *executable_path) { - // defines a variable that is replaced with the name of - // the executable - - // Typical relative search paths to locate needed companion files (e.g. - // sample input data, or JIT source files) The origin for the relative - // search may be the .exe file, a .bat file launching an .exe, a browser - // .exe launching the .exe or .bat, etc - const char *searchPath[] = { - "./", // same dir - "./data/", // same dir - "../../../../Samples//", // up 4 in tree - "../../../Samples//", // up 3 in tree - "../../Samples//", // up 2 in tree - "../../../../Samples//data/", // up 4 in tree - "../../../Samples//data/", // up 3 in tree - "../../Samples//data/", // up 2 in tree - "../../../../Common/data/", // up 4 in tree - "../../../Common/data/", // up 3 in tree - "../../Common/data/" // up 2 in tree - }; - - // Extract the executable name - std::string executable_name; - - if (executable_path != 0) { - executable_name = std::string(executable_path); + // defines a variable that is replaced with the name of + // the executable + + // Typical relative search paths to locate needed companion files (e.g. + // sample input data, or JIT source files) The origin for the relative + // search may be the .exe file, a .bat file launching an .exe, a browser + // .exe launching the .exe or .bat, etc + const char *searchPath[] = { + "./", // same dir + "./data/", // same dir + "../../../../Samples//", // up 4 in tree + "../../../Samples//", // up 3 in tree + "../../Samples//", // up 2 in tree + "../../../../Samples//data/", // up 4 in tree + "../../../Samples//data/", // up 3 in tree + "../../Samples//data/", // up 2 in tree + "../../../../Common/data/", // up 4 in tree + "../../../Common/data/", // up 3 in tree + "../../Common/data/" // up 2 in tree + }; + + // Extract the executable name + std::string executable_name; + + if (executable_path != 0) { + executable_name = std::string(executable_path); #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) - // Windows path delimiter - size_t delimiter_pos = executable_name.find_last_of('\\'); - executable_name.erase(0, delimiter_pos + 1); + // Windows path delimiter + size_t delimiter_pos = executable_name.find_last_of('\\'); + executable_name.erase(0, delimiter_pos + 1); - if (executable_name.rfind(".exe") != std::string::npos) { - // we strip .exe, only if the .exe is found - executable_name.resize(executable_name.size() - 4); - } + if (executable_name.rfind(".exe") != std::string::npos) { + // we strip .exe, only if the .exe is found + executable_name.resize(executable_name.size() - 4); + } #else - // Linux & OSX path delimiter - size_t delimiter_pos = executable_name.find_last_of('/'); - executable_name.erase(0, delimiter_pos + 1); + // Linux & OSX path delimiter + size_t delimiter_pos = executable_name.find_last_of('/'); + executable_name.erase(0, delimiter_pos + 1); #endif - } - - // Loop over all search paths and return the first hit - for (unsigned int i = 0; i < sizeof(searchPath) / sizeof(char *); ++i) { - std::string path(searchPath[i]); - size_t executable_name_pos = path.find(""); - - // If there is executable_name variable in the searchPath - // replace it with the value - if (executable_name_pos != std::string::npos) { - if (executable_path != 0) { - path.replace(executable_name_pos, strlen(""), executable_name); - } else { - // Skip this path entry if no executable argument is given - continue; - } - } + } + + // Loop over all search paths and return the first hit + for (unsigned int i = 0; i < sizeof(searchPath) / sizeof(char *); ++i) { + std::string path(searchPath[i]); + size_t executable_name_pos = path.find(""); + + // If there is executable_name variable in the searchPath + // replace it with the value + if (executable_name_pos != std::string::npos) { + if (executable_path != 0) { + path.replace(executable_name_pos, strlen(""), executable_name); + } else { + // Skip this path entry if no executable argument is given + continue; + } + } #ifdef _DEBUG - printf("sdkFindFilePath <%s> in %s\n", filename, path.c_str()); + printf("sdkFindFilePath <%s> in %s\n", filename, path.c_str()); #endif - // Test if the file exists - path.append(filename); - FILE *fp; - FOPEN(fp, path.c_str(), "rb"); - - if (fp != NULL) { - fclose(fp); - // File found - // returning an allocated array here for backwards compatibility - // reasons - char *file_path = reinterpret_cast(malloc(path.length() + 1)); - STRCPY(file_path, path.length() + 1, path.c_str()); - return file_path; - } - - if (fp) { fclose(fp); } - } - - // File not found - return 0; + // Test if the file exists + path.append(filename); + FILE *fp; + FOPEN(fp, path.c_str(), "rb"); + + if (fp != NULL) { + fclose(fp); + // File found + // returning an allocated array here for backwards compatibility + // reasons + char *file_path = reinterpret_cast(malloc(path.length() + 1)); + STRCPY(file_path, path.length() + 1, path.c_str()); + return file_path; + } + + if (fp) { fclose(fp); } + } + + // File not found + return 0; } #endif // COMMON_HELPER_STRING_H_ diff --git a/librapid/include/librapid/cuda/kernel_header.h b/librapid/include/librapid/cuda/kernel_header.h index 8c65771e..b0d060b0 100644 --- a/librapid/include/librapid/cuda/kernel_header.h +++ b/librapid/include/librapid/cuda/kernel_header.h @@ -3,19 +3,19 @@ #include namespace librapid::imp { - inline const jitify::detail::vector cudaHeaders = { // CUDA_INCLUDE_DIRS, - CUDA_INCLUDE_DIRS + std::string("/curand.h"), - CUDA_INCLUDE_DIRS + std::string("/curand_kernel.h"), - CUDA_INCLUDE_DIRS + std::string("/cublas_v2.h"), - CUDA_INCLUDE_DIRS + std::string("/cublas_api.h"), - CUDA_INCLUDE_DIRS + std::string("/cuda_fp16.h"), - CUDA_INCLUDE_DIRS + std::string("/cuda_bf16.h")}; - - inline const std::vector cudaParams = { - "--disable-warnings", "-std=c++17", std::string("-I") + CUDA_INCLUDE_DIRS}; - - inline std::string genKernelHeader() { - return fmt::format(R"V0G0N( + inline const jitify::detail::vector cudaHeaders = { // CUDA_INCLUDE_DIRS, + CUDA_INCLUDE_DIRS + std::string("/curand.h"), + CUDA_INCLUDE_DIRS + std::string("/curand_kernel.h"), + CUDA_INCLUDE_DIRS + std::string("/cublas_v2.h"), + CUDA_INCLUDE_DIRS + std::string("/cublas_api.h"), + CUDA_INCLUDE_DIRS + std::string("/cuda_fp16.h"), + CUDA_INCLUDE_DIRS + std::string("/cuda_bf16.h")}; + + inline const std::vector cudaParams = { + "--disable-warnings", "-std=c++17", std::string("-I") + CUDA_INCLUDE_DIRS}; + + inline std::string genKernelHeader() { + return fmt::format(R"V0G0N( #include <"{0}/curand_kernel.h> #include <"{0}"/curand.h> #include @@ -283,6 +283,6 @@ namespace librapid {{ #endif // LIBRAPID_CUSTOM_COMPLEX )V0G0N", - CUDA_INCLUDE_DIRS); - } + CUDA_INCLUDE_DIRS); + } } // namespace librapid::imp diff --git a/librapid/include/librapid/cuda/kernels/abs.cu b/librapid/include/librapid/cuda/kernels/abs.cu index c52b65a9..48da4e85 100644 --- a/librapid/include/librapid/cuda/kernels/abs.cu +++ b/librapid/include/librapid/cuda/kernels/abs.cu @@ -1,5 +1,5 @@ template __global__ void absArrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = abs(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = abs(src[kernelIndex]); } } diff --git a/librapid/include/librapid/cuda/kernels/activations.cu b/librapid/include/librapid/cuda/kernels/activations.cu index 50cbdc54..520f3f9e 100644 --- a/librapid/include/librapid/cuda/kernels/activations.cu +++ b/librapid/include/librapid/cuda/kernels/activations.cu @@ -1,23 +1,23 @@ template __global__ void sigmoidActivationForward(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = 1 / (1 + exp((float)-src[kernelIndex])); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = 1 / (1 + exp((float)-src[kernelIndex])); } } template<> __global__ void sigmoidActivationForward(size_t elements, float *dst, float *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = 1 / (1 + exp(-src[kernelIndex])); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = 1 / (1 + exp(-src[kernelIndex])); } } template<> __global__ void sigmoidActivationForward(size_t elements, double *dst, double *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = 1 / (1 + exp(-src[kernelIndex])); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = 1 / (1 + exp(-src[kernelIndex])); } } template __global__ void sigmoidActivationBackward(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = src[kernelIndex] * (1 - src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = src[kernelIndex] * (1 - src[kernelIndex]); } } diff --git a/librapid/include/librapid/cuda/kernels/arithmetic.cu b/librapid/include/librapid/cuda/kernels/arithmetic.cu index fedda165..d8975b39 100644 --- a/librapid/include/librapid/cuda/kernels/arithmetic.cu +++ b/librapid/include/librapid/cuda/kernels/arithmetic.cu @@ -3,188 +3,189 @@ template __global__ void addArrays(size_t elements, Destination *dst, LHS *lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] + rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] + rhs[kernelIndex]; } } template __global__ void addArraysScalarRhs(size_t elements, Destination *dst, LHS *lhs, RHS rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] + rhs; - // printf("%d + %d = %d\n", lhs[kernelIndex], rhs, dst[kernelIndex]); - } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { + dst[kernelIndex] = lhs[kernelIndex] + rhs; + // printf("%d + %d = %d\n", lhs[kernelIndex], rhs, dst[kernelIndex]); + } } template __global__ void addArraysScalarLhs(size_t elements, Destination *dst, LHS lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs + rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs + rhs[kernelIndex]; } } template __global__ void subArrays(size_t elements, Destination *dst, LHS *lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] - rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] - rhs[kernelIndex]; } } template __global__ void subArraysScalarLhs(size_t elements, Destination *dst, LHS lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs - rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs - rhs[kernelIndex]; } } template __global__ void subArraysScalarRhs(size_t elements, Destination *dst, LHS *lhs, RHS rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] - rhs; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] - rhs; } } template __global__ void mulArrays(size_t elements, Destination *dst, LHS *lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] * rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] * rhs[kernelIndex]; } } template __global__ void mulArraysScalarLhs(size_t elements, Destination *dst, LHS lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs * rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs * rhs[kernelIndex]; } } template __global__ void mulArraysScalarRhs(size_t elements, Destination *dst, LHS *lhs, RHS rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] * rhs; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] * rhs; } } template __global__ void divArrays(size_t elements, Destination *dst, LHS *lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] / rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] / rhs[kernelIndex]; } } template __global__ void divArraysScalarLhs(size_t elements, Destination *dst, LHS lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs / rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs / rhs[kernelIndex]; } } template __global__ void divArraysScalarRhs(size_t elements, Destination *dst, LHS *lhs, RHS rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] / rhs; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] / rhs; } } template __global__ void lessThanArrays(size_t elements, Destination *dst, LHS *lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] < rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] < rhs[kernelIndex]; } } template __global__ void lessThanArraysScalarLhs(size_t elements, Destination *dst, LHS lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs < rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs < rhs[kernelIndex]; } } template __global__ void lessThanArraysScalarRhs(size_t elements, Destination *dst, LHS *lhs, RHS rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] < rhs; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] < rhs; } } template __global__ void greaterThanArrays(size_t elements, Destination *dst, LHS *lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] > rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] > rhs[kernelIndex]; } } template __global__ void greaterThanArraysScalarLhs(size_t elements, Destination *dst, LHS lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs > rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs > rhs[kernelIndex]; } } template __global__ void greaterThanArraysScalarRhs(size_t elements, Destination *dst, LHS *lhs, RHS rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] > rhs; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] > rhs; } } template __global__ void lessThanEqualArrays(size_t elements, Destination *dst, LHS *lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] <= rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] <= rhs[kernelIndex]; } } template __global__ void lessThanEqualArraysScalarRhs(size_t elements, Destination *dst, LHS *lhs, RHS rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] <= rhs; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] <= rhs; } } template __global__ void lessThanEqualArraysScalarLhs(size_t elements, Destination *dst, LHS lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs <= rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs <= rhs[kernelIndex]; } } template __global__ void greaterThanEqualArrays(size_t elements, Destination *dst, LHS *lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] >= rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] >= rhs[kernelIndex]; } } template __global__ void greaterThanEqualArraysScalarLhs(size_t elements, Destination *dst, LHS lhs, - RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs >= rhs[kernelIndex]; } + RHS *rhs) { + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs >= rhs[kernelIndex]; } } template __global__ void greaterThanEqualArraysScalarRhs(size_t elements, Destination *dst, LHS *lhs, - RHS rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] >= rhs; } + RHS rhs) { + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] >= rhs; } } template __global__ void elementWiseEqualArrays(size_t elements, Destination *dst, LHS *lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] == rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] == rhs[kernelIndex]; } } template __global__ void elementWiseEqualArraysScalarLhs(size_t elements, Destination *dst, LHS lhs, - RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs == rhs[kernelIndex]; } + RHS *rhs) { + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs == rhs[kernelIndex]; } } template __global__ void elementWiseEqualArraysScalarRhs(size_t elements, Destination *dst, LHS *lhs, - RHS rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] == rhs; } + RHS rhs) { + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] == rhs; } } template __global__ void elementWiseNotEqualArrays(size_t elements, Destination *dst, LHS *lhs, RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] != rhs[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] != rhs[kernelIndex]; } } template __global__ void elementWiseNotEqualArraysScalarLhs(size_t elements, Destination *dst, LHS lhs, - RHS *rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs != rhs[kernelIndex]; } + RHS *rhs) { + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs != rhs[kernelIndex]; } } template __global__ void elementWiseNotEqualArraysScalarRhs(size_t elements, Destination *dst, LHS *lhs, - RHS rhs) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] != rhs; } + RHS rhs) { + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = lhs[kernelIndex] != rhs; } } diff --git a/librapid/include/librapid/cuda/kernels/expLogPow.cu b/librapid/include/librapid/cuda/kernels/expLogPow.cu index 85bde94f..0b075478 100644 --- a/librapid/include/librapid/cuda/kernels/expLogPow.cu +++ b/librapid/include/librapid/cuda/kernels/expLogPow.cu @@ -1,53 +1,53 @@ template __global__ void expArrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = exp(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = exp(src[kernelIndex]); } } template __global__ void logArrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = log(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = log(src[kernelIndex]); } } template __global__ void log2Arrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = log2(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = log2(src[kernelIndex]); } } template __global__ void log10Arrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = log10(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = log10(src[kernelIndex]); } } template __global__ void sqrtArrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = sqrt(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = sqrt(src[kernelIndex]); } } template __global__ void cbrtArrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = cbrt(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = cbrt(src[kernelIndex]); } } template __global__ void absArrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = abs(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = abs(src[kernelIndex]); } } template __global__ void floorArrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = floor(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = floor(src[kernelIndex]); } } template __global__ void ceilArrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = ceil(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = ceil(src[kernelIndex]); } } diff --git a/librapid/include/librapid/cuda/kernels/fill.cu b/librapid/include/librapid/cuda/kernels/fill.cu index 169aa812..eebc573f 100644 --- a/librapid/include/librapid/cuda/kernels/fill.cu +++ b/librapid/include/librapid/cuda/kernels/fill.cu @@ -8,16 +8,16 @@ Implements Mersenne twister generator. M. Matsumoto, T. Nishimura, Mersenne twister: a 623-dimensionally equidistributed uniform pseudo-random number generator, ACM Transactions on Modeling and Computer Simulation (TOMACS) 8 (1) (1998) 3–30. - */ + */ #define RNG32 -#define MT19937_FLOAT_MULTI 2.3283064365386962890625e-10f +#define MT19937_FLOAT_MULTI 2.3283064365386962890625e-10f #define MT19937_DOUBLE2_MULTI 2.3283064365386962890625e-10 #define MT19937_DOUBLE_MULTI 5.4210108624275221700372640e-20 -#define MT19937_N 624 -#define MT19937_M 397 +#define MT19937_N 624 +#define MT19937_M 397 #define MT19937_MATRIX_A 0x9908b0df /* constant vector a */ #define MT19937_UPPER_MASK 0x80000000 /* most significant w-r bits */ #define MT19937_LOWER_MASK 0x7fffffff /* least significant r bits */ @@ -26,8 +26,8 @@ pseudo-random number generator, ACM Transactions on Modeling and Computer Simula State of MT19937 RNG. */ typedef struct { - uint32_t mt[MT19937_N]; /* the array for the state vector */ - int mti; + uint32_t mt[MT19937_N]; /* the array for the state vector */ + int mti; } mt19937_state; /** @@ -37,33 +37,33 @@ Generates a random 32-bit unsigned integer using MT19937 RNG. */ #define mt19937_uint(state) _mt19937_uint(&state) uint32_t _mt19937_uint(mt19937_state *state) { - uint32_t y; - uint32_t mag01[2] = {0x0, MT19937_MATRIX_A}; - /* mag01[x] = x * MT19937_MATRIX_A for x=0,1 */ - - if (state->mti < MT19937_N - MT19937_M) { - y = (state->mt[state->mti] & MT19937_UPPER_MASK) | - (state->mt[state->mti + 1] & MT19937_LOWER_MASK); - state->mt[state->mti] = state->mt[state->mti + MT19937_M] ^ (y >> 1) ^ mag01[y & 0x1]; - } else if (state->mti < MT19937_N - 1) { - y = (state->mt[state->mti] & MT19937_UPPER_MASK) | - (state->mt[state->mti + 1] & MT19937_LOWER_MASK); - state->mt[state->mti] = - state->mt[state->mti + (MT19937_M - MT19937_N)] ^ (y >> 1) ^ mag01[y & 0x1]; - } else { - y = (state->mt[MT19937_N - 1] & MT19937_UPPER_MASK) | (state->mt[0] & MT19937_LOWER_MASK); - state->mt[MT19937_N - 1] = state->mt[MT19937_M - 1] ^ (y >> 1) ^ mag01[y & 0x1]; - state->mti = 0; - } - y = state->mt[state->mti++]; - - /* Tempering */ - y ^= (y >> 11); - y ^= (y << 7) & 0x9d2c5680; - y ^= (y << 15) & 0xefc60000; - y ^= (y >> 18); - - return y; + uint32_t y; + uint32_t mag01[2] = {0x0, MT19937_MATRIX_A}; + /* mag01[x] = x * MT19937_MATRIX_A for x=0,1 */ + + if (state->mti < MT19937_N - MT19937_M) { + y = (state->mt[state->mti] & MT19937_UPPER_MASK) | + (state->mt[state->mti + 1] & MT19937_LOWER_MASK); + state->mt[state->mti] = state->mt[state->mti + MT19937_M] ^ (y >> 1) ^ mag01[y & 0x1]; + } else if (state->mti < MT19937_N - 1) { + y = (state->mt[state->mti] & MT19937_UPPER_MASK) | + (state->mt[state->mti + 1] & MT19937_LOWER_MASK); + state->mt[state->mti] = + state->mt[state->mti + (MT19937_M - MT19937_N)] ^ (y >> 1) ^ mag01[y & 0x1]; + } else { + y = (state->mt[MT19937_N - 1] & MT19937_UPPER_MASK) | (state->mt[0] & MT19937_LOWER_MASK); + state->mt[MT19937_N - 1] = state->mt[MT19937_M - 1] ^ (y >> 1) ^ mag01[y & 0x1]; + state->mti = 0; + } + y = state->mt[state->mti++]; + + /* Tempering */ + y ^= (y >> 11); + y ^= (y << 7) & 0x9d2c5680; + y ^= (y << 15) & 0xefc60000; + y ^= (y >> 18); + + return y; } /** Generates a random 32-bit unsigned integer using MT19937 RNG. @@ -74,36 +74,36 @@ This is alternative implementation of MT19937 RNG, that generates 32 values in s */ #define mt19937_loop_uint(state) _mt19937_loop_uint(&state) uint32_t _mt19937_loop_uint(mt19937_state *state) { - uint32_t y; - uint32_t mag01[2] = {0x0, MT19937_MATRIX_A}; - /* mag01[x] = x * MT19937_MATRIX_A for x=0,1 */ - - if (state->mti >= MT19937_N) { - int kk; - - for (kk = 0; kk < MT19937_N - MT19937_M; kk++) { - y = (state->mt[kk] & MT19937_UPPER_MASK) | (state->mt[kk + 1] & MT19937_LOWER_MASK); - state->mt[kk] = state->mt[kk + MT19937_M] ^ (y >> 1) ^ mag01[y & 0x1]; - } - for (; kk < MT19937_N - 1; kk++) { - y = (state->mt[kk] & MT19937_UPPER_MASK) | (state->mt[kk + 1] & MT19937_LOWER_MASK); - state->mt[kk] = state->mt[kk + (MT19937_M - MT19937_N)] ^ (y >> 1) ^ mag01[y & 0x1]; - } - y = (state->mt[MT19937_N - 1] & MT19937_UPPER_MASK) | (state->mt[0] & MT19937_LOWER_MASK); - state->mt[MT19937_N - 1] = state->mt[MT19937_M - 1] ^ (y >> 1) ^ mag01[y & 0x1]; - - state->mti = 0; - } - - y = state->mt[state->mti++]; - - /* Tempering */ - y ^= (y >> 11); - y ^= (y << 7) & 0x9d2c5680; - y ^= (y << 15) & 0xefc60000; - y ^= (y >> 18); - - return y; + uint32_t y; + uint32_t mag01[2] = {0x0, MT19937_MATRIX_A}; + /* mag01[x] = x * MT19937_MATRIX_A for x=0,1 */ + + if (state->mti >= MT19937_N) { + int kk; + + for (kk = 0; kk < MT19937_N - MT19937_M; kk++) { + y = (state->mt[kk] & MT19937_UPPER_MASK) | (state->mt[kk + 1] & MT19937_LOWER_MASK); + state->mt[kk] = state->mt[kk + MT19937_M] ^ (y >> 1) ^ mag01[y & 0x1]; + } + for (; kk < MT19937_N - 1; kk++) { + y = (state->mt[kk] & MT19937_UPPER_MASK) | (state->mt[kk + 1] & MT19937_LOWER_MASK); + state->mt[kk] = state->mt[kk + (MT19937_M - MT19937_N)] ^ (y >> 1) ^ mag01[y & 0x1]; + } + y = (state->mt[MT19937_N - 1] & MT19937_UPPER_MASK) | (state->mt[0] & MT19937_LOWER_MASK); + state->mt[MT19937_N - 1] = state->mt[MT19937_M - 1] ^ (y >> 1) ^ mag01[y & 0x1]; + + state->mti = 0; + } + + y = state->mt[state->mti++]; + + /* Tempering */ + y ^= (y >> 11); + y ^= (y << 7) & 0x9d2c5680; + y ^= (y << 15) & 0xefc60000; + y ^= (y >> 18); + + return y; } /** @@ -114,17 +114,17 @@ Seeds MT19937 RNG. (thread). */ void mt19937_seed(mt19937_state *state, uint32_t s) { - state->mt[0] = s; - uint32_t mti; - for (mti = 1; mti < MT19937_N; mti++) { - state->mt[mti] = 1812433253 * (state->mt[mti - 1] ^ (state->mt[mti - 1] >> 30)) + mti; - - /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ - /* In the previous versions, MSBs of the seed affect */ - /* only MSBs of the array mt19937[]. */ - /* 2002/01/09 modified by Makoto Matsumoto */ - } - state->mti = mti; + state->mt[0] = s; + uint32_t mti; + for (mti = 1; mti < MT19937_N; mti++) { + state->mt[mti] = 1812433253 * (state->mt[mti - 1] ^ (state->mt[mti - 1] >> 30)) + mti; + + /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ + /* In the previous versions, MSBs of the seed affect */ + /* only MSBs of the array mt19937[]. */ + /* 2002/01/09 modified by Makoto Matsumoto */ + } + state->mti = mti; } /** @@ -150,48 +150,46 @@ Generates a random double using MT19937 RNG. template __global__ void fillArray(size_t elements, Destination *dst, Source value) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = value; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = value; } } void print_binary_16bit(int number) { int i; - for (i = 15; i >= 0; i--) { - printf((number & (1 << i)) ? "1" : "0"); - } - printf("\n"); + for (i = 15; i >= 0; i--) { printf((number & (1 << i)) ? "1" : "0"); } + printf("\n"); } template __global__ void fillRandom(T *data, int64_t elements, Lower lower, Upper upper, int64_t *seeds, - int64_t numSeeds) { - int64_t gid = blockDim.x * blockIdx.x + threadIdx.x; - int64_t seedIndex = gid % numSeeds; - mt19937_state state; - mt19937_seed(&state, seeds[seedIndex]); - - for (int64_t i = gid; i < elements; i += blockDim.x * gridDim.x) { - data[i] = (T)(mt19937_double(state) * (upper - lower) + lower); - } - - // Change the seed for the next thread - seeds[seedIndex] = mt19937_ulong(state); + int64_t numSeeds) { + int64_t gid = blockDim.x * blockIdx.x + threadIdx.x; + int64_t seedIndex = gid % numSeeds; + mt19937_state state; + mt19937_seed(&state, seeds[seedIndex]); + + for (int64_t i = gid; i < elements; i += blockDim.x * gridDim.x) { + data[i] = (T)(mt19937_double(state) * (upper - lower) + lower); + } + + // Change the seed for the next thread + seeds[seedIndex] = mt19937_ulong(state); } template __global__ void fillRandomHalf(T *data, int64_t elements, Lower lower, Upper upper, int64_t *seeds, - int64_t numSeeds) { - int64_t gid = blockDim.x * blockIdx.x + threadIdx.x; - int64_t seedIndex = gid % numSeeds; - mt19937_state state; - mt19937_seed(&state, seeds[seedIndex]); - - for (int64_t i = gid; i < elements; i += blockDim.x * gridDim.x) { - float lowerF = (float)lower; - float upperF = (float)upper; - data[i] = (T)(mt19937_float(state) * (upperF - lowerF) + lowerF); - } - - // Change the seed for the next thread - seeds[seedIndex] = mt19937_ulong(state); + int64_t numSeeds) { + int64_t gid = blockDim.x * blockIdx.x + threadIdx.x; + int64_t seedIndex = gid % numSeeds; + mt19937_state state; + mt19937_seed(&state, seeds[seedIndex]); + + for (int64_t i = gid; i < elements; i += blockDim.x * gridDim.x) { + float lowerF = (float)lower; + float upperF = (float)upper; + data[i] = (T)(mt19937_float(state) * (upperF - lowerF) + lowerF); + } + + // Change the seed for the next thread + seeds[seedIndex] = mt19937_ulong(state); } diff --git a/librapid/include/librapid/cuda/kernels/floorCeilRound.cu b/librapid/include/librapid/cuda/kernels/floorCeilRound.cu index fdeddcd0..9eca62db 100644 --- a/librapid/include/librapid/cuda/kernels/floorCeilRound.cu +++ b/librapid/include/librapid/cuda/kernels/floorCeilRound.cu @@ -1,12 +1,12 @@ template __global__ void floorArrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = floor(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = floor(src[kernelIndex]); } } template __global__ void ceilArrays(size_t elements, Destination *dst, Data *src) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = ceil(src[kernelIndex]); } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = ceil(src[kernelIndex]); } } diff --git a/librapid/include/librapid/cuda/kernels/kernelHelper.cuh b/librapid/include/librapid/cuda/kernels/kernelHelper.cuh index 383031cc..3b294631 100644 --- a/librapid/include/librapid/cuda/kernels/kernelHelper.cuh +++ b/librapid/include/librapid/cuda/kernels/kernelHelper.cuh @@ -1,7 +1,7 @@ #ifndef LIBRAPID_CUDA_KERNEL_HELPER #define LIBRAPID_CUDA_KERNEL_HELPER -#define LIBRAPID_INLINE inline +#define LIBRAPID_INLINE inline #define LIBRAPID_ALWAYS_INLINE inline #define LIBRAPID_NODISCARD @@ -12,7 +12,7 @@ #include namespace librapid { - using half = __half; + using half = __half; } template diff --git a/librapid/include/librapid/cuda/kernels/negate.cu b/librapid/include/librapid/cuda/kernels/negate.cu index ef033ea7..0423e9f2 100644 --- a/librapid/include/librapid/cuda/kernels/negate.cu +++ b/librapid/include/librapid/cuda/kernels/negate.cu @@ -3,66 +3,66 @@ template __global__ void negateArrays(size_t elements, Destination *dst, DATA *data) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { dst[kernelIndex] = -data[kernelIndex]; } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { dst[kernelIndex] = -data[kernelIndex]; } } template<> __global__ void negateArrays(size_t elements, float2 *dst, float2 *data) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { - dst[kernelIndex].x = -data[kernelIndex].x; - dst[kernelIndex].y = -data[kernelIndex].y; - } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { + dst[kernelIndex].x = -data[kernelIndex].x; + dst[kernelIndex].y = -data[kernelIndex].y; + } } template<> __global__ void negateArrays(size_t elements, float3 *dst, float3 *data) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { - dst[kernelIndex].x = -data[kernelIndex].x; - dst[kernelIndex].y = -data[kernelIndex].y; - dst[kernelIndex].z = -data[kernelIndex].z; - } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { + dst[kernelIndex].x = -data[kernelIndex].x; + dst[kernelIndex].y = -data[kernelIndex].y; + dst[kernelIndex].z = -data[kernelIndex].z; + } } template<> __global__ void negateArrays(size_t elements, float4 *dst, float4 *data) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { - dst[kernelIndex].x = -data[kernelIndex].x; - dst[kernelIndex].y = -data[kernelIndex].y; - dst[kernelIndex].z = -data[kernelIndex].z; - dst[kernelIndex].w = -data[kernelIndex].w; - } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { + dst[kernelIndex].x = -data[kernelIndex].x; + dst[kernelIndex].y = -data[kernelIndex].y; + dst[kernelIndex].z = -data[kernelIndex].z; + dst[kernelIndex].w = -data[kernelIndex].w; + } } template<> __global__ void negateArrays(size_t elements, double2 *dst, double2 *data) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { - dst[kernelIndex].x = -data[kernelIndex].x; - dst[kernelIndex].y = -data[kernelIndex].y; - } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { + dst[kernelIndex].x = -data[kernelIndex].x; + dst[kernelIndex].y = -data[kernelIndex].y; + } } template<> __global__ void negateArrays(size_t elements, double3 *dst, double3 *data) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { - dst[kernelIndex].x = -data[kernelIndex].x; - dst[kernelIndex].y = -data[kernelIndex].y; - dst[kernelIndex].z = -data[kernelIndex].z; - } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { + dst[kernelIndex].x = -data[kernelIndex].x; + dst[kernelIndex].y = -data[kernelIndex].y; + dst[kernelIndex].z = -data[kernelIndex].z; + } } template<> __global__ void negateArrays(size_t elements, double4 *dst, double4 *data) { - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; - if (kernelIndex < elements) { - dst[kernelIndex].x = -data[kernelIndex].x; - dst[kernelIndex].y = -data[kernelIndex].y; - dst[kernelIndex].z = -data[kernelIndex].z; - dst[kernelIndex].w = -data[kernelIndex].w; - } + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; + if (kernelIndex < elements) { + dst[kernelIndex].x = -data[kernelIndex].x; + dst[kernelIndex].y = -data[kernelIndex].y; + dst[kernelIndex].z = -data[kernelIndex].z; + dst[kernelIndex].w = -data[kernelIndex].w; + } } diff --git a/librapid/include/librapid/cuda/kernels/trigonometry.cu b/librapid/include/librapid/cuda/kernels/trigonometry.cu index 58d67efe..16f828e3 100644 --- a/librapid/include/librapid/cuda/kernels/trigonometry.cu +++ b/librapid/include/librapid/cuda/kernels/trigonometry.cu @@ -1,9 +1,9 @@ #define TRIG_IMPL(NAME) \ - template \ - __global__ void NAME##Arrays(size_t elements, Destination *dst, Data *data) { \ - const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; \ - if (kernelIndex < elements) { dst[kernelIndex] = NAME(data[kernelIndex]); } \ - } + template \ + __global__ void NAME##Arrays(size_t elements, Destination *dst, Data *data) { \ + const size_t kernelIndex = blockDim.x * blockIdx.x + threadIdx.x; \ + if (kernelIndex < elements) { dst[kernelIndex] = NAME(data[kernelIndex]); } \ + } TRIG_IMPL(sin) TRIG_IMPL(cos) diff --git a/librapid/include/librapid/cuda/kernels/vectorOps.cuh b/librapid/include/librapid/cuda/kernels/vectorOps.cuh index d1bb7985..1c21e503 100644 --- a/librapid/include/librapid/cuda/kernels/vectorOps.cuh +++ b/librapid/include/librapid/cuda/kernels/vectorOps.cuh @@ -28,359 +28,359 @@ #define IOF(R, O) inline R operator O -IOF(F2,+)(CO F2 &a,CO F2 &b){MF2(a.x+b.x,a.y+b.y);} -IOF(F2,+)(CO F2 &a,CO FL &b){MF2(a.x+b,a.y+b);} -IOF(F2,+)(CO FL &a,CO F2 &b){MF2(a+b.x,a+b.y);} -IOF(F3,+)(CO F3 &a,CO F3 &b){MF3(a.x+b.x,a.y+b.y,a.z+b.z);} -IOF(F3,+)(CO F3 &a,CO FL &b){MF3(a.x+b,a.y+b,a.z+b);} -IOF(F3,+)(CO FL &a,CO F3 &b){MF3(a+b.x,a+b.y,a+b.z);} -IOF(F4,+)(CO F4 &a,CO F4 &b){MF4(a.x+b.x,a.y+b.y,a.z+b.z,a.w+b.w);} -IOF(F4,+)(CO F4 &a,CO FL &b){MF4(a.x+b,a.y+b,a.z+b,a.w+b);} -IOF(F4,+)(CO FL &a,CO F4 &b){MF4(a+b.x,a+b.y,a+b.z,a+b.w);} -IOF(F2,-)(CO F2 &a,CO F2 &b){MF2(a.x-b.x,a.y-b.y);} -IOF(F2,-)(CO F2 &a,CO FL &b){MF2(a.x-b,a.y-b);} -IOF(F2,-)(CO FL &a,CO F2 &b){MF2(a-b.x,a-b.y);} -IOF(F3,-)(CO F3 &a,CO F3 &b){MF3(a.x-b.x,a.y-b.y,a.z-b.z);} -IOF(F3,-)(CO F3 &a,CO FL &b){MF3(a.x-b,a.y-b,a.z-b);} -IOF(F3,-)(CO FL &a,CO F3 &b){MF3(a-b.x,a-b.y,a-b.z);} -IOF(F4,-)(CO F4 &a,CO F4 &b){MF4(a.x-b.x,a.y-b.y,a.z-b.z,a.w-b.w);} -IOF(F4,-)(CO F4 &a,CO FL &b){MF4(a.x-b,a.y-b,a.z-b,a.w-b);} -IOF(F4,-)(CO FL &a,CO F4 &b){MF4(a-b.x,a-b.y,a-b.z,a-b.w);} -IOF(F2,*)(CO F2 &a,CO F2 &b){MF2(a.x*b.x,a.y*b.y);} -IOF(F2,*)(CO F2 &a,CO FL &b){MF2(a.x*b,a.y*b);} -IOF(F2,*)(CO FL &a,CO F2 &b){MF2(a*b.x,a*b.y);} -IOF(F3,*)(CO F3 &a,CO F3 &b){MF3(a.x*b.x,a.y*b.y,a.z*b.z);} -IOF(F3,*)(CO F3 &a,CO FL &b){MF3(a.x*b,a.y*b,a.z*b);} -IOF(F3,*)(CO FL &a,CO F3 &b){MF3(a*b.x,a*b.y,a*b.z);} -IOF(F4,*)(CO F4 &a,CO F4 &b){MF4(a.x*b.x,a.y*b.y,a.z*b.z,a.w*b.w);} -IOF(F4,*)(CO F4 &a,CO FL &b){MF4(a.x*b,a.y*b,a.z*b,a.w*b);} -IOF(F4,*)(CO FL &a,CO F4 &b){MF4(a*b.x,a*b.y,a*b.z,a*b.w);} -IOF(F2,/)(CO F2 &a,CO F2 &b){MF2(a.x/b.x,a.y/b.y);} -IOF(F2,/)(CO F2 &a,CO FL &b){MF2(a.x/b,a.y/b);} -IOF(F2,/)(CO FL &a,CO F2 &b){MF2(a/b.x,a/b.y);} -IOF(F3,/)(CO F3 &a,CO F3 &b){MF3(a.x/b.x,a.y/b.y,a.z/b.z);} -IOF(F3,/)(CO F3 &a,CO FL &b){MF3(a.x/b,a.y/b,a.z/b);} -IOF(F3,/)(CO FL &a,CO F3 &b){MF3(a/b.x,a/b.y,a/b.z);} -IOF(F4,/)(CO F4 &a,CO F4 &b){MF4(a.x/b.x,a.y/b.y,a.z/b.z,a.w/b.w);} -IOF(F4,/)(CO F4 &a,CO FL &b){MF4(a.x/b,a.y/b,a.z/b,a.w/b);} -IOF(F4,/)(CO FL &a,CO F4 &b){MF4(a/b.x,a/b.y,a/b.z,a/b.w);} -IOF(F2,+=)(F2 &a,CO F2 &b){MF2N(a.x+b.x,a.y+b.y);} -IOF(F2,+=)(F2 &a,CO FL &b){MF2N(a.x + b,a.y + b);} -IOF(F3,+=)(F3 &a,CO F3 &b){MF3N(a.x+b.x,a.y+b.y,a.z+b.z);} -IOF(F3,+=)(F3 &a,CO FL &b){MF3N(a.x + b,a.y + b,a.z + b);} -IOF(F4,+=)(F4 &a,CO F4 &b){MF4N(a.x+b.x,a.y+b.y,a.z+b.z,a.w+b.w);} -IOF(F4,+=)(F4 &a,CO FL &b){MF4N(a.x + b,a.y + b,a.z + b,a.w + b);} -IOF(F2,-=)(F2 &a,CO F2 &b){MF2N(a.x-b.x,a.y-b.y);} -IOF(F2,-=)(F2 &a,CO FL &b){MF2N(a.x - b,a.y - b);} -IOF(F3,-=)(F3 &a,CO F3 &b){MF3N(a.x-b.x,a.y-b.y,a.z-b.z);} -IOF(F3,-=)(F3 &a,CO FL &b){MF3N(a.x - b,a.y - b,a.z - b);} -IOF(F4,-=)(F4 &a,CO F4 &b){MF4N(a.x-b.x,a.y-b.y,a.z-b.z,a.w-b.w);} -IOF(F4,-=)(F4 &a,CO FL &b){MF4N(a.x - b,a.y - b,a.z - b,a.w - b);} -IOF(F2,*=)(F2 &a,CO F2 &b){MF2N(a.x*b.x,a.y*b.y);} -IOF(F2,*=)(F2 &a,CO FL &b){MF2N(a.x * b,a.y * b);} -IOF(F3,*=)(F3 &a,CO F3 &b){MF3N(a.x*b.x,a.y*b.y,a.z*b.z);} -IOF(F3,*=)(F3 &a,CO FL &b){MF3N(a.x * b,a.y * b,a.z * b);} -IOF(F4,*=)(F4 &a,CO F4 &b){MF4N(a.x*b.x,a.y*b.y,a.z*b.z,a.w*b.w);} -IOF(F4,*=)(F4 &a,CO FL &b){MF4N(a.x * b,a.y * b,a.z * b,a.w * b);} -IOF(F2,/=)(F2 &a,CO F2 &b){MF2N(a.x/b.x,a.y/b.y);} -IOF(F2,/=)(F2 &a,CO FL &b){MF2N(a.x / b,a.y / b);} -IOF(F3,/=)(F3 &a,CO F3 &b){MF3N(a.x/b.x,a.y/b.y,a.z/b.z);} -IOF(F3,/=)(F3 &a,CO FL &b){MF3N(a.x / b,a.y / b,a.z / b);} -IOF(F4,/=)(F4 &a,CO F4 &b){MF4N(a.x/b.x,a.y/b.y,a.z/b.z,a.w/b.w);} -IOF(F4,/=)(F4 &a,CO FL &b){MF4N(a.x / b,a.y / b,a.z / b,a.w / b);} -IOF(F2,>)(CO F2 &a,CO F2 &b){MF2(a.x>b.x,a.y>b.y);} -IOF(F2,>)(CO F2 &a,CO FL &b){MF2(a.x>b,a.y>b);} -IOF(F2,>)(CO FL &a,CO F2 &b){MF2(a>b.x,a>b.y);} -IOF(F3,>)(CO F3 &a,CO F3 &b){MF3(a.x>b.x,a.y>b.y,a.z>b.z);} -IOF(F3,>)(CO F3 &a,CO FL &b){MF3(a.x>b,a.y>b,a.z>b);} -IOF(F3,>)(CO FL &a,CO F3 &b){MF3(a>b.x,a>b.y,a>b.z);} -IOF(F4,>)(CO F4 &a,CO F4 &b){MF4(a.x>b.x,a.y>b.y,a.z>b.z,a.w>b.w);} -IOF(F4,>)(CO F4 &a,CO FL &b){MF4(a.x>b,a.y>b,a.z>b,a.w>b);} -IOF(F4,>)(CO FL &a,CO F4 &b){MF4(a>b.x,a>b.y,a>b.z,a>b.w);} -IOF(F2,<)(CO F2 &a,CO F2 &b){MF2(a.x=)(CO F2 &a,CO F2 &b){MF2(a.x>=b.x,a.y>=b.y);} -IOF(F2,>=)(CO F2 &a,CO FL &b){MF2(a.x>=b,a.y>=b);} -IOF(F2,>=)(CO FL &a,CO F2 &b){MF2(a>=b.x,a>=b.y);} -IOF(F3,>=)(CO F3 &a,CO F3 &b){MF3(a.x>=b.x,a.y>=b.y,a.z>=b.z);} -IOF(F3,>=)(CO F3 &a,CO FL &b){MF3(a.x>=b,a.y>=b,a.z>=b);} -IOF(F3,>=)(CO FL &a,CO F3 &b){MF3(a>=b.x,a>=b.y,a>=b.z);} -IOF(F4,>=)(CO F4 &a,CO F4 &b){MF4(a.x>=b.x,a.y>=b.y,a.z>=b.z,a.w>=b.w);} -IOF(F4,>=)(CO F4 &a,CO FL &b){MF4(a.x>=b,a.y>=b,a.z>=b,a.w>=b);} -IOF(F4,>=)(CO FL &a,CO F4 &b){MF4(a>=b.x,a>=b.y,a>=b.z,a>=b.w);} -IOF(F2,<=)(CO F2 &a,CO F2 &b){MF2(a.x<=b.x,a.y<=b.y);} -IOF(F2,<=)(CO F2 &a,CO FL &b){MF2(a.x<=b,a.y<=b);} -IOF(F2,<=)(CO FL &a,CO F2 &b){MF2(a<=b.x,a<=b.y);} -IOF(F3,<=)(CO F3 &a,CO F3 &b){MF3(a.x<=b.x,a.y<=b.y,a.z<=b.z);} -IOF(F3,<=)(CO F3 &a,CO FL &b){MF3(a.x<=b,a.y<=b,a.z<=b);} -IOF(F3,<=)(CO FL &a,CO F3 &b){MF3(a<=b.x,a<=b.y,a<=b.z);} -IOF(F4,<=)(CO F4 &a,CO F4 &b){MF4(a.x<=b.x,a.y<=b.y,a.z<=b.z,a.w<=b.w);} -IOF(F4,<=)(CO F4 &a,CO FL &b){MF4(a.x<=b,a.y<=b,a.z<=b,a.w<=b);} -IOF(F4,<=)(CO FL &a,CO F4 &b){MF4(a<=b.x,a<=b.y,a<=b.z,a<=b.w);} -IOF(F2,==)(CO F2 &a,CO F2 &b){MF2(a.x==b.x,a.y==b.y);} -IOF(F2,==)(CO F2 &a,CO FL &b){MF2(a.x==b,a.y==b);} -IOF(F2,==)(CO FL &a,CO F2 &b){MF2(a==b.x,a==b.y);} -IOF(F3,==)(CO F3 &a,CO F3 &b){MF3(a.x==b.x,a.y==b.y,a.z==b.z);} -IOF(F3,==)(CO F3 &a,CO FL &b){MF3(a.x==b,a.y==b,a.z==b);} -IOF(F3,==)(CO FL &a,CO F3 &b){MF3(a==b.x,a==b.y,a==b.z);} -IOF(F4,==)(CO F4 &a,CO F4 &b){MF4(a.x==b.x,a.y==b.y,a.z==b.z,a.w==b.w);} -IOF(F4,==)(CO F4 &a,CO FL &b){MF4(a.x==b,a.y==b,a.z==b,a.w==b);} -IOF(F4,==)(CO FL &a,CO F4 &b){MF4(a==b.x,a==b.y,a==b.z,a==b.w);} -IOF(F2,!=)(CO F2 &a,CO F2 &b){MF2(a.x!=b.x,a.y!=b.y);} -IOF(F2,!=)(CO F2 &a,CO FL &b){MF2(a.x!=b,a.y!=b);} -IOF(F2,!=)(CO FL &a,CO F2 &b){MF2(a!=b.x,a!=b.y);} -IOF(F3,!=)(CO F3 &a,CO F3 &b){MF3(a.x!=b.x,a.y!=b.y,a.z!=b.z);} -IOF(F3,!=)(CO F3 &a,CO FL &b){MF3(a.x!=b,a.y!=b,a.z!=b);} -IOF(F3,!=)(CO FL &a,CO F3 &b){MF3(a!=b.x,a!=b.y,a!=b.z);} -IOF(F4,!=)(CO F4 &a,CO F4 &b){MF4(a.x!=b.x,a.y!=b.y,a.z!=b.z,a.w!=b.w);} -IOF(F4,!=)(CO F4 &a,CO FL &b){MF4(a.x!=b,a.y!=b,a.z!=b,a.w!=b);} -IOF(F4,!=)(CO FL &a,CO F4 &b){MF4(a!=b.x,a!=b.y,a!=b.z,a!=b.w);} -inline F2 sin(CO F2 &a){MF2(sin(a.x),sin(a.y));} -inline F3 sin(CO F3 &a){MF3(sin(a.x),sin(a.y),sin(a.z));} -inline F4 sin(CO F4 &a){MF4(sin(a.x),sin(a.y),sin(a.z),sin(a.w));} -inline F2 cos(CO F2 &a){MF2(cos(a.x),cos(a.y));} -inline F3 cos(CO F3 &a){MF3(cos(a.x),cos(a.y),cos(a.z));} -inline F4 cos(CO F4 &a){MF4(cos(a.x),cos(a.y),cos(a.z),cos(a.w));} -inline F2 tan(CO F2 &a){MF2(tan(a.x),tan(a.y));} -inline F3 tan(CO F3 &a){MF3(tan(a.x),tan(a.y),tan(a.z));} -inline F4 tan(CO F4 &a){MF4(tan(a.x),tan(a.y),tan(a.z),tan(a.w));} -inline F2 asin(CO F2 &a){MF2(asin(a.x),asin(a.y));} -inline F3 asin(CO F3 &a){MF3(asin(a.x),asin(a.y),asin(a.z));} -inline F4 asin(CO F4 &a){MF4(asin(a.x),asin(a.y),asin(a.z),asin(a.w));} -inline F2 acos(CO F2 &a){MF2(acos(a.x),acos(a.y));} -inline F3 acos(CO F3 &a){MF3(acos(a.x),acos(a.y),acos(a.z));} -inline F4 acos(CO F4 &a){MF4(acos(a.x),acos(a.y),acos(a.z),acos(a.w));} -inline F2 atan(CO F2 &a){MF2(atan(a.x),atan(a.y));} -inline F3 atan(CO F3 &a){MF3(atan(a.x),atan(a.y),atan(a.z));} -inline F4 atan(CO F4 &a){MF4(atan(a.x),atan(a.y),atan(a.z),atan(a.w));} -inline F2 sinh(CO F2 &a){MF2(sinh(a.x),sinh(a.y));} -inline F3 sinh(CO F3 &a){MF3(sinh(a.x),sinh(a.y),sinh(a.z));} -inline F4 sinh(CO F4 &a){MF4(sinh(a.x),sinh(a.y),sinh(a.z),sinh(a.w));} -inline F2 cosh(CO F2 &a){MF2(cosh(a.x),cosh(a.y));} -inline F3 cosh(CO F3 &a){MF3(cosh(a.x),cosh(a.y),cosh(a.z));} -inline F4 cosh(CO F4 &a){MF4(cosh(a.x),cosh(a.y),cosh(a.z),cosh(a.w));} -inline F2 tanh(CO F2 &a){MF2(tanh(a.x),tanh(a.y));} -inline F3 tanh(CO F3 &a){MF3(tanh(a.x),tanh(a.y),tanh(a.z));} -inline F4 tanh(CO F4 &a){MF4(tanh(a.x),tanh(a.y),tanh(a.z),tanh(a.w));} -inline F2 asinh(CO F2 &a){MF2(asinh(a.x),asinh(a.y));} -inline F3 asinh(CO F3 &a){MF3(asinh(a.x),asinh(a.y),asinh(a.z));} -inline F4 asinh(CO F4 &a){MF4(asinh(a.x),asinh(a.y),asinh(a.z),asinh(a.w));} -inline F2 acosh(CO F2 &a){MF2(acosh(a.x),acosh(a.y));} -inline F3 acosh(CO F3 &a){MF3(acosh(a.x),acosh(a.y),acosh(a.z));} -inline F4 acosh(CO F4 &a){MF4(acosh(a.x),acosh(a.y),acosh(a.z),acosh(a.w));} -inline F2 atanh(CO F2 &a){MF2(atanh(a.x),atanh(a.y));} -inline F3 atanh(CO F3 &a){MF3(atanh(a.x),atanh(a.y),atanh(a.z));} -inline F4 atanh(CO F4 &a){MF4(atanh(a.x),atanh(a.y),atanh(a.z),atanh(a.w));} -inline F2 exp(CO F2 &a){MF2(exp(a.x),exp(a.y));} -inline F3 exp(CO F3 &a){MF3(exp(a.x),exp(a.y),exp(a.z));} -inline F4 exp(CO F4 &a){MF4(exp(a.x),exp(a.y),exp(a.z),exp(a.w));} -inline F2 log(CO F2 &a){MF2(log(a.x),log(a.y));} -inline F3 log(CO F3 &a){MF3(log(a.x),log(a.y),log(a.z));} -inline F4 log(CO F4 &a){MF4(log(a.x),log(a.y),log(a.z),log(a.w));} -inline F2 log2(CO F2 &a){MF2(log2(a.x),log2(a.y));} -inline F3 log2(CO F3 &a){MF3(log2(a.x),log2(a.y),log2(a.z));} -inline F4 log2(CO F4 &a){MF4(log2(a.x),log2(a.y),log2(a.z),log2(a.w));} -inline F2 log10(CO F2 &a){MF2(log10(a.x),log10(a.y));} -inline F3 log10(CO F3 &a){MF3(log10(a.x),log10(a.y),log10(a.z));} -inline F4 log10(CO F4 &a){MF4(log10(a.x),log10(a.y),log10(a.z),log10(a.w));} -inline F2 sqrt(CO F2 &a){MF2(sqrt(a.x),sqrt(a.y));} -inline F3 sqrt(CO F3 &a){MF3(sqrt(a.x),sqrt(a.y),sqrt(a.z));} -inline F4 sqrt(CO F4 &a){MF4(sqrt(a.x),sqrt(a.y),sqrt(a.z),sqrt(a.w));} -inline F2 cbrt(CO F2 &a){MF2(cbrt(a.x),cbrt(a.y));} -inline F3 cbrt(CO F3 &a){MF3(cbrt(a.x),cbrt(a.y),cbrt(a.z));} -inline F4 cbrt(CO F4 &a){MF4(cbrt(a.x),cbrt(a.y),cbrt(a.z),cbrt(a.w));} -inline F2 abs(CO F2 &a){MF2(abs(a.x),abs(a.y));} -inline F3 abs(CO F3 &a){MF3(abs(a.x),abs(a.y),abs(a.z));} -inline F4 abs(CO F4 &a){MF4(abs(a.x),abs(a.y),abs(a.z),abs(a.w));} -inline F2 ceil(CO F2 &a){MF2(ceil(a.x),ceil(a.y));} -inline F3 ceil(CO F3 &a){MF3(ceil(a.x),ceil(a.y),ceil(a.z));} -inline F4 ceil(CO F4 &a){MF4(ceil(a.x),ceil(a.y),ceil(a.z),ceil(a.w));} -inline F2 floor(CO F2 &a){MF2(floor(a.x),floor(a.y));} -inline F3 floor(CO F3 &a){MF3(floor(a.x),floor(a.y),floor(a.z));} -inline F4 floor(CO F4 &a){MF4(floor(a.x),floor(a.y),floor(a.z),floor(a.w));} -IOF(D2,+)(CO D2 &a,CO D2 &b){MD2(a.x+b.x,a.y+b.y);} -IOF(D2,+)(CO D2 &a,CO DL &b){MD2(a.x+b,a.y+b);} -IOF(D2,+)(CO DL &a,CO D2 &b){MD2(a+b.x,a+b.y);} -IOF(D3,+)(CO D3 &a,CO D3 &b){MD3(a.x+b.x,a.y+b.y,a.z+b.z);} -IOF(D3,+)(CO D3 &a,CO DL &b){MD3(a.x+b,a.y+b,a.z+b);} -IOF(D3,+)(CO DL &a,CO D3 &b){MD3(a+b.x,a+b.y,a+b.z);} -IOF(D4,+)(CO D4 &a,CO D4 &b){MD4(a.x+b.x,a.y+b.y,a.z+b.z,a.w+b.w);} -IOF(D4,+)(CO D4 &a,CO DL &b){MD4(a.x+b,a.y+b,a.z+b,a.w+b);} -IOF(D4,+)(CO DL &a,CO D4 &b){MD4(a+b.x,a+b.y,a+b.z,a+b.w);} -IOF(D2,-)(CO D2 &a,CO D2 &b){MD2(a.x-b.x,a.y-b.y);} -IOF(D2,-)(CO D2 &a,CO DL &b){MD2(a.x-b,a.y-b);} -IOF(D2,-)(CO DL &a,CO D2 &b){MD2(a-b.x,a-b.y);} -IOF(D3,-)(CO D3 &a,CO D3 &b){MD3(a.x-b.x,a.y-b.y,a.z-b.z);} -IOF(D3,-)(CO D3 &a,CO DL &b){MD3(a.x-b,a.y-b,a.z-b);} -IOF(D3,-)(CO DL &a,CO D3 &b){MD3(a-b.x,a-b.y,a-b.z);} -IOF(D4,-)(CO D4 &a,CO D4 &b){MD4(a.x-b.x,a.y-b.y,a.z-b.z,a.w-b.w);} -IOF(D4,-)(CO D4 &a,CO DL &b){MD4(a.x-b,a.y-b,a.z-b,a.w-b);} -IOF(D4,-)(CO DL &a,CO D4 &b){MD4(a-b.x,a-b.y,a-b.z,a-b.w);} -IOF(D2,*)(CO D2 &a,CO D2 &b){MD2(a.x*b.x,a.y*b.y);} -IOF(D2,*)(CO D2 &a,CO DL &b){MD2(a.x*b,a.y*b);} -IOF(D2,*)(CO DL &a,CO D2 &b){MD2(a*b.x,a*b.y);} -IOF(D3,*)(CO D3 &a,CO D3 &b){MD3(a.x*b.x,a.y*b.y,a.z*b.z);} -IOF(D3,*)(CO D3 &a,CO DL &b){MD3(a.x*b,a.y*b,a.z*b);} -IOF(D3,*)(CO DL &a,CO D3 &b){MD3(a*b.x,a*b.y,a*b.z);} -IOF(D4,*)(CO D4 &a,CO D4 &b){MD4(a.x*b.x,a.y*b.y,a.z*b.z,a.w*b.w);} -IOF(D4,*)(CO D4 &a,CO DL &b){MD4(a.x*b,a.y*b,a.z*b,a.w*b);} -IOF(D4,*)(CO DL &a,CO D4 &b){MD4(a*b.x,a*b.y,a*b.z,a*b.w);} -IOF(D2,/)(CO D2 &a,CO D2 &b){MD2(a.x/b.x,a.y/b.y);} -IOF(D2,/)(CO D2 &a,CO DL &b){MD2(a.x/b,a.y/b);} -IOF(D2,/)(CO DL &a,CO D2 &b){MD2(a/b.x,a/b.y);} -IOF(D3,/)(CO D3 &a,CO D3 &b){MD3(a.x/b.x,a.y/b.y,a.z/b.z);} -IOF(D3,/)(CO D3 &a,CO DL &b){MD3(a.x/b,a.y/b,a.z/b);} -IOF(D3,/)(CO DL &a,CO D3 &b){MD3(a/b.x,a/b.y,a/b.z);} -IOF(D4,/)(CO D4 &a,CO D4 &b){MD4(a.x/b.x,a.y/b.y,a.z/b.z,a.w/b.w);} -IOF(D4,/)(CO D4 &a,CO DL &b){MD4(a.x/b,a.y/b,a.z/b,a.w/b);} -IOF(D4,/)(CO DL &a,CO D4 &b){MD4(a/b.x,a/b.y,a/b.z,a/b.w);} -IOF(D2,+=)(D2 &a,CO D2 &b){MD2N(a.x+b.x,a.y+b.y);} -IOF(D2,+=)(D2 &a,CO DL &b){MD2N(a.x + b,a.y + b);} -IOF(D3,+=)(D3 &a,CO D3 &b){MD3N(a.x+b.x,a.y+b.y,a.z+b.z);} -IOF(D3,+=)(D3 &a,CO DL &b){MD3N(a.x + b,a.y + b,a.z + b);} -IOF(D4,+=)(D4 &a,CO D4 &b){MD4N(a.x+b.x,a.y+b.y,a.z+b.z,a.w+b.w);} -IOF(D4,+=)(D4 &a,CO DL &b){MD4N(a.x + b,a.y + b,a.z + b,a.w + b);} -IOF(D2,-=)(D2 &a,CO D2 &b){MD2N(a.x-b.x,a.y-b.y);} -IOF(D2,-=)(D2 &a,CO DL &b){MD2N(a.x - b,a.y - b);} -IOF(D3,-=)(D3 &a,CO D3 &b){MD3N(a.x-b.x,a.y-b.y,a.z-b.z);} -IOF(D3,-=)(D3 &a,CO DL &b){MD3N(a.x - b,a.y - b,a.z - b);} -IOF(D4,-=)(D4 &a,CO D4 &b){MD4N(a.x-b.x,a.y-b.y,a.z-b.z,a.w-b.w);} -IOF(D4,-=)(D4 &a,CO DL &b){MD4N(a.x - b,a.y - b,a.z - b,a.w - b);} -IOF(D2,*=)(D2 &a,CO D2 &b){MD2N(a.x*b.x,a.y*b.y);} -IOF(D2,*=)(D2 &a,CO DL &b){MD2N(a.x * b,a.y * b);} -IOF(D3,*=)(D3 &a,CO D3 &b){MD3N(a.x*b.x,a.y*b.y,a.z*b.z);} -IOF(D3,*=)(D3 &a,CO DL &b){MD3N(a.x * b,a.y * b,a.z * b);} -IOF(D4,*=)(D4 &a,CO D4 &b){MD4N(a.x*b.x,a.y*b.y,a.z*b.z,a.w*b.w);} -IOF(D4,*=)(D4 &a,CO DL &b){MD4N(a.x * b,a.y * b,a.z * b,a.w * b);} -IOF(D2,/=)(D2 &a,CO D2 &b){MD2N(a.x/b.x,a.y/b.y);} -IOF(D2,/=)(D2 &a,CO DL &b){MD2N(a.x / b,a.y / b);} -IOF(D3,/=)(D3 &a,CO D3 &b){MD3N(a.x/b.x,a.y/b.y,a.z/b.z);} -IOF(D3,/=)(D3 &a,CO DL &b){MD3N(a.x / b,a.y / b,a.z / b);} -IOF(D4,/=)(D4 &a,CO D4 &b){MD4N(a.x/b.x,a.y/b.y,a.z/b.z,a.w/b.w);} -IOF(D4,/=)(D4 &a,CO DL &b){MD4N(a.x / b,a.y / b,a.z / b,a.w / b);} -IOF(D2,>)(CO D2 &a,CO D2 &b){MD2(a.x>b.x,a.y>b.y);} -IOF(D2,>)(CO D2 &a,CO DL &b){MD2(a.x>b,a.y>b);} -IOF(D2,>)(CO DL &a,CO D2 &b){MD2(a>b.x,a>b.y);} -IOF(D3,>)(CO D3 &a,CO D3 &b){MD3(a.x>b.x,a.y>b.y,a.z>b.z);} -IOF(D3,>)(CO D3 &a,CO DL &b){MD3(a.x>b,a.y>b,a.z>b);} -IOF(D3,>)(CO DL &a,CO D3 &b){MD3(a>b.x,a>b.y,a>b.z);} -IOF(D4,>)(CO D4 &a,CO D4 &b){MD4(a.x>b.x,a.y>b.y,a.z>b.z,a.w>b.w);} -IOF(D4,>)(CO D4 &a,CO DL &b){MD4(a.x>b,a.y>b,a.z>b,a.w>b);} -IOF(D4,>)(CO DL &a,CO D4 &b){MD4(a>b.x,a>b.y,a>b.z,a>b.w);} -IOF(D2,<)(CO D2 &a,CO D2 &b){MD2(a.x=)(CO D2 &a,CO D2 &b){MD2(a.x>=b.x,a.y>=b.y);} -IOF(D2,>=)(CO D2 &a,CO DL &b){MD2(a.x>=b,a.y>=b);} -IOF(D2,>=)(CO DL &a,CO D2 &b){MD2(a>=b.x,a>=b.y);} -IOF(D3,>=)(CO D3 &a,CO D3 &b){MD3(a.x>=b.x,a.y>=b.y,a.z>=b.z);} -IOF(D3,>=)(CO D3 &a,CO DL &b){MD3(a.x>=b,a.y>=b,a.z>=b);} -IOF(D3,>=)(CO DL &a,CO D3 &b){MD3(a>=b.x,a>=b.y,a>=b.z);} -IOF(D4,>=)(CO D4 &a,CO D4 &b){MD4(a.x>=b.x,a.y>=b.y,a.z>=b.z,a.w>=b.w);} -IOF(D4,>=)(CO D4 &a,CO DL &b){MD4(a.x>=b,a.y>=b,a.z>=b,a.w>=b);} -IOF(D4,>=)(CO DL &a,CO D4 &b){MD4(a>=b.x,a>=b.y,a>=b.z,a>=b.w);} -IOF(D2,<=)(CO D2 &a,CO D2 &b){MD2(a.x<=b.x,a.y<=b.y);} -IOF(D2,<=)(CO D2 &a,CO DL &b){MD2(a.x<=b,a.y<=b);} -IOF(D2,<=)(CO DL &a,CO D2 &b){MD2(a<=b.x,a<=b.y);} -IOF(D3,<=)(CO D3 &a,CO D3 &b){MD3(a.x<=b.x,a.y<=b.y,a.z<=b.z);} -IOF(D3,<=)(CO D3 &a,CO DL &b){MD3(a.x<=b,a.y<=b,a.z<=b);} -IOF(D3,<=)(CO DL &a,CO D3 &b){MD3(a<=b.x,a<=b.y,a<=b.z);} -IOF(D4,<=)(CO D4 &a,CO D4 &b){MD4(a.x<=b.x,a.y<=b.y,a.z<=b.z,a.w<=b.w);} -IOF(D4,<=)(CO D4 &a,CO DL &b){MD4(a.x<=b,a.y<=b,a.z<=b,a.w<=b);} -IOF(D4,<=)(CO DL &a,CO D4 &b){MD4(a<=b.x,a<=b.y,a<=b.z,a<=b.w);} -IOF(D2,==)(CO D2 &a,CO D2 &b){MD2(a.x==b.x,a.y==b.y);} -IOF(D2,==)(CO D2 &a,CO DL &b){MD2(a.x==b,a.y==b);} -IOF(D2,==)(CO DL &a,CO D2 &b){MD2(a==b.x,a==b.y);} -IOF(D3,==)(CO D3 &a,CO D3 &b){MD3(a.x==b.x,a.y==b.y,a.z==b.z);} -IOF(D3,==)(CO D3 &a,CO DL &b){MD3(a.x==b,a.y==b,a.z==b);} -IOF(D3,==)(CO DL &a,CO D3 &b){MD3(a==b.x,a==b.y,a==b.z);} -IOF(D4,==)(CO D4 &a,CO D4 &b){MD4(a.x==b.x,a.y==b.y,a.z==b.z,a.w==b.w);} -IOF(D4,==)(CO D4 &a,CO DL &b){MD4(a.x==b,a.y==b,a.z==b,a.w==b);} -IOF(D4,==)(CO DL &a,CO D4 &b){MD4(a==b.x,a==b.y,a==b.z,a==b.w);} -IOF(D2,!=)(CO D2 &a,CO D2 &b){MD2(a.x!=b.x,a.y!=b.y);} -IOF(D2,!=)(CO D2 &a,CO DL &b){MD2(a.x!=b,a.y!=b);} -IOF(D2,!=)(CO DL &a,CO D2 &b){MD2(a!=b.x,a!=b.y);} -IOF(D3,!=)(CO D3 &a,CO D3 &b){MD3(a.x!=b.x,a.y!=b.y,a.z!=b.z);} -IOF(D3,!=)(CO D3 &a,CO DL &b){MD3(a.x!=b,a.y!=b,a.z!=b);} -IOF(D3,!=)(CO DL &a,CO D3 &b){MD3(a!=b.x,a!=b.y,a!=b.z);} -IOF(D4,!=)(CO D4 &a,CO D4 &b){MD4(a.x!=b.x,a.y!=b.y,a.z!=b.z,a.w!=b.w);} -IOF(D4,!=)(CO D4 &a,CO DL &b){MD4(a.x!=b,a.y!=b,a.z!=b,a.w!=b);} -IOF(D4,!=)(CO DL &a,CO D4 &b){MD4(a!=b.x,a!=b.y,a!=b.z,a!=b.w);} -inline D2 sin(CO D2 &a){MD2(sin(a.x),sin(a.y));} -inline D3 sin(CO D3 &a){MD3(sin(a.x),sin(a.y),sin(a.z));} -inline D4 sin(CO D4 &a){MD4(sin(a.x),sin(a.y),sin(a.z),sin(a.w));} -inline D2 cos(CO D2 &a){MD2(cos(a.x),cos(a.y));} -inline D3 cos(CO D3 &a){MD3(cos(a.x),cos(a.y),cos(a.z));} -inline D4 cos(CO D4 &a){MD4(cos(a.x),cos(a.y),cos(a.z),cos(a.w));} -inline D2 tan(CO D2 &a){MD2(tan(a.x),tan(a.y));} -inline D3 tan(CO D3 &a){MD3(tan(a.x),tan(a.y),tan(a.z));} -inline D4 tan(CO D4 &a){MD4(tan(a.x),tan(a.y),tan(a.z),tan(a.w));} -inline D2 asin(CO D2 &a){MD2(asin(a.x),asin(a.y));} -inline D3 asin(CO D3 &a){MD3(asin(a.x),asin(a.y),asin(a.z));} -inline D4 asin(CO D4 &a){MD4(asin(a.x),asin(a.y),asin(a.z),asin(a.w));} -inline D2 acos(CO D2 &a){MD2(acos(a.x),acos(a.y));} -inline D3 acos(CO D3 &a){MD3(acos(a.x),acos(a.y),acos(a.z));} -inline D4 acos(CO D4 &a){MD4(acos(a.x),acos(a.y),acos(a.z),acos(a.w));} -inline D2 atan(CO D2 &a){MD2(atan(a.x),atan(a.y));} -inline D3 atan(CO D3 &a){MD3(atan(a.x),atan(a.y),atan(a.z));} -inline D4 atan(CO D4 &a){MD4(atan(a.x),atan(a.y),atan(a.z),atan(a.w));} -inline D2 sinh(CO D2 &a){MD2(sinh(a.x),sinh(a.y));} -inline D3 sinh(CO D3 &a){MD3(sinh(a.x),sinh(a.y),sinh(a.z));} -inline D4 sinh(CO D4 &a){MD4(sinh(a.x),sinh(a.y),sinh(a.z),sinh(a.w));} -inline D2 cosh(CO D2 &a){MD2(cosh(a.x),cosh(a.y));} -inline D3 cosh(CO D3 &a){MD3(cosh(a.x),cosh(a.y),cosh(a.z));} -inline D4 cosh(CO D4 &a){MD4(cosh(a.x),cosh(a.y),cosh(a.z),cosh(a.w));} -inline D2 tanh(CO D2 &a){MD2(tanh(a.x),tanh(a.y));} -inline D3 tanh(CO D3 &a){MD3(tanh(a.x),tanh(a.y),tanh(a.z));} -inline D4 tanh(CO D4 &a){MD4(tanh(a.x),tanh(a.y),tanh(a.z),tanh(a.w));} -inline D2 asinh(CO D2 &a){MD2(asinh(a.x),asinh(a.y));} -inline D3 asinh(CO D3 &a){MD3(asinh(a.x),asinh(a.y),asinh(a.z));} -inline D4 asinh(CO D4 &a){MD4(asinh(a.x),asinh(a.y),asinh(a.z),asinh(a.w));} -inline D2 acosh(CO D2 &a){MD2(acosh(a.x),acosh(a.y));} -inline D3 acosh(CO D3 &a){MD3(acosh(a.x),acosh(a.y),acosh(a.z));} -inline D4 acosh(CO D4 &a){MD4(acosh(a.x),acosh(a.y),acosh(a.z),acosh(a.w));} -inline D2 atanh(CO D2 &a){MD2(atanh(a.x),atanh(a.y));} -inline D3 atanh(CO D3 &a){MD3(atanh(a.x),atanh(a.y),atanh(a.z));} -inline D4 atanh(CO D4 &a){MD4(atanh(a.x),atanh(a.y),atanh(a.z),atanh(a.w));} -inline D2 exp(CO D2 &a){MD2(exp(a.x),exp(a.y));} -inline D3 exp(CO D3 &a){MD3(exp(a.x),exp(a.y),exp(a.z));} -inline D4 exp(CO D4 &a){MD4(exp(a.x),exp(a.y),exp(a.z),exp(a.w));} -inline D2 log(CO D2 &a){MD2(log(a.x),log(a.y));} -inline D3 log(CO D3 &a){MD3(log(a.x),log(a.y),log(a.z));} -inline D4 log(CO D4 &a){MD4(log(a.x),log(a.y),log(a.z),log(a.w));} -inline D2 log2(CO D2 &a){MD2(log2(a.x),log2(a.y));} -inline D3 log2(CO D3 &a){MD3(log2(a.x),log2(a.y),log2(a.z));} -inline D4 log2(CO D4 &a){MD4(log2(a.x),log2(a.y),log2(a.z),log2(a.w));} -inline D2 log10(CO D2 &a){MD2(log10(a.x),log10(a.y));} -inline D3 log10(CO D3 &a){MD3(log10(a.x),log10(a.y),log10(a.z));} -inline D4 log10(CO D4 &a){MD4(log10(a.x),log10(a.y),log10(a.z),log10(a.w));} -inline D2 sqrt(CO D2 &a){MD2(sqrt(a.x),sqrt(a.y));} -inline D3 sqrt(CO D3 &a){MD3(sqrt(a.x),sqrt(a.y),sqrt(a.z));} -inline D4 sqrt(CO D4 &a){MD4(sqrt(a.x),sqrt(a.y),sqrt(a.z),sqrt(a.w));} -inline D2 cbrt(CO D2 &a){MD2(cbrt(a.x),cbrt(a.y));} -inline D3 cbrt(CO D3 &a){MD3(cbrt(a.x),cbrt(a.y),cbrt(a.z));} -inline D4 cbrt(CO D4 &a){MD4(cbrt(a.x),cbrt(a.y),cbrt(a.z),cbrt(a.w));} -inline D2 abs(CO D2 &a){MD2(abs(a.x),abs(a.y));} -inline D3 abs(CO D3 &a){MD3(abs(a.x),abs(a.y),abs(a.z));} -inline D4 abs(CO D4 &a){MD4(abs(a.x),abs(a.y),abs(a.z),abs(a.w));} -inline D2 ceil(CO D2 &a){MD2(ceil(a.x),ceil(a.y));} -inline D3 ceil(CO D3 &a){MD3(ceil(a.x),ceil(a.y),ceil(a.z));} -inline D4 ceil(CO D4 &a){MD4(ceil(a.x),ceil(a.y),ceil(a.z),ceil(a.w));} -inline D2 floor(CO D2 &a){MD2(floor(a.x),floor(a.y));} -inline D3 floor(CO D3 &a){MD3(floor(a.x),floor(a.y),floor(a.z));} -inline D4 floor(CO D4 &a){MD4(floor(a.x),floor(a.y),floor(a.z),floor(a.w));} +IOF(F2, +)(CO F2 &a, CO F2 &b) { MF2(a.x + b.x, a.y + b.y); } +IOF(F2, +)(CO F2 &a, CO FL &b) { MF2(a.x + b, a.y + b); } +IOF(F2, +)(CO FL &a, CO F2 &b) { MF2(a + b.x, a + b.y); } +IOF(F3, +)(CO F3 &a, CO F3 &b) { MF3(a.x + b.x, a.y + b.y, a.z + b.z); } +IOF(F3, +)(CO F3 &a, CO FL &b) { MF3(a.x + b, a.y + b, a.z + b); } +IOF(F3, +)(CO FL &a, CO F3 &b) { MF3(a + b.x, a + b.y, a + b.z); } +IOF(F4, +)(CO F4 &a, CO F4 &b) { MF4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } +IOF(F4, +)(CO F4 &a, CO FL &b) { MF4(a.x + b, a.y + b, a.z + b, a.w + b); } +IOF(F4, +)(CO FL &a, CO F4 &b) { MF4(a + b.x, a + b.y, a + b.z, a + b.w); } +IOF(F2, -)(CO F2 &a, CO F2 &b) { MF2(a.x - b.x, a.y - b.y); } +IOF(F2, -)(CO F2 &a, CO FL &b) { MF2(a.x - b, a.y - b); } +IOF(F2, -)(CO FL &a, CO F2 &b) { MF2(a - b.x, a - b.y); } +IOF(F3, -)(CO F3 &a, CO F3 &b) { MF3(a.x - b.x, a.y - b.y, a.z - b.z); } +IOF(F3, -)(CO F3 &a, CO FL &b) { MF3(a.x - b, a.y - b, a.z - b); } +IOF(F3, -)(CO FL &a, CO F3 &b) { MF3(a - b.x, a - b.y, a - b.z); } +IOF(F4, -)(CO F4 &a, CO F4 &b) { MF4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } +IOF(F4, -)(CO F4 &a, CO FL &b) { MF4(a.x - b, a.y - b, a.z - b, a.w - b); } +IOF(F4, -)(CO FL &a, CO F4 &b) { MF4(a - b.x, a - b.y, a - b.z, a - b.w); } +IOF(F2, *)(CO F2 &a, CO F2 &b) { MF2(a.x * b.x, a.y * b.y); } +IOF(F2, *)(CO F2 &a, CO FL &b) { MF2(a.x * b, a.y * b); } +IOF(F2, *)(CO FL &a, CO F2 &b) { MF2(a * b.x, a * b.y); } +IOF(F3, *)(CO F3 &a, CO F3 &b) { MF3(a.x * b.x, a.y * b.y, a.z * b.z); } +IOF(F3, *)(CO F3 &a, CO FL &b) { MF3(a.x * b, a.y * b, a.z * b); } +IOF(F3, *)(CO FL &a, CO F3 &b) { MF3(a * b.x, a * b.y, a * b.z); } +IOF(F4, *)(CO F4 &a, CO F4 &b) { MF4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } +IOF(F4, *)(CO F4 &a, CO FL &b) { MF4(a.x * b, a.y * b, a.z * b, a.w * b); } +IOF(F4, *)(CO FL &a, CO F4 &b) { MF4(a * b.x, a * b.y, a * b.z, a * b.w); } +IOF(F2, /)(CO F2 &a, CO F2 &b) { MF2(a.x / b.x, a.y / b.y); } +IOF(F2, /)(CO F2 &a, CO FL &b) { MF2(a.x / b, a.y / b); } +IOF(F2, /)(CO FL &a, CO F2 &b) { MF2(a / b.x, a / b.y); } +IOF(F3, /)(CO F3 &a, CO F3 &b) { MF3(a.x / b.x, a.y / b.y, a.z / b.z); } +IOF(F3, /)(CO F3 &a, CO FL &b) { MF3(a.x / b, a.y / b, a.z / b); } +IOF(F3, /)(CO FL &a, CO F3 &b) { MF3(a / b.x, a / b.y, a / b.z); } +IOF(F4, /)(CO F4 &a, CO F4 &b) { MF4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); } +IOF(F4, /)(CO F4 &a, CO FL &b) { MF4(a.x / b, a.y / b, a.z / b, a.w / b); } +IOF(F4, /)(CO FL &a, CO F4 &b) { MF4(a / b.x, a / b.y, a / b.z, a / b.w); } +IOF(F2, +=)(F2 &a, CO F2 &b) { MF2N(a.x + b.x, a.y + b.y); } +IOF(F2, +=)(F2 &a, CO FL &b) { MF2N(a.x + b, a.y + b); } +IOF(F3, +=)(F3 &a, CO F3 &b) { MF3N(a.x + b.x, a.y + b.y, a.z + b.z); } +IOF(F3, +=)(F3 &a, CO FL &b) { MF3N(a.x + b, a.y + b, a.z + b); } +IOF(F4, +=)(F4 &a, CO F4 &b) { MF4N(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } +IOF(F4, +=)(F4 &a, CO FL &b) { MF4N(a.x + b, a.y + b, a.z + b, a.w + b); } +IOF(F2, -=)(F2 &a, CO F2 &b) { MF2N(a.x - b.x, a.y - b.y); } +IOF(F2, -=)(F2 &a, CO FL &b) { MF2N(a.x - b, a.y - b); } +IOF(F3, -=)(F3 &a, CO F3 &b) { MF3N(a.x - b.x, a.y - b.y, a.z - b.z); } +IOF(F3, -=)(F3 &a, CO FL &b) { MF3N(a.x - b, a.y - b, a.z - b); } +IOF(F4, -=)(F4 &a, CO F4 &b) { MF4N(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } +IOF(F4, -=)(F4 &a, CO FL &b) { MF4N(a.x - b, a.y - b, a.z - b, a.w - b); } +IOF(F2, *=)(F2 &a, CO F2 &b) { MF2N(a.x * b.x, a.y * b.y); } +IOF(F2, *=)(F2 &a, CO FL &b) { MF2N(a.x * b, a.y * b); } +IOF(F3, *=)(F3 &a, CO F3 &b) { MF3N(a.x * b.x, a.y * b.y, a.z * b.z); } +IOF(F3, *=)(F3 &a, CO FL &b) { MF3N(a.x * b, a.y * b, a.z * b); } +IOF(F4, *=)(F4 &a, CO F4 &b) { MF4N(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } +IOF(F4, *=)(F4 &a, CO FL &b) { MF4N(a.x * b, a.y * b, a.z * b, a.w * b); } +IOF(F2, /=)(F2 &a, CO F2 &b) { MF2N(a.x / b.x, a.y / b.y); } +IOF(F2, /=)(F2 &a, CO FL &b) { MF2N(a.x / b, a.y / b); } +IOF(F3, /=)(F3 &a, CO F3 &b) { MF3N(a.x / b.x, a.y / b.y, a.z / b.z); } +IOF(F3, /=)(F3 &a, CO FL &b) { MF3N(a.x / b, a.y / b, a.z / b); } +IOF(F4, /=)(F4 &a, CO F4 &b) { MF4N(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); } +IOF(F4, /=)(F4 &a, CO FL &b) { MF4N(a.x / b, a.y / b, a.z / b, a.w / b); } +IOF(F2, >)(CO F2 &a, CO F2 &b) { MF2(a.x > b.x, a.y > b.y); } +IOF(F2, >)(CO F2 &a, CO FL &b) { MF2(a.x > b, a.y > b); } +IOF(F2, >)(CO FL &a, CO F2 &b) { MF2(a > b.x, a > b.y); } +IOF(F3, >)(CO F3 &a, CO F3 &b) { MF3(a.x > b.x, a.y > b.y, a.z > b.z); } +IOF(F3, >)(CO F3 &a, CO FL &b) { MF3(a.x > b, a.y > b, a.z > b); } +IOF(F3, >)(CO FL &a, CO F3 &b) { MF3(a > b.x, a > b.y, a > b.z); } +IOF(F4, >)(CO F4 &a, CO F4 &b) { MF4(a.x > b.x, a.y > b.y, a.z > b.z, a.w > b.w); } +IOF(F4, >)(CO F4 &a, CO FL &b) { MF4(a.x > b, a.y > b, a.z > b, a.w > b); } +IOF(F4, >)(CO FL &a, CO F4 &b) { MF4(a > b.x, a > b.y, a > b.z, a > b.w); } +IOF(F2, <)(CO F2 &a, CO F2 &b) { MF2(a.x < b.x, a.y < b.y); } +IOF(F2, <)(CO F2 &a, CO FL &b) { MF2(a.x < b, a.y < b); } +IOF(F2, <)(CO FL &a, CO F2 &b) { MF2(a < b.x, a < b.y); } +IOF(F3, <)(CO F3 &a, CO F3 &b) { MF3(a.x < b.x, a.y < b.y, a.z < b.z); } +IOF(F3, <)(CO F3 &a, CO FL &b) { MF3(a.x < b, a.y < b, a.z < b); } +IOF(F3, <)(CO FL &a, CO F3 &b) { MF3(a < b.x, a < b.y, a < b.z); } +IOF(F4, <)(CO F4 &a, CO F4 &b) { MF4(a.x < b.x, a.y < b.y, a.z < b.z, a.w < b.w); } +IOF(F4, <)(CO F4 &a, CO FL &b) { MF4(a.x < b, a.y < b, a.z < b, a.w < b); } +IOF(F4, <)(CO FL &a, CO F4 &b) { MF4(a < b.x, a < b.y, a < b.z, a < b.w); } +IOF(F2, >=)(CO F2 &a, CO F2 &b) { MF2(a.x >= b.x, a.y >= b.y); } +IOF(F2, >=)(CO F2 &a, CO FL &b) { MF2(a.x >= b, a.y >= b); } +IOF(F2, >=)(CO FL &a, CO F2 &b) { MF2(a >= b.x, a >= b.y); } +IOF(F3, >=)(CO F3 &a, CO F3 &b) { MF3(a.x >= b.x, a.y >= b.y, a.z >= b.z); } +IOF(F3, >=)(CO F3 &a, CO FL &b) { MF3(a.x >= b, a.y >= b, a.z >= b); } +IOF(F3, >=)(CO FL &a, CO F3 &b) { MF3(a >= b.x, a >= b.y, a >= b.z); } +IOF(F4, >=)(CO F4 &a, CO F4 &b) { MF4(a.x >= b.x, a.y >= b.y, a.z >= b.z, a.w >= b.w); } +IOF(F4, >=)(CO F4 &a, CO FL &b) { MF4(a.x >= b, a.y >= b, a.z >= b, a.w >= b); } +IOF(F4, >=)(CO FL &a, CO F4 &b) { MF4(a >= b.x, a >= b.y, a >= b.z, a >= b.w); } +IOF(F2, <=)(CO F2 &a, CO F2 &b) { MF2(a.x <= b.x, a.y <= b.y); } +IOF(F2, <=)(CO F2 &a, CO FL &b) { MF2(a.x <= b, a.y <= b); } +IOF(F2, <=)(CO FL &a, CO F2 &b) { MF2(a <= b.x, a <= b.y); } +IOF(F3, <=)(CO F3 &a, CO F3 &b) { MF3(a.x <= b.x, a.y <= b.y, a.z <= b.z); } +IOF(F3, <=)(CO F3 &a, CO FL &b) { MF3(a.x <= b, a.y <= b, a.z <= b); } +IOF(F3, <=)(CO FL &a, CO F3 &b) { MF3(a <= b.x, a <= b.y, a <= b.z); } +IOF(F4, <=)(CO F4 &a, CO F4 &b) { MF4(a.x <= b.x, a.y <= b.y, a.z <= b.z, a.w <= b.w); } +IOF(F4, <=)(CO F4 &a, CO FL &b) { MF4(a.x <= b, a.y <= b, a.z <= b, a.w <= b); } +IOF(F4, <=)(CO FL &a, CO F4 &b) { MF4(a <= b.x, a <= b.y, a <= b.z, a <= b.w); } +IOF(F2, ==)(CO F2 &a, CO F2 &b) { MF2(a.x == b.x, a.y == b.y); } +IOF(F2, ==)(CO F2 &a, CO FL &b) { MF2(a.x == b, a.y == b); } +IOF(F2, ==)(CO FL &a, CO F2 &b) { MF2(a == b.x, a == b.y); } +IOF(F3, ==)(CO F3 &a, CO F3 &b) { MF3(a.x == b.x, a.y == b.y, a.z == b.z); } +IOF(F3, ==)(CO F3 &a, CO FL &b) { MF3(a.x == b, a.y == b, a.z == b); } +IOF(F3, ==)(CO FL &a, CO F3 &b) { MF3(a == b.x, a == b.y, a == b.z); } +IOF(F4, ==)(CO F4 &a, CO F4 &b) { MF4(a.x == b.x, a.y == b.y, a.z == b.z, a.w == b.w); } +IOF(F4, ==)(CO F4 &a, CO FL &b) { MF4(a.x == b, a.y == b, a.z == b, a.w == b); } +IOF(F4, ==)(CO FL &a, CO F4 &b) { MF4(a == b.x, a == b.y, a == b.z, a == b.w); } +IOF(F2, !=)(CO F2 &a, CO F2 &b) { MF2(a.x != b.x, a.y != b.y); } +IOF(F2, !=)(CO F2 &a, CO FL &b) { MF2(a.x != b, a.y != b); } +IOF(F2, !=)(CO FL &a, CO F2 &b) { MF2(a != b.x, a != b.y); } +IOF(F3, !=)(CO F3 &a, CO F3 &b) { MF3(a.x != b.x, a.y != b.y, a.z != b.z); } +IOF(F3, !=)(CO F3 &a, CO FL &b) { MF3(a.x != b, a.y != b, a.z != b); } +IOF(F3, !=)(CO FL &a, CO F3 &b) { MF3(a != b.x, a != b.y, a != b.z); } +IOF(F4, !=)(CO F4 &a, CO F4 &b) { MF4(a.x != b.x, a.y != b.y, a.z != b.z, a.w != b.w); } +IOF(F4, !=)(CO F4 &a, CO FL &b) { MF4(a.x != b, a.y != b, a.z != b, a.w != b); } +IOF(F4, !=)(CO FL &a, CO F4 &b) { MF4(a != b.x, a != b.y, a != b.z, a != b.w); } +inline F2 sin(CO F2 &a) { MF2(sin(a.x), sin(a.y)); } +inline F3 sin(CO F3 &a) { MF3(sin(a.x), sin(a.y), sin(a.z)); } +inline F4 sin(CO F4 &a) { MF4(sin(a.x), sin(a.y), sin(a.z), sin(a.w)); } +inline F2 cos(CO F2 &a) { MF2(cos(a.x), cos(a.y)); } +inline F3 cos(CO F3 &a) { MF3(cos(a.x), cos(a.y), cos(a.z)); } +inline F4 cos(CO F4 &a) { MF4(cos(a.x), cos(a.y), cos(a.z), cos(a.w)); } +inline F2 tan(CO F2 &a) { MF2(tan(a.x), tan(a.y)); } +inline F3 tan(CO F3 &a) { MF3(tan(a.x), tan(a.y), tan(a.z)); } +inline F4 tan(CO F4 &a) { MF4(tan(a.x), tan(a.y), tan(a.z), tan(a.w)); } +inline F2 asin(CO F2 &a) { MF2(asin(a.x), asin(a.y)); } +inline F3 asin(CO F3 &a) { MF3(asin(a.x), asin(a.y), asin(a.z)); } +inline F4 asin(CO F4 &a) { MF4(asin(a.x), asin(a.y), asin(a.z), asin(a.w)); } +inline F2 acos(CO F2 &a) { MF2(acos(a.x), acos(a.y)); } +inline F3 acos(CO F3 &a) { MF3(acos(a.x), acos(a.y), acos(a.z)); } +inline F4 acos(CO F4 &a) { MF4(acos(a.x), acos(a.y), acos(a.z), acos(a.w)); } +inline F2 atan(CO F2 &a) { MF2(atan(a.x), atan(a.y)); } +inline F3 atan(CO F3 &a) { MF3(atan(a.x), atan(a.y), atan(a.z)); } +inline F4 atan(CO F4 &a) { MF4(atan(a.x), atan(a.y), atan(a.z), atan(a.w)); } +inline F2 sinh(CO F2 &a) { MF2(sinh(a.x), sinh(a.y)); } +inline F3 sinh(CO F3 &a) { MF3(sinh(a.x), sinh(a.y), sinh(a.z)); } +inline F4 sinh(CO F4 &a) { MF4(sinh(a.x), sinh(a.y), sinh(a.z), sinh(a.w)); } +inline F2 cosh(CO F2 &a) { MF2(cosh(a.x), cosh(a.y)); } +inline F3 cosh(CO F3 &a) { MF3(cosh(a.x), cosh(a.y), cosh(a.z)); } +inline F4 cosh(CO F4 &a) { MF4(cosh(a.x), cosh(a.y), cosh(a.z), cosh(a.w)); } +inline F2 tanh(CO F2 &a) { MF2(tanh(a.x), tanh(a.y)); } +inline F3 tanh(CO F3 &a) { MF3(tanh(a.x), tanh(a.y), tanh(a.z)); } +inline F4 tanh(CO F4 &a) { MF4(tanh(a.x), tanh(a.y), tanh(a.z), tanh(a.w)); } +inline F2 asinh(CO F2 &a) { MF2(asinh(a.x), asinh(a.y)); } +inline F3 asinh(CO F3 &a) { MF3(asinh(a.x), asinh(a.y), asinh(a.z)); } +inline F4 asinh(CO F4 &a) { MF4(asinh(a.x), asinh(a.y), asinh(a.z), asinh(a.w)); } +inline F2 acosh(CO F2 &a) { MF2(acosh(a.x), acosh(a.y)); } +inline F3 acosh(CO F3 &a) { MF3(acosh(a.x), acosh(a.y), acosh(a.z)); } +inline F4 acosh(CO F4 &a) { MF4(acosh(a.x), acosh(a.y), acosh(a.z), acosh(a.w)); } +inline F2 atanh(CO F2 &a) { MF2(atanh(a.x), atanh(a.y)); } +inline F3 atanh(CO F3 &a) { MF3(atanh(a.x), atanh(a.y), atanh(a.z)); } +inline F4 atanh(CO F4 &a) { MF4(atanh(a.x), atanh(a.y), atanh(a.z), atanh(a.w)); } +inline F2 exp(CO F2 &a) { MF2(exp(a.x), exp(a.y)); } +inline F3 exp(CO F3 &a) { MF3(exp(a.x), exp(a.y), exp(a.z)); } +inline F4 exp(CO F4 &a) { MF4(exp(a.x), exp(a.y), exp(a.z), exp(a.w)); } +inline F2 log(CO F2 &a) { MF2(log(a.x), log(a.y)); } +inline F3 log(CO F3 &a) { MF3(log(a.x), log(a.y), log(a.z)); } +inline F4 log(CO F4 &a) { MF4(log(a.x), log(a.y), log(a.z), log(a.w)); } +inline F2 log2(CO F2 &a) { MF2(log2(a.x), log2(a.y)); } +inline F3 log2(CO F3 &a) { MF3(log2(a.x), log2(a.y), log2(a.z)); } +inline F4 log2(CO F4 &a) { MF4(log2(a.x), log2(a.y), log2(a.z), log2(a.w)); } +inline F2 log10(CO F2 &a) { MF2(log10(a.x), log10(a.y)); } +inline F3 log10(CO F3 &a) { MF3(log10(a.x), log10(a.y), log10(a.z)); } +inline F4 log10(CO F4 &a) { MF4(log10(a.x), log10(a.y), log10(a.z), log10(a.w)); } +inline F2 sqrt(CO F2 &a) { MF2(sqrt(a.x), sqrt(a.y)); } +inline F3 sqrt(CO F3 &a) { MF3(sqrt(a.x), sqrt(a.y), sqrt(a.z)); } +inline F4 sqrt(CO F4 &a) { MF4(sqrt(a.x), sqrt(a.y), sqrt(a.z), sqrt(a.w)); } +inline F2 cbrt(CO F2 &a) { MF2(cbrt(a.x), cbrt(a.y)); } +inline F3 cbrt(CO F3 &a) { MF3(cbrt(a.x), cbrt(a.y), cbrt(a.z)); } +inline F4 cbrt(CO F4 &a) { MF4(cbrt(a.x), cbrt(a.y), cbrt(a.z), cbrt(a.w)); } +inline F2 abs(CO F2 &a) { MF2(abs(a.x), abs(a.y)); } +inline F3 abs(CO F3 &a) { MF3(abs(a.x), abs(a.y), abs(a.z)); } +inline F4 abs(CO F4 &a) { MF4(abs(a.x), abs(a.y), abs(a.z), abs(a.w)); } +inline F2 ceil(CO F2 &a) { MF2(ceil(a.x), ceil(a.y)); } +inline F3 ceil(CO F3 &a) { MF3(ceil(a.x), ceil(a.y), ceil(a.z)); } +inline F4 ceil(CO F4 &a) { MF4(ceil(a.x), ceil(a.y), ceil(a.z), ceil(a.w)); } +inline F2 floor(CO F2 &a) { MF2(floor(a.x), floor(a.y)); } +inline F3 floor(CO F3 &a) { MF3(floor(a.x), floor(a.y), floor(a.z)); } +inline F4 floor(CO F4 &a) { MF4(floor(a.x), floor(a.y), floor(a.z), floor(a.w)); } +IOF(D2, +)(CO D2 &a, CO D2 &b) { MD2(a.x + b.x, a.y + b.y); } +IOF(D2, +)(CO D2 &a, CO DL &b) { MD2(a.x + b, a.y + b); } +IOF(D2, +)(CO DL &a, CO D2 &b) { MD2(a + b.x, a + b.y); } +IOF(D3, +)(CO D3 &a, CO D3 &b) { MD3(a.x + b.x, a.y + b.y, a.z + b.z); } +IOF(D3, +)(CO D3 &a, CO DL &b) { MD3(a.x + b, a.y + b, a.z + b); } +IOF(D3, +)(CO DL &a, CO D3 &b) { MD3(a + b.x, a + b.y, a + b.z); } +IOF(D4, +)(CO D4 &a, CO D4 &b) { MD4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } +IOF(D4, +)(CO D4 &a, CO DL &b) { MD4(a.x + b, a.y + b, a.z + b, a.w + b); } +IOF(D4, +)(CO DL &a, CO D4 &b) { MD4(a + b.x, a + b.y, a + b.z, a + b.w); } +IOF(D2, -)(CO D2 &a, CO D2 &b) { MD2(a.x - b.x, a.y - b.y); } +IOF(D2, -)(CO D2 &a, CO DL &b) { MD2(a.x - b, a.y - b); } +IOF(D2, -)(CO DL &a, CO D2 &b) { MD2(a - b.x, a - b.y); } +IOF(D3, -)(CO D3 &a, CO D3 &b) { MD3(a.x - b.x, a.y - b.y, a.z - b.z); } +IOF(D3, -)(CO D3 &a, CO DL &b) { MD3(a.x - b, a.y - b, a.z - b); } +IOF(D3, -)(CO DL &a, CO D3 &b) { MD3(a - b.x, a - b.y, a - b.z); } +IOF(D4, -)(CO D4 &a, CO D4 &b) { MD4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } +IOF(D4, -)(CO D4 &a, CO DL &b) { MD4(a.x - b, a.y - b, a.z - b, a.w - b); } +IOF(D4, -)(CO DL &a, CO D4 &b) { MD4(a - b.x, a - b.y, a - b.z, a - b.w); } +IOF(D2, *)(CO D2 &a, CO D2 &b) { MD2(a.x * b.x, a.y * b.y); } +IOF(D2, *)(CO D2 &a, CO DL &b) { MD2(a.x * b, a.y * b); } +IOF(D2, *)(CO DL &a, CO D2 &b) { MD2(a * b.x, a * b.y); } +IOF(D3, *)(CO D3 &a, CO D3 &b) { MD3(a.x * b.x, a.y * b.y, a.z * b.z); } +IOF(D3, *)(CO D3 &a, CO DL &b) { MD3(a.x * b, a.y * b, a.z * b); } +IOF(D3, *)(CO DL &a, CO D3 &b) { MD3(a * b.x, a * b.y, a * b.z); } +IOF(D4, *)(CO D4 &a, CO D4 &b) { MD4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } +IOF(D4, *)(CO D4 &a, CO DL &b) { MD4(a.x * b, a.y * b, a.z * b, a.w * b); } +IOF(D4, *)(CO DL &a, CO D4 &b) { MD4(a * b.x, a * b.y, a * b.z, a * b.w); } +IOF(D2, /)(CO D2 &a, CO D2 &b) { MD2(a.x / b.x, a.y / b.y); } +IOF(D2, /)(CO D2 &a, CO DL &b) { MD2(a.x / b, a.y / b); } +IOF(D2, /)(CO DL &a, CO D2 &b) { MD2(a / b.x, a / b.y); } +IOF(D3, /)(CO D3 &a, CO D3 &b) { MD3(a.x / b.x, a.y / b.y, a.z / b.z); } +IOF(D3, /)(CO D3 &a, CO DL &b) { MD3(a.x / b, a.y / b, a.z / b); } +IOF(D3, /)(CO DL &a, CO D3 &b) { MD3(a / b.x, a / b.y, a / b.z); } +IOF(D4, /)(CO D4 &a, CO D4 &b) { MD4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); } +IOF(D4, /)(CO D4 &a, CO DL &b) { MD4(a.x / b, a.y / b, a.z / b, a.w / b); } +IOF(D4, /)(CO DL &a, CO D4 &b) { MD4(a / b.x, a / b.y, a / b.z, a / b.w); } +IOF(D2, +=)(D2 &a, CO D2 &b) { MD2N(a.x + b.x, a.y + b.y); } +IOF(D2, +=)(D2 &a, CO DL &b) { MD2N(a.x + b, a.y + b); } +IOF(D3, +=)(D3 &a, CO D3 &b) { MD3N(a.x + b.x, a.y + b.y, a.z + b.z); } +IOF(D3, +=)(D3 &a, CO DL &b) { MD3N(a.x + b, a.y + b, a.z + b); } +IOF(D4, +=)(D4 &a, CO D4 &b) { MD4N(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } +IOF(D4, +=)(D4 &a, CO DL &b) { MD4N(a.x + b, a.y + b, a.z + b, a.w + b); } +IOF(D2, -=)(D2 &a, CO D2 &b) { MD2N(a.x - b.x, a.y - b.y); } +IOF(D2, -=)(D2 &a, CO DL &b) { MD2N(a.x - b, a.y - b); } +IOF(D3, -=)(D3 &a, CO D3 &b) { MD3N(a.x - b.x, a.y - b.y, a.z - b.z); } +IOF(D3, -=)(D3 &a, CO DL &b) { MD3N(a.x - b, a.y - b, a.z - b); } +IOF(D4, -=)(D4 &a, CO D4 &b) { MD4N(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } +IOF(D4, -=)(D4 &a, CO DL &b) { MD4N(a.x - b, a.y - b, a.z - b, a.w - b); } +IOF(D2, *=)(D2 &a, CO D2 &b) { MD2N(a.x * b.x, a.y * b.y); } +IOF(D2, *=)(D2 &a, CO DL &b) { MD2N(a.x * b, a.y * b); } +IOF(D3, *=)(D3 &a, CO D3 &b) { MD3N(a.x * b.x, a.y * b.y, a.z * b.z); } +IOF(D3, *=)(D3 &a, CO DL &b) { MD3N(a.x * b, a.y * b, a.z * b); } +IOF(D4, *=)(D4 &a, CO D4 &b) { MD4N(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } +IOF(D4, *=)(D4 &a, CO DL &b) { MD4N(a.x * b, a.y * b, a.z * b, a.w * b); } +IOF(D2, /=)(D2 &a, CO D2 &b) { MD2N(a.x / b.x, a.y / b.y); } +IOF(D2, /=)(D2 &a, CO DL &b) { MD2N(a.x / b, a.y / b); } +IOF(D3, /=)(D3 &a, CO D3 &b) { MD3N(a.x / b.x, a.y / b.y, a.z / b.z); } +IOF(D3, /=)(D3 &a, CO DL &b) { MD3N(a.x / b, a.y / b, a.z / b); } +IOF(D4, /=)(D4 &a, CO D4 &b) { MD4N(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); } +IOF(D4, /=)(D4 &a, CO DL &b) { MD4N(a.x / b, a.y / b, a.z / b, a.w / b); } +IOF(D2, >)(CO D2 &a, CO D2 &b) { MD2(a.x > b.x, a.y > b.y); } +IOF(D2, >)(CO D2 &a, CO DL &b) { MD2(a.x > b, a.y > b); } +IOF(D2, >)(CO DL &a, CO D2 &b) { MD2(a > b.x, a > b.y); } +IOF(D3, >)(CO D3 &a, CO D3 &b) { MD3(a.x > b.x, a.y > b.y, a.z > b.z); } +IOF(D3, >)(CO D3 &a, CO DL &b) { MD3(a.x > b, a.y > b, a.z > b); } +IOF(D3, >)(CO DL &a, CO D3 &b) { MD3(a > b.x, a > b.y, a > b.z); } +IOF(D4, >)(CO D4 &a, CO D4 &b) { MD4(a.x > b.x, a.y > b.y, a.z > b.z, a.w > b.w); } +IOF(D4, >)(CO D4 &a, CO DL &b) { MD4(a.x > b, a.y > b, a.z > b, a.w > b); } +IOF(D4, >)(CO DL &a, CO D4 &b) { MD4(a > b.x, a > b.y, a > b.z, a > b.w); } +IOF(D2, <)(CO D2 &a, CO D2 &b) { MD2(a.x < b.x, a.y < b.y); } +IOF(D2, <)(CO D2 &a, CO DL &b) { MD2(a.x < b, a.y < b); } +IOF(D2, <)(CO DL &a, CO D2 &b) { MD2(a < b.x, a < b.y); } +IOF(D3, <)(CO D3 &a, CO D3 &b) { MD3(a.x < b.x, a.y < b.y, a.z < b.z); } +IOF(D3, <)(CO D3 &a, CO DL &b) { MD3(a.x < b, a.y < b, a.z < b); } +IOF(D3, <)(CO DL &a, CO D3 &b) { MD3(a < b.x, a < b.y, a < b.z); } +IOF(D4, <)(CO D4 &a, CO D4 &b) { MD4(a.x < b.x, a.y < b.y, a.z < b.z, a.w < b.w); } +IOF(D4, <)(CO D4 &a, CO DL &b) { MD4(a.x < b, a.y < b, a.z < b, a.w < b); } +IOF(D4, <)(CO DL &a, CO D4 &b) { MD4(a < b.x, a < b.y, a < b.z, a < b.w); } +IOF(D2, >=)(CO D2 &a, CO D2 &b) { MD2(a.x >= b.x, a.y >= b.y); } +IOF(D2, >=)(CO D2 &a, CO DL &b) { MD2(a.x >= b, a.y >= b); } +IOF(D2, >=)(CO DL &a, CO D2 &b) { MD2(a >= b.x, a >= b.y); } +IOF(D3, >=)(CO D3 &a, CO D3 &b) { MD3(a.x >= b.x, a.y >= b.y, a.z >= b.z); } +IOF(D3, >=)(CO D3 &a, CO DL &b) { MD3(a.x >= b, a.y >= b, a.z >= b); } +IOF(D3, >=)(CO DL &a, CO D3 &b) { MD3(a >= b.x, a >= b.y, a >= b.z); } +IOF(D4, >=)(CO D4 &a, CO D4 &b) { MD4(a.x >= b.x, a.y >= b.y, a.z >= b.z, a.w >= b.w); } +IOF(D4, >=)(CO D4 &a, CO DL &b) { MD4(a.x >= b, a.y >= b, a.z >= b, a.w >= b); } +IOF(D4, >=)(CO DL &a, CO D4 &b) { MD4(a >= b.x, a >= b.y, a >= b.z, a >= b.w); } +IOF(D2, <=)(CO D2 &a, CO D2 &b) { MD2(a.x <= b.x, a.y <= b.y); } +IOF(D2, <=)(CO D2 &a, CO DL &b) { MD2(a.x <= b, a.y <= b); } +IOF(D2, <=)(CO DL &a, CO D2 &b) { MD2(a <= b.x, a <= b.y); } +IOF(D3, <=)(CO D3 &a, CO D3 &b) { MD3(a.x <= b.x, a.y <= b.y, a.z <= b.z); } +IOF(D3, <=)(CO D3 &a, CO DL &b) { MD3(a.x <= b, a.y <= b, a.z <= b); } +IOF(D3, <=)(CO DL &a, CO D3 &b) { MD3(a <= b.x, a <= b.y, a <= b.z); } +IOF(D4, <=)(CO D4 &a, CO D4 &b) { MD4(a.x <= b.x, a.y <= b.y, a.z <= b.z, a.w <= b.w); } +IOF(D4, <=)(CO D4 &a, CO DL &b) { MD4(a.x <= b, a.y <= b, a.z <= b, a.w <= b); } +IOF(D4, <=)(CO DL &a, CO D4 &b) { MD4(a <= b.x, a <= b.y, a <= b.z, a <= b.w); } +IOF(D2, ==)(CO D2 &a, CO D2 &b) { MD2(a.x == b.x, a.y == b.y); } +IOF(D2, ==)(CO D2 &a, CO DL &b) { MD2(a.x == b, a.y == b); } +IOF(D2, ==)(CO DL &a, CO D2 &b) { MD2(a == b.x, a == b.y); } +IOF(D3, ==)(CO D3 &a, CO D3 &b) { MD3(a.x == b.x, a.y == b.y, a.z == b.z); } +IOF(D3, ==)(CO D3 &a, CO DL &b) { MD3(a.x == b, a.y == b, a.z == b); } +IOF(D3, ==)(CO DL &a, CO D3 &b) { MD3(a == b.x, a == b.y, a == b.z); } +IOF(D4, ==)(CO D4 &a, CO D4 &b) { MD4(a.x == b.x, a.y == b.y, a.z == b.z, a.w == b.w); } +IOF(D4, ==)(CO D4 &a, CO DL &b) { MD4(a.x == b, a.y == b, a.z == b, a.w == b); } +IOF(D4, ==)(CO DL &a, CO D4 &b) { MD4(a == b.x, a == b.y, a == b.z, a == b.w); } +IOF(D2, !=)(CO D2 &a, CO D2 &b) { MD2(a.x != b.x, a.y != b.y); } +IOF(D2, !=)(CO D2 &a, CO DL &b) { MD2(a.x != b, a.y != b); } +IOF(D2, !=)(CO DL &a, CO D2 &b) { MD2(a != b.x, a != b.y); } +IOF(D3, !=)(CO D3 &a, CO D3 &b) { MD3(a.x != b.x, a.y != b.y, a.z != b.z); } +IOF(D3, !=)(CO D3 &a, CO DL &b) { MD3(a.x != b, a.y != b, a.z != b); } +IOF(D3, !=)(CO DL &a, CO D3 &b) { MD3(a != b.x, a != b.y, a != b.z); } +IOF(D4, !=)(CO D4 &a, CO D4 &b) { MD4(a.x != b.x, a.y != b.y, a.z != b.z, a.w != b.w); } +IOF(D4, !=)(CO D4 &a, CO DL &b) { MD4(a.x != b, a.y != b, a.z != b, a.w != b); } +IOF(D4, !=)(CO DL &a, CO D4 &b) { MD4(a != b.x, a != b.y, a != b.z, a != b.w); } +inline D2 sin(CO D2 &a) { MD2(sin(a.x), sin(a.y)); } +inline D3 sin(CO D3 &a) { MD3(sin(a.x), sin(a.y), sin(a.z)); } +inline D4 sin(CO D4 &a) { MD4(sin(a.x), sin(a.y), sin(a.z), sin(a.w)); } +inline D2 cos(CO D2 &a) { MD2(cos(a.x), cos(a.y)); } +inline D3 cos(CO D3 &a) { MD3(cos(a.x), cos(a.y), cos(a.z)); } +inline D4 cos(CO D4 &a) { MD4(cos(a.x), cos(a.y), cos(a.z), cos(a.w)); } +inline D2 tan(CO D2 &a) { MD2(tan(a.x), tan(a.y)); } +inline D3 tan(CO D3 &a) { MD3(tan(a.x), tan(a.y), tan(a.z)); } +inline D4 tan(CO D4 &a) { MD4(tan(a.x), tan(a.y), tan(a.z), tan(a.w)); } +inline D2 asin(CO D2 &a) { MD2(asin(a.x), asin(a.y)); } +inline D3 asin(CO D3 &a) { MD3(asin(a.x), asin(a.y), asin(a.z)); } +inline D4 asin(CO D4 &a) { MD4(asin(a.x), asin(a.y), asin(a.z), asin(a.w)); } +inline D2 acos(CO D2 &a) { MD2(acos(a.x), acos(a.y)); } +inline D3 acos(CO D3 &a) { MD3(acos(a.x), acos(a.y), acos(a.z)); } +inline D4 acos(CO D4 &a) { MD4(acos(a.x), acos(a.y), acos(a.z), acos(a.w)); } +inline D2 atan(CO D2 &a) { MD2(atan(a.x), atan(a.y)); } +inline D3 atan(CO D3 &a) { MD3(atan(a.x), atan(a.y), atan(a.z)); } +inline D4 atan(CO D4 &a) { MD4(atan(a.x), atan(a.y), atan(a.z), atan(a.w)); } +inline D2 sinh(CO D2 &a) { MD2(sinh(a.x), sinh(a.y)); } +inline D3 sinh(CO D3 &a) { MD3(sinh(a.x), sinh(a.y), sinh(a.z)); } +inline D4 sinh(CO D4 &a) { MD4(sinh(a.x), sinh(a.y), sinh(a.z), sinh(a.w)); } +inline D2 cosh(CO D2 &a) { MD2(cosh(a.x), cosh(a.y)); } +inline D3 cosh(CO D3 &a) { MD3(cosh(a.x), cosh(a.y), cosh(a.z)); } +inline D4 cosh(CO D4 &a) { MD4(cosh(a.x), cosh(a.y), cosh(a.z), cosh(a.w)); } +inline D2 tanh(CO D2 &a) { MD2(tanh(a.x), tanh(a.y)); } +inline D3 tanh(CO D3 &a) { MD3(tanh(a.x), tanh(a.y), tanh(a.z)); } +inline D4 tanh(CO D4 &a) { MD4(tanh(a.x), tanh(a.y), tanh(a.z), tanh(a.w)); } +inline D2 asinh(CO D2 &a) { MD2(asinh(a.x), asinh(a.y)); } +inline D3 asinh(CO D3 &a) { MD3(asinh(a.x), asinh(a.y), asinh(a.z)); } +inline D4 asinh(CO D4 &a) { MD4(asinh(a.x), asinh(a.y), asinh(a.z), asinh(a.w)); } +inline D2 acosh(CO D2 &a) { MD2(acosh(a.x), acosh(a.y)); } +inline D3 acosh(CO D3 &a) { MD3(acosh(a.x), acosh(a.y), acosh(a.z)); } +inline D4 acosh(CO D4 &a) { MD4(acosh(a.x), acosh(a.y), acosh(a.z), acosh(a.w)); } +inline D2 atanh(CO D2 &a) { MD2(atanh(a.x), atanh(a.y)); } +inline D3 atanh(CO D3 &a) { MD3(atanh(a.x), atanh(a.y), atanh(a.z)); } +inline D4 atanh(CO D4 &a) { MD4(atanh(a.x), atanh(a.y), atanh(a.z), atanh(a.w)); } +inline D2 exp(CO D2 &a) { MD2(exp(a.x), exp(a.y)); } +inline D3 exp(CO D3 &a) { MD3(exp(a.x), exp(a.y), exp(a.z)); } +inline D4 exp(CO D4 &a) { MD4(exp(a.x), exp(a.y), exp(a.z), exp(a.w)); } +inline D2 log(CO D2 &a) { MD2(log(a.x), log(a.y)); } +inline D3 log(CO D3 &a) { MD3(log(a.x), log(a.y), log(a.z)); } +inline D4 log(CO D4 &a) { MD4(log(a.x), log(a.y), log(a.z), log(a.w)); } +inline D2 log2(CO D2 &a) { MD2(log2(a.x), log2(a.y)); } +inline D3 log2(CO D3 &a) { MD3(log2(a.x), log2(a.y), log2(a.z)); } +inline D4 log2(CO D4 &a) { MD4(log2(a.x), log2(a.y), log2(a.z), log2(a.w)); } +inline D2 log10(CO D2 &a) { MD2(log10(a.x), log10(a.y)); } +inline D3 log10(CO D3 &a) { MD3(log10(a.x), log10(a.y), log10(a.z)); } +inline D4 log10(CO D4 &a) { MD4(log10(a.x), log10(a.y), log10(a.z), log10(a.w)); } +inline D2 sqrt(CO D2 &a) { MD2(sqrt(a.x), sqrt(a.y)); } +inline D3 sqrt(CO D3 &a) { MD3(sqrt(a.x), sqrt(a.y), sqrt(a.z)); } +inline D4 sqrt(CO D4 &a) { MD4(sqrt(a.x), sqrt(a.y), sqrt(a.z), sqrt(a.w)); } +inline D2 cbrt(CO D2 &a) { MD2(cbrt(a.x), cbrt(a.y)); } +inline D3 cbrt(CO D3 &a) { MD3(cbrt(a.x), cbrt(a.y), cbrt(a.z)); } +inline D4 cbrt(CO D4 &a) { MD4(cbrt(a.x), cbrt(a.y), cbrt(a.z), cbrt(a.w)); } +inline D2 abs(CO D2 &a) { MD2(abs(a.x), abs(a.y)); } +inline D3 abs(CO D3 &a) { MD3(abs(a.x), abs(a.y), abs(a.z)); } +inline D4 abs(CO D4 &a) { MD4(abs(a.x), abs(a.y), abs(a.z), abs(a.w)); } +inline D2 ceil(CO D2 &a) { MD2(ceil(a.x), ceil(a.y)); } +inline D3 ceil(CO D3 &a) { MD3(ceil(a.x), ceil(a.y), ceil(a.z)); } +inline D4 ceil(CO D4 &a) { MD4(ceil(a.x), ceil(a.y), ceil(a.z), ceil(a.w)); } +inline D2 floor(CO D2 &a) { MD2(floor(a.x), floor(a.y)); } +inline D3 floor(CO D3 &a) { MD3(floor(a.x), floor(a.y), floor(a.z)); } +inline D4 floor(CO D4 &a) { MD4(floor(a.x), floor(a.y), floor(a.z), floor(a.w)); } #endif // LIBRAPID_CUDA_VECTOR_OPS_HELPER \ No newline at end of file diff --git a/librapid/include/librapid/cuda/nvrtc_helper.h b/librapid/include/librapid/cuda/nvrtc_helper.h index fc792f37..21497117 100644 --- a/librapid/include/librapid/cuda/nvrtc_helper.h +++ b/librapid/include/librapid/cuda/nvrtc_helper.h @@ -38,169 +38,169 @@ #include #define NVRTC_SAFE_CALL(Name, x) \ - do { \ - nvrtcResult result = x; \ - if (result != NVRTC_SUCCESS) { \ - std::cerr << "\nerror: " << Name << " failed with error " \ - << nvrtcGetErrorString(result); \ - exit(1); \ - } \ - } while (0) + do { \ + nvrtcResult result = x; \ + if (result != NVRTC_SUCCESS) { \ + std::cerr << "\nerror: " << Name << " failed with error " \ + << nvrtcGetErrorString(result); \ + exit(1); \ + } \ + } while (0) void compileFileToCUBIN(char *filename, int argc, char **argv, char **cubinResult, - size_t *cubinResultSize, int requiresCGheaders) { - std::ifstream inputFile(filename, std::ios::in | std::ios::binary | std::ios::ate); + size_t *cubinResultSize, int requiresCGheaders) { + std::ifstream inputFile(filename, std::ios::in | std::ios::binary | std::ios::ate); - if (!inputFile.is_open()) { - std::cerr << "\nerror: unable to open " << filename << " for reading!\n"; - exit(1); - } + if (!inputFile.is_open()) { + std::cerr << "\nerror: unable to open " << filename << " for reading!\n"; + exit(1); + } - std::streampos pos = inputFile.tellg(); - size_t inputSize = (size_t)pos; - char *memBlock = new char[inputSize + 1]; + std::streampos pos = inputFile.tellg(); + size_t inputSize = (size_t)pos; + char *memBlock = new char[inputSize + 1]; - inputFile.seekg(0, std::ios::beg); - inputFile.read(memBlock, inputSize); - inputFile.close(); - memBlock[inputSize] = '\x0'; + inputFile.seekg(0, std::ios::beg); + inputFile.read(memBlock, inputSize); + inputFile.close(); + memBlock[inputSize] = '\x0'; - int numCompileOptions = 0; + int numCompileOptions = 0; - char *compileParams[2]; + char *compileParams[2]; - int major = 0, minor = 0; - char deviceName[256]; + int major = 0, minor = 0; + char deviceName[256]; - // Picks the best CUDA device available - CUdevice cuDevice = findCudaDeviceDRV(argc, (const char **)argv); + // Picks the best CUDA device available + CUdevice cuDevice = findCudaDeviceDRV(argc, (const char **)argv); - // get compute capabilities and the devicename - checkCudaErrors( - cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cuDevice)); - checkCudaErrors( - cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cuDevice)); + // get compute capabilities and the devicename + checkCudaErrors( + cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cuDevice)); + checkCudaErrors( + cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cuDevice)); - { - // Compile cubin for the GPU arch on which are going to run cuda kernel. - std::string compileOptions; - compileOptions = "--gpu-architecture=sm_"; + { + // Compile cubin for the GPU arch on which are going to run cuda kernel. + std::string compileOptions; + compileOptions = "--gpu-architecture=sm_"; - compileParams[numCompileOptions] = - reinterpret_cast(malloc(sizeof(char) * (compileOptions.length() + 10))); + compileParams[numCompileOptions] = + reinterpret_cast(malloc(sizeof(char) * (compileOptions.length() + 10))); #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) - sprintf_s(compileParams[numCompileOptions], - sizeof(char) * (compileOptions.length() + 10), - "%s%d%d", - compileOptions.c_str(), - major, - minor); + sprintf_s(compileParams[numCompileOptions], + sizeof(char) * (compileOptions.length() + 10), + "%s%d%d", + compileOptions.c_str(), + major, + minor); #else - snprintf(compileParams[numCompileOptions], - compileOptions.size() + 10, - "%s%d%d", - compileOptions.c_str(), - major, - minor); + snprintf(compileParams[numCompileOptions], + compileOptions.size() + 10, + "%s%d%d", + compileOptions.c_str(), + major, + minor); #endif - } + } - numCompileOptions++; + numCompileOptions++; - if (requiresCGheaders) { - std::string compileOptions; - char HeaderNames[256]; + if (requiresCGheaders) { + std::string compileOptions; + char HeaderNames[256]; #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) - sprintf_s(HeaderNames, sizeof(HeaderNames), "%s", "cooperative_groups.h"); + sprintf_s(HeaderNames, sizeof(HeaderNames), "%s", "cooperative_groups.h"); #else - snprintf(HeaderNames, sizeof(HeaderNames), "%s", "cooperative_groups.h"); + snprintf(HeaderNames, sizeof(HeaderNames), "%s", "cooperative_groups.h"); #endif - compileOptions = "--include-path="; - - std::string path = sdkFindFilePath(HeaderNames, argv[0]); - if (!path.empty()) { - std::size_t found = path.find(HeaderNames); - path.erase(found); - } else { - printf( - "\nCooperativeGroups headers not found, please install it in %s " - "sample directory..\n Exiting..\n", - argv[0]); - } - compileOptions += path.c_str(); - compileParams[numCompileOptions] = - reinterpret_cast(malloc(sizeof(char) * (compileOptions.length() + 1))); + compileOptions = "--include-path="; + + std::string path = sdkFindFilePath(HeaderNames, argv[0]); + if (!path.empty()) { + std::size_t found = path.find(HeaderNames); + path.erase(found); + } else { + printf( + "\nCooperativeGroups headers not found, please install it in %s " + "sample directory..\n Exiting..\n", + argv[0]); + } + compileOptions += path.c_str(); + compileParams[numCompileOptions] = + reinterpret_cast(malloc(sizeof(char) * (compileOptions.length() + 1))); #if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) - sprintf_s(compileParams[numCompileOptions], - sizeof(char) * (compileOptions.length() + 1), - "%s", - compileOptions.c_str()); + sprintf_s(compileParams[numCompileOptions], + sizeof(char) * (compileOptions.length() + 1), + "%s", + compileOptions.c_str()); #else - snprintf( - compileParams[numCompileOptions], compileOptions.size(), "%s", compileOptions.c_str()); + snprintf( + compileParams[numCompileOptions], compileOptions.size(), "%s", compileOptions.c_str()); #endif - numCompileOptions++; - } + numCompileOptions++; + } - // compile - nvrtcProgram prog; - NVRTC_SAFE_CALL("nvrtcCreateProgram", - nvrtcCreateProgram(&prog, memBlock, filename, 0, NULL, NULL)); + // compile + nvrtcProgram prog; + NVRTC_SAFE_CALL("nvrtcCreateProgram", + nvrtcCreateProgram(&prog, memBlock, filename, 0, NULL, NULL)); - nvrtcResult res = nvrtcCompileProgram(prog, numCompileOptions, compileParams); + nvrtcResult res = nvrtcCompileProgram(prog, numCompileOptions, compileParams); - // dump log - size_t logSize; - NVRTC_SAFE_CALL("nvrtcGetProgramLogSize", nvrtcGetProgramLogSize(prog, &logSize)); - char *log = reinterpret_cast(malloc(sizeof(char) * logSize + 1)); - NVRTC_SAFE_CALL("nvrtcGetProgramLog", nvrtcGetProgramLog(prog, log)); - log[logSize] = '\x0'; + // dump log + size_t logSize; + NVRTC_SAFE_CALL("nvrtcGetProgramLogSize", nvrtcGetProgramLogSize(prog, &logSize)); + char *log = reinterpret_cast(malloc(sizeof(char) * logSize + 1)); + NVRTC_SAFE_CALL("nvrtcGetProgramLog", nvrtcGetProgramLog(prog, log)); + log[logSize] = '\x0'; - if (strlen(log) >= 2) { - std::cerr << "\n compilation log ---\n"; - std::cerr << log; - std::cerr << "\n end log ---\n"; - } + if (strlen(log) >= 2) { + std::cerr << "\n compilation log ---\n"; + std::cerr << log; + std::cerr << "\n end log ---\n"; + } - free(log); + free(log); - NVRTC_SAFE_CALL("nvrtcCompileProgram", res); + NVRTC_SAFE_CALL("nvrtcCompileProgram", res); - size_t codeSize; - NVRTC_SAFE_CALL("nvrtcGetCUBINSize", nvrtcGetCUBINSize(prog, &codeSize)); - char *code = new char[codeSize]; - NVRTC_SAFE_CALL("nvrtcGetCUBIN", nvrtcGetCUBIN(prog, code)); - *cubinResult = code; - *cubinResultSize = codeSize; + size_t codeSize; + NVRTC_SAFE_CALL("nvrtcGetCUBINSize", nvrtcGetCUBINSize(prog, &codeSize)); + char *code = new char[codeSize]; + NVRTC_SAFE_CALL("nvrtcGetCUBIN", nvrtcGetCUBIN(prog, code)); + *cubinResult = code; + *cubinResultSize = codeSize; - for (int i = 0; i < numCompileOptions; i++) { free(compileParams[i]); } + for (int i = 0; i < numCompileOptions; i++) { free(compileParams[i]); } } CUmodule loadCUBIN(char *cubin, int argc, char **argv) { - CUmodule module; - CUcontext context; - int major = 0, minor = 0; - char deviceName[256]; + CUmodule module; + CUcontext context; + int major = 0, minor = 0; + char deviceName[256]; - // Picks the best CUDA device available - CUdevice cuDevice = findCudaDeviceDRV(argc, (const char **)argv); + // Picks the best CUDA device available + CUdevice cuDevice = findCudaDeviceDRV(argc, (const char **)argv); - // get compute capabilities and the devicename - checkCudaErrors( - cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cuDevice)); - checkCudaErrors( - cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cuDevice)); - checkCudaErrors(cuDeviceGetName(deviceName, 256, cuDevice)); - printf("> GPU Device has SM %d.%d compute capability\n", major, minor); + // get compute capabilities and the devicename + checkCudaErrors( + cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, cuDevice)); + checkCudaErrors( + cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, cuDevice)); + checkCudaErrors(cuDeviceGetName(deviceName, 256, cuDevice)); + printf("> GPU Device has SM %d.%d compute capability\n", major, minor); - checkCudaErrors(cuInit(0)); - checkCudaErrors(cuCtxCreate(&context, 0, cuDevice)); + checkCudaErrors(cuInit(0)); + checkCudaErrors(cuCtxCreate(&context, 0, cuDevice)); - checkCudaErrors(cuModuleLoadData(&module, cubin)); - free(cubin); + checkCudaErrors(cuModuleLoadData(&module, cubin)); + free(cubin); - return module; + return module; } #endif // COMMON_NVRTC_HELPER_H_ diff --git a/librapid/include/librapid/math/compileTime.hpp b/librapid/include/librapid/math/compileTime.hpp index fea3d995..36a82cce 100644 --- a/librapid/include/librapid/math/compileTime.hpp +++ b/librapid/include/librapid/math/compileTime.hpp @@ -2,14 +2,14 @@ #define LIBRAPID_MATH_COMPILE_TIME_HPP namespace librapid { - template - constexpr size_t product() { - if constexpr (sizeof...(Rest) == 0) { - return First; - } else { - return First * product(); - } - } + template + constexpr size_t product() { + if constexpr (sizeof...(Rest) == 0) { + return First; + } else { + return First * product(); + } + } } // namespace librapid #endif // LIBRAPID_MATH_COMPILE_TIME_HPP \ No newline at end of file diff --git a/librapid/include/librapid/math/complex.hpp b/librapid/include/librapid/math/complex.hpp index abb74f66..49268b1b 100644 --- a/librapid/include/librapid/math/complex.hpp +++ b/librapid/include/librapid/math/complex.hpp @@ -13,2068 +13,2068 @@ */ #if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC)) -# define USE_X86_X64_INTRINSICS -# include +# define USE_X86_X64_INTRINSICS +# include #elif defined(_M_ARM64) || defined(_M_ARM64EC) -# define USE_ARM64_INTRINSICS -# include +# define USE_ARM64_INTRINSICS +# include #endif namespace librapid { - namespace detail { - // Implements floating-point arithmetic for numeric algorithms - namespace multiprec { - template - struct Fmp { - Scalar val0; // Most significant numeric_limits::precision bits - Scalar val1; // Least significant numeric_limits::precision bits - }; - - /// \brief Summarizes two 1x precision values combined into a 2x precision result - /// - /// This function is exact when: - /// 1. The result doesn't overflow - /// 2. Either underflow is gradual, or no internal underflow occurs - /// 3. Intermediate precision is either the same as T, or greater than twice the - /// precision of T - /// 4. Parameters and local variables do not retain extra intermediate precision - /// 5. Rounding mode is rounding to nearest. - /// - /// Violation of condition 3 or 5 could lead to relative error on the order of - /// epsilon^2. - /// - /// Violation of other conditions could lead to worse results - /// - /// \tparam T Template type - /// \param x First value - /// \param y Second value - /// \return Sum of x and y - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto addX2(const T &x, - const T &y) noexcept - -> Fmp { - const T sum0 = x + y; - const T yMod = sum0 - x; - const T xMod = sum0 - yMod; - const T yErr = y - yMod; - const T xErr = x - xMod; - return {sum0, xErr + yErr}; - } - - /// \brief Combines two 1x precision values into a 2x precision result with the - /// requirement of specific exponent relationship - /// - /// Requires: exponent(x) + countr_zero(significand(x)) >= exponent(y) or x == 0 - /// - /// The result is exact when: - /// 1. The requirement above is satisfied - /// 2. No internal overflow occurs - /// 3. Either underflow is gradual, or no internal underflow occurs - /// 4. Intermediate precision is either the same as T, or greater than twice the - /// precision of T - /// 5. Parameters and local variables do not retain extra intermediate precision - /// 6. Rounding mode is rounding to nearest - /// - /// Violation of condition 3 or 5 could lead to relative error on the order of - /// epsilon^2. - /// - /// Violation of other conditions could lead to worse results - /// - /// \tparam T Template type - /// \param x First value - /// \param y Second value - /// \return Sum of x and y - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto addSmallX2(const T x, - const T y) noexcept - -> Fmp { - const T sum0 = x + y; - const T yMod = sum0 - x; - const T yErr = y - yMod; - return {sum0, yErr}; - } - - /// \brief Combines a 1x precision value with a 2x precision value - /// - /// Requires: exponent(x) + countr_zero(significand(x)) >= exponent(y.val0) or x == 0 - /// - /// \tparam T Template type - /// \param x First value - /// \param y Second value - /// \return Sum of x and y - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto - addSmallX2(const T &x, const Fmp &y) noexcept -> Fmp { - const Fmp sum0 = addSmallX2(x, y.val0); - return addSmallX2(sum0.val0, sum0.val1 + y.val1); - } - - /// \brief Combines two 2x precision values into a 1x precision result - /// \tparam T Template type - /// \param x First value - /// \param y Second value - /// \return Sum of x and y - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto addX1(const Fmp &x, - const Fmp &y) noexcept - -> T { - const Fmp sum0 = addX2(x.val0, y.val0); - return sum0.val0 + (sum0.val1 + (x.val1 + y.val1)); - } - - /// \brief Rounds a 2x precision value to 26 significant bits - /// \param x Value to round - /// \return Rounded value - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto - highHalf(const double x) noexcept -> double { - const auto bits = bitCast(x); - const auto highHalfBits = (bits + 0x3ff'ffffULL) & 0xffff'ffff'f800'0000ULL; - return bitCast(highHalfBits); - } + namespace detail { + // Implements floating-point arithmetic for numeric algorithms + namespace multiprec { + template + struct Fmp { + Scalar val0; // Most significant numeric_limits::precision bits + Scalar val1; // Least significant numeric_limits::precision bits + }; + + /// \brief Summarizes two 1x precision values combined into a 2x precision result + /// + /// This function is exact when: + /// 1. The result doesn't overflow + /// 2. Either underflow is gradual, or no internal underflow occurs + /// 3. Intermediate precision is either the same as T, or greater than twice the + /// precision of T + /// 4. Parameters and local variables do not retain extra intermediate precision + /// 5. Rounding mode is rounding to nearest. + /// + /// Violation of condition 3 or 5 could lead to relative error on the order of + /// epsilon^2. + /// + /// Violation of other conditions could lead to worse results + /// + /// \tparam T Template type + /// \param x First value + /// \param y Second value + /// \return Sum of x and y + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto addX2(const T &x, + const T &y) noexcept + -> Fmp { + const T sum0 = x + y; + const T yMod = sum0 - x; + const T xMod = sum0 - yMod; + const T yErr = y - yMod; + const T xErr = x - xMod; + return {sum0, xErr + yErr}; + } + + /// \brief Combines two 1x precision values into a 2x precision result with the + /// requirement of specific exponent relationship + /// + /// Requires: exponent(x) + countr_zero(significand(x)) >= exponent(y) or x == 0 + /// + /// The result is exact when: + /// 1. The requirement above is satisfied + /// 2. No internal overflow occurs + /// 3. Either underflow is gradual, or no internal underflow occurs + /// 4. Intermediate precision is either the same as T, or greater than twice the + /// precision of T + /// 5. Parameters and local variables do not retain extra intermediate precision + /// 6. Rounding mode is rounding to nearest + /// + /// Violation of condition 3 or 5 could lead to relative error on the order of + /// epsilon^2. + /// + /// Violation of other conditions could lead to worse results + /// + /// \tparam T Template type + /// \param x First value + /// \param y Second value + /// \return Sum of x and y + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto addSmallX2(const T x, + const T y) noexcept + -> Fmp { + const T sum0 = x + y; + const T yMod = sum0 - x; + const T yErr = y - yMod; + return {sum0, yErr}; + } + + /// \brief Combines a 1x precision value with a 2x precision value + /// + /// Requires: exponent(x) + countr_zero(significand(x)) >= exponent(y.val0) or x == 0 + /// + /// \tparam T Template type + /// \param x First value + /// \param y Second value + /// \return Sum of x and y + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto + addSmallX2(const T &x, const Fmp &y) noexcept -> Fmp { + const Fmp sum0 = addSmallX2(x, y.val0); + return addSmallX2(sum0.val0, sum0.val1 + y.val1); + } + + /// \brief Combines two 2x precision values into a 1x precision result + /// \tparam T Template type + /// \param x First value + /// \param y Second value + /// \return Sum of x and y + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto addX1(const Fmp &x, + const Fmp &y) noexcept + -> T { + const Fmp sum0 = addX2(x.val0, y.val0); + return sum0.val0 + (sum0.val1 + (x.val1 + y.val1)); + } + + /// \brief Rounds a 2x precision value to 26 significant bits + /// \param x Value to round + /// \return Rounded value + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto + highHalf(const double x) noexcept -> double { + const auto bits = bitCast(x); + const auto highHalfBits = (bits + 0x3ff'ffffULL) & 0xffff'ffff'f800'0000ULL; + return bitCast(highHalfBits); + } #if defined(USE_X86_X64_INTRINSICS) || defined(USE_ARM64_INTRINSICS) // SIMD method - /// \brief Calculates the error between x^2 and its faithfully rounded product prod0 - /// - /// The result is exact when: - /// 1. prod0 is x^2 faithfully rounded - /// 2. No internal overflow or underflow occurs - /// - /// Violation of condition 1 could lead to relative error on the order of epsilon. - /// - /// \param x Input value - /// \param prod0 Faithfully rounded product of x^2 - /// \return Error between x^2 and prod0 - LIBRAPID_NODISCARD - LIBRAPID_ALWAYS_INLINE auto sqrError(const double x, const double prod0) noexcept - -> double { -# if defined(USE_X86_X64_INTRINSICS) - const __m128d xVec = _mm_set_sd(x); - const __m128d prodVec = _mm_set_sd(prod0); - const __m128d resultVec = _mm_fmsub_sd(xVec, xVec, prodVec); - double result; - _mm_store_sd(&result, resultVec); - return result; -# else // Only two options, so this is fine - const float64x1_t xVec = vld1_double(&x); - const float64x1_t prod0Vec = vld1_double(&prod0); - const float64x1_t resultVec = vfma_double(vneg_double(prod0Vec), xVec, xVec); - double result; - vst1_double(&result, resultVec); - return result; -# endif - } + /// \brief Calculates the error between x^2 and its faithfully rounded product prod0 + /// + /// The result is exact when: + /// 1. prod0 is x^2 faithfully rounded + /// 2. No internal overflow or underflow occurs + /// + /// Violation of condition 1 could lead to relative error on the order of epsilon. + /// + /// \param x Input value + /// \param prod0 Faithfully rounded product of x^2 + /// \return Error between x^2 and prod0 + LIBRAPID_NODISCARD + LIBRAPID_ALWAYS_INLINE auto sqrError(const double x, const double prod0) noexcept + -> double { +# if defined(USE_X86_X64_INTRINSICS) + const __m128d xVec = _mm_set_sd(x); + const __m128d prodVec = _mm_set_sd(prod0); + const __m128d resultVec = _mm_fmsub_sd(xVec, xVec, prodVec); + double result; + _mm_store_sd(&result, resultVec); + return result; +# else // Only two options, so this is fine + const float64x1_t xVec = vld1_double(&x); + const float64x1_t prod0Vec = vld1_double(&prod0); + const float64x1_t resultVec = vfma_double(vneg_double(prod0Vec), xVec, xVec); + double result; + vst1_double(&result, resultVec); + return result; +# endif + } #else - /// \brief Fallback method for sqrError(const double, const double) when SIMD is not - /// available. - LIBRAPID_NODISCARD - LIBRAPID_ALWAYS_INLINE constexpr double sqrError(const double x, - const double prod0) noexcept { - const double xHigh = highHalf(x); - const double xLow = x - xHigh; - return ((xHigh * xHigh - prod0) + 2.0 * xHigh * xLow) + xLow * xLow; - } + /// \brief Fallback method for sqrError(const double, const double) when SIMD is not + /// available. + LIBRAPID_NODISCARD + LIBRAPID_ALWAYS_INLINE constexpr double sqrError(const double x, + const double prod0) noexcept { + const double xHigh = highHalf(x); + const double xLow = x - xHigh; + return ((xHigh * xHigh - prod0) + 2.0 * xHigh * xLow) + xLow * xLow; + } #endif - /// \brief Type-agnostic version of sqrError(const double, const double) - /// \tparam T Template type - /// \param x Input value - /// \param prod0 Faithfully rounded product of x^2 - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrError(const T x, - const T prod0) noexcept -> T { - const T xHigh = static_cast(highHalf(x)); - const T xLow = x - xHigh; - return ((xHigh * xHigh - prod0) + static_cast(2.0) * xHigh * xLow) + xLow * xLow; - } - - /// \brief Calculates the square of a 1x precision value and returns a 2x precision - /// result - /// - /// The result is exact when no internal overflow or underflow occurs. - /// - /// \param x Input value - /// \return 2x precision square of x - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrX2(const double x) noexcept - -> Fmp { - const double prod0 = x * x; - return {prod0, sqrError(x, prod0)}; - } - - /// \brief Type-agnostic version of sqrX2(const double) - /// \tparam T Template type - /// \param x Input value - /// \return 2x precision square of x - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrX2(const T x) noexcept -> Fmp { - const T prod0 = x * x; - return {prod0, static_cast(sqrError(x, prod0))}; - } - } // namespace multiprec - - namespace algorithm { - // HypotLegHuge = T{0.5} * sqrt((numeric_limits::max())); - // HypotLegTiny = sqrt(T{2.0} * (numeric_limits::min)() / - // numeric_limits::epsilon()); - - template - struct HypotLegHugeHelper { - // If is an integer type, divide by two rather than multiplying by 0.5, as - // 0.5 gets truncated to zero - static inline T val = - (std::is_integral_v) - ? (::librapid::sqrt(typetraits::TypeInfo::max()) / T(2)) - : (T(0.5) * ::librapid::sqrt(typetraits::TypeInfo::max())); - }; - - template<> - struct HypotLegHugeHelper { - static constexpr double val = 6.703903964971298e+153; - }; - - template<> - struct HypotLegHugeHelper { - static constexpr double val = 9.2233715e+18f; - }; - - template - struct HypotLegTinyHelper { - // If is an integer type, divide by two rather than multiplying by 0.5, as - // 0.5 gets truncated to zero - static inline T val = ::librapid::sqrt(T(2) * typetraits::TypeInfo::min() / - typetraits::TypeInfo::epsilon()); - }; - - template<> - struct HypotLegTinyHelper { - static constexpr double val = 1.4156865331029228e-146; - }; - - template<> - struct HypotLegTinyHelper { - static constexpr double val = 4.440892e-16f; - }; - - template - static inline T HypotLegHuge = HypotLegHugeHelper::val; - template - static inline T HypotLegTiny = HypotLegTinyHelper::val; - - /// \brief Calculates \f$ x^2 + y^2 - 1 \f$ for - /// \f$ |x| \geq |y| \f$ and \f$ 0.5 \leq |x| < 2^{12} \f$ - /// \tparam T Template type \param x First value \param y Second value - /// \return x * x + y * y - 1 - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto normMinusOne(const T x, - const T y) noexcept -> T { - const multiprec::Fmp xSqr = multiprec::sqrX2(x); - const multiprec::Fmp ySqr = multiprec::sqrX2(y); - const multiprec::Fmp xSqrM1 = multiprec::addSmallX2(T(-1), xSqr); - return multiprec::addX1(xSqrM1, ySqr); - } - - /// \brief Calculates \f$ \log(1 + x) \f$ - /// - /// May be inaccurate for small inputs - /// - /// \tparam safe If true, will check for NaNs and overflow - /// \tparam T Template type - /// \param x Input value - /// \return \f$ \log(1 + x) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto logP1(const T x) -> T { - if constexpr (!safe) return ::librapid::log(x + 1.0); + /// \brief Type-agnostic version of sqrError(const double, const double) + /// \tparam T Template type + /// \param x Input value + /// \param prod0 Faithfully rounded product of x^2 + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrError(const T x, + const T prod0) noexcept -> T { + const T xHigh = static_cast(highHalf(x)); + const T xLow = x - xHigh; + return ((xHigh * xHigh - prod0) + static_cast(2.0) * xHigh * xLow) + xLow * xLow; + } + + /// \brief Calculates the square of a 1x precision value and returns a 2x precision + /// result + /// + /// The result is exact when no internal overflow or underflow occurs. + /// + /// \param x Input value + /// \return 2x precision square of x + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrX2(const double x) noexcept + -> Fmp { + const double prod0 = x * x; + return {prod0, sqrError(x, prod0)}; + } + + /// \brief Type-agnostic version of sqrX2(const double) + /// \tparam T Template type + /// \param x Input value + /// \return 2x precision square of x + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrX2(const T x) noexcept -> Fmp { + const T prod0 = x * x; + return {prod0, static_cast(sqrError(x, prod0))}; + } + } // namespace multiprec + + namespace algorithm { + // HypotLegHuge = T{0.5} * sqrt((numeric_limits::max())); + // HypotLegTiny = sqrt(T{2.0} * (numeric_limits::min)() / + // numeric_limits::epsilon()); + + template + struct HypotLegHugeHelper { + // If is an integer type, divide by two rather than multiplying by 0.5, as + // 0.5 gets truncated to zero + static inline T val = + (std::is_integral_v) + ? (::librapid::sqrt(typetraits::TypeInfo::max()) / T(2)) + : (T(0.5) * ::librapid::sqrt(typetraits::TypeInfo::max())); + }; + + template<> + struct HypotLegHugeHelper { + static constexpr double val = 6.703903964971298e+153; + }; + + template<> + struct HypotLegHugeHelper { + static constexpr double val = 9.2233715e+18f; + }; + + template + struct HypotLegTinyHelper { + // If is an integer type, divide by two rather than multiplying by 0.5, as + // 0.5 gets truncated to zero + static inline T val = ::librapid::sqrt(T(2) * typetraits::TypeInfo::min() / + typetraits::TypeInfo::epsilon()); + }; + + template<> + struct HypotLegTinyHelper { + static constexpr double val = 1.4156865331029228e-146; + }; + + template<> + struct HypotLegTinyHelper { + static constexpr double val = 4.440892e-16f; + }; + + template + static inline T HypotLegHuge = HypotLegHugeHelper::val; + template + static inline T HypotLegTiny = HypotLegTinyHelper::val; + + /// \brief Calculates \f$ x^2 + y^2 - 1 \f$ for + /// \f$ |x| \geq |y| \f$ and \f$ 0.5 \leq |x| < 2^{12} \f$ + /// \tparam T Template type \param x First value \param y Second value + /// \return x * x + y * y - 1 + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto normMinusOne(const T x, + const T y) noexcept -> T { + const multiprec::Fmp xSqr = multiprec::sqrX2(x); + const multiprec::Fmp ySqr = multiprec::sqrX2(y); + const multiprec::Fmp xSqrM1 = multiprec::addSmallX2(T(-1), xSqr); + return multiprec::addX1(xSqrM1, ySqr); + } + + /// \brief Calculates \f$ \log(1 + x) \f$ + /// + /// May be inaccurate for small inputs + /// + /// \tparam safe If true, will check for NaNs and overflow + /// \tparam T Template type + /// \param x Input value + /// \return \f$ \log(1 + x) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto logP1(const T x) -> T { + if constexpr (!safe) return ::librapid::log(x + 1.0); #if defined(LIBRAPID_USE_MULTIPREC) - // No point doing anything shown below if we're using multiprec - if constexpr (std::is_same_v) return ::librapid::log(x + 1.0); + // No point doing anything shown below if we're using multiprec + if constexpr (std::is_same_v) return ::librapid::log(x + 1.0); #endif - if (::librapid::isNaN(x)) return x + x; // Trigger a signaling NaN - - // Naive formula - if (x <= T(-0.5) || T(2) <= x) { - // To avoid overflow - if (x == typetraits::TypeInfo::max()) return ::librapid::log(x); - return ::librapid::log(T(1) + x); - } - - const T absX = ::librapid::abs(x); - if (absX < typetraits::TypeInfo::epsilon()) { - if (x == T(0)) return x; - return x - T(0.5) * x * x; // Honour rounding - } - - // log(1 + x) with fix for small x - const multiprec::Fmp tmp = multiprec::addSmallX2(T(1), x); - return ::librapid::log(tmp.val0) + tmp.val1 / tmp.val0; - } - - // Return log(hypot(x, y)) - - /// \brief Calculates \f$ \log(\sqrt{x^2 + y^2}) \f$ - /// \tparam safe If true, will check for NaNs and overflow - /// \tparam T Template type - /// \param x Horizontal component - /// \param y Vertical component - /// \return \f$ \log(\sqrt{x^2 + y^2}) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto logHypot(const T x, const T y) noexcept - -> T { - if constexpr (!safe) return ::librapid::log(::librapid::sqrt(x * x + y * y)); + if (::librapid::isNaN(x)) return x + x; // Trigger a signaling NaN + + // Naive formula + if (x <= T(-0.5) || T(2) <= x) { + // To avoid overflow + if (x == typetraits::TypeInfo::max()) return ::librapid::log(x); + return ::librapid::log(T(1) + x); + } + + const T absX = ::librapid::abs(x); + if (absX < typetraits::TypeInfo::epsilon()) { + if (x == T(0)) return x; + return x - T(0.5) * x * x; // Honour rounding + } + + // log(1 + x) with fix for small x + const multiprec::Fmp tmp = multiprec::addSmallX2(T(1), x); + return ::librapid::log(tmp.val0) + tmp.val1 / tmp.val0; + } + + // Return log(hypot(x, y)) + + /// \brief Calculates \f$ \log(\sqrt{x^2 + y^2}) \f$ + /// \tparam safe If true, will check for NaNs and overflow + /// \tparam T Template type + /// \param x Horizontal component + /// \param y Vertical component + /// \return \f$ \log(\sqrt{x^2 + y^2}) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto logHypot(const T x, const T y) noexcept + -> T { + if constexpr (!safe) return ::librapid::log(::librapid::sqrt(x * x + y * y)); #if defined(LIBRAPID_USE_MULTIPREC) - // No point doing anything shown below if we're using multiprec - if constexpr (std::is_same_v) - return ::librapid::log(::mpfr::hypot(x, y)); - else { + // No point doing anything shown below if we're using multiprec + if constexpr (std::is_same_v) + return ::librapid::log(::mpfr::hypot(x, y)); + else { #endif - if (!::librapid::isFinite(x) || !::librapid::isFinite(y)) { // Inf or NaN - // Return NaN and raise FE_INVALID if either x or y is NaN - if (::librapid::isNaN(x) || ::librapid::isNaN(y)) return x + y; - - // Return Inf if either of them is infinity - if (::librapid::isInf(x)) return x; - if (::librapid::isInf(y)) return y; - - return x + y; // Fallback - } - - T absX = ::librapid::abs(x); - T absY = ::librapid::abs(y); - - if (absX < absY) std::swap(absX, absY); // Ensure absX > absY - if (absY == 0) return ::librapid::log(absX); // One side has zero length - - // Avoid overflow and underflow - if (HypotLegTiny < absX && absX < HypotLegHuge) { - constexpr auto normSmall = T(0.5); - constexpr auto normBig = T(3.0); - - const T absYSqr = absY * absY; - - if (absX == T(1)) return logP1(absYSqr) * T(0.5); - - const T norm = absX * absX + absYSqr; - if (normSmall < norm && norm < normBig) // Avoid cancellation - return logP1(normMinusOne(absX, absY)) * T(0.5); - return ::librapid::log(norm) * T(0.5); - } else { // Use 1 1/2 precision to preserve bits - constexpr T cm = T(22713.0L / 32768.0L); // Not sure where this came from - constexpr T cl = T(1.4286068203094172321214581765680755e-6L); // Or this... - - const int exp = std::ilogb(absX); - const T absXScaled = std::scalbn(absX, -exp); - const T absYScaled = std::scalbn(absY, -exp); - const T absYScaledSqr = absYScaled * absYScaled; - const T normScaled = absXScaled * absXScaled + absYScaledSqr; - const T realShifted = ::librapid::log(normScaled) * T(0.5); - const auto fExp = static_cast(exp); - return (realShifted + fExp * cl) + fExp * cm; - } + if (!::librapid::isFinite(x) || !::librapid::isFinite(y)) { // Inf or NaN + // Return NaN and raise FE_INVALID if either x or y is NaN + if (::librapid::isNaN(x) || ::librapid::isNaN(y)) return x + y; + + // Return Inf if either of them is infinity + if (::librapid::isInf(x)) return x; + if (::librapid::isInf(y)) return y; + + return x + y; // Fallback + } + + T absX = ::librapid::abs(x); + T absY = ::librapid::abs(y); + + if (absX < absY) std::swap(absX, absY); // Ensure absX > absY + if (absY == 0) return ::librapid::log(absX); // One side has zero length + + // Avoid overflow and underflow + if (HypotLegTiny < absX && absX < HypotLegHuge) { + constexpr auto normSmall = T(0.5); + constexpr auto normBig = T(3.0); + + const T absYSqr = absY * absY; + + if (absX == T(1)) return logP1(absYSqr) * T(0.5); + + const T norm = absX * absX + absYSqr; + if (normSmall < norm && norm < normBig) // Avoid cancellation + return logP1(normMinusOne(absX, absY)) * T(0.5); + return ::librapid::log(norm) * T(0.5); + } else { // Use 1 1/2 precision to preserve bits + constexpr T cm = T(22713.0L / 32768.0L); // Not sure where this came from + constexpr T cl = T(1.4286068203094172321214581765680755e-6L); // Or this... + + const int exp = std::ilogb(absX); + const T absXScaled = std::scalbn(absX, -exp); + const T absYScaled = std::scalbn(absY, -exp); + const T absYScaledSqr = absYScaled * absYScaled; + const T normScaled = absXScaled * absXScaled + absYScaledSqr; + const T realShifted = ::librapid::log(normScaled) * T(0.5); + const auto fExp = static_cast(exp); + return (realShifted + fExp * cl) + fExp * cm; + } #if defined(LIBRAPID_USE_MULTIPREC) - } // This ensures the "if constexpr" above actually stops compiler errors + } // This ensures the "if constexpr" above actually stops compiler errors #endif - } - - /// \brief Compute \f$e^{\text{pleft}} \times \text{right} \times 2^{\text{exponent}}\f$ - /// - /// \tparam T Template type - /// \param pleft Pointer to the value to be exponentiated - /// \param right Multiplier for the exponentiated value - /// \param exponent Exponent for the power of 2 multiplication - /// \return 1 if the result is NaN or Inf, -1 otherwise - template - auto expMul(T *pleft, T right, short exponent) -> short { + } + + /// \brief Compute \f$e^{\text{pleft}} \times \text{right} \times 2^{\text{exponent}}\f$ + /// + /// \tparam T Template type + /// \param pleft Pointer to the value to be exponentiated + /// \param right Multiplier for the exponentiated value + /// \param exponent Exponent for the power of 2 multiplication + /// \return 1 if the result is NaN or Inf, -1 otherwise + template + auto expMul(T *pleft, T right, short exponent) -> short { #if defined(LIBRAPID_USE_MULTIPREC) - if constexpr (std::is_same_v) { - *pleft = ::mpfr::exp(*pleft) * right * ::mpfr::exp2(exponent); - return (::librapid::isNaN(*pleft) || ::librapid::isInf(*pleft)) ? 1 : -1; - } else { + if constexpr (std::is_same_v) { + *pleft = ::mpfr::exp(*pleft) * right * ::mpfr::exp2(exponent); + return (::librapid::isNaN(*pleft) || ::librapid::isInf(*pleft)) ? 1 : -1; + } else { #endif #if defined(LIBRAPID_MSVC) - auto tmp = static_cast(*pleft); - short ans = _CSTD _Exp(&tmp, static_cast(right), exponent); - *pleft = static_cast(tmp); - return ans; + auto tmp = static_cast(*pleft); + short ans = _CSTD _Exp(&tmp, static_cast(right), exponent); + *pleft = static_cast(tmp); + return ans; #else - *pleft = ::librapid::exp(*pleft) * right * ::librapid::exp2(exponent); - return (::librapid::isNaN(*pleft) || ::librapid::isInf(*pleft)) ? 1 : -1; + *pleft = ::librapid::exp(*pleft) * right * ::librapid::exp2(exponent); + return (::librapid::isNaN(*pleft) || ::librapid::isInf(*pleft)) ? 1 : -1; #endif #if defined(LIBRAPID_USE_MULTIPREC) - } // This ensures the "if constexpr" above actually stops compiler errors + } // This ensures the "if constexpr" above actually stops compiler errors #endif - } - } // namespace algorithm - } // namespace detail - - /// \brief A class representing a complex number of the form \f$a + bi\f$, where \f$a\f$ and - /// \f$b\f$ are real numbers - /// - /// This class represents a complex number of the form \f$a + bi\f$, where \f$a\f$ and - /// \f$b\f$ are real numbers. The class is templated, allowing the user to specify the type - /// of the real and imaginary components. The default type is ``double``. - /// - /// \tparam T The type of the real and imaginary components - template - class Complex { - public: - using Scalar = typename typetraits::TypeInfo::Scalar; - - /// \brief Default constructor - /// - /// Create a new complex number. Both the real and imaginary components are set to zero - Complex() : m_val {T(0), T(0)} {} - - /// \brief Construct a complex number from a real number - /// - /// Create a complex number, setting only the real component. The imaginary component is - /// initialized to zero - /// - /// \tparam R The type of the real component - /// \param realVal The real component - template - explicit Complex(const R &realVal) : m_val {T(realVal), T(0)} {} - - /// \brief Construct a complex number from real and imaginary components - /// - /// Create a new complex number where both the real and imaginary parts are set from the - /// passed parameters - /// - /// \tparam R The type of the real component - /// \tparam I The type of the imaginary component - /// \param realVal The real component - /// \param imagVal The imaginary component - template - Complex(const R &realVal, const I &imagVal) : m_val {T(realVal), T(imagVal)} {} - - /// \brief Complex number copy constructor - /// \param other The complex number to copy - Complex(const Complex &other) : m_val {other.real(), other.imag()} {} - - /// \brief Complex number move constructor - /// \param other The complex number to move - Complex(Complex &&other) noexcept : m_val {other.real(), other.imag()} {} - - /// \brief Construct a complex number from another complex number with a different type - /// \tparam Other Type of the components of the other complex number - /// \param other The complex number to copy - template - Complex(const Complex &other) : m_val {T(other.real()), T(other.imag())} {} - - /// \brief Construct a complex number from a std::complex - /// \param other The std::complex value to copy - explicit Complex(const std::complex &other) : m_val {other.real(), other.imag()} {} - - static constexpr auto size() -> size_t { - return typetraits::TypeInfo::packetWidth; - } - - /// \brief Complex number assignment operator - /// \param other The value to assign - /// \return *this - auto operator=(const Complex &other) -> Complex & { - if (this == &other) return *this; - m_val[RE] = other.real(); - m_val[IM] = other.imag(); - return *this; - } - - // template - // LIBRAPID_ALWAYS_INLINE void store(P *ptr) const { - // auto casted = reinterpret_cast(ptr); - // auto ret = Vc::interleave(m_val[RE], m_val[IM]); - // ret.first.store(casted); - // ret.second.store(casted + size()); - // } - - // template - // LIBRAPID_ALWAYS_INLINE void load(const P *ptr) { - // auto casted = reinterpret_cast(ptr); - // Vc::deinterleave(&m_val[RE], &m_val[IM], casted, Vc::Aligned); - // } - - /// \brief Assign to the real component - /// - /// Set the real component of this complex number to \p val - /// - /// \param val The value to assign - LIBRAPID_ALWAYS_INLINE void real(const T &val) { m_val[RE] = val; } - - /// \brief Assign to the imaginary component - /// - /// Set the imaginary component of this complex number to \p val - /// - /// \param val The value to assign - LIBRAPID_ALWAYS_INLINE void imag(const T &val) { m_val[IM] = val; } - - /// \brief Access the real component - /// - /// Returns a const reference to the real component of this complex number - /// - /// \return Real component - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto real() const -> const T & { - return m_val[RE]; - } - - /// \brief Access the imaginary component - /// - /// Returns a const reference to the imaginary component of this complex number - /// - /// \return Imaginary component - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto imag() const -> const T & { - return m_val[IM]; - } - - /// \brief Access the real component - /// - /// Returns a reference to the real component of this complex number. Since this is a - /// reference type, it can be assigned to - /// - /// \return Real component - LIBRAPID_ALWAYS_INLINE auto real() -> T & { return m_val[RE]; } - - /// \brief Access the imaginary component - /// - /// Returns a reference to the imaginary component of this complex number. Since this is a - /// reference type, it can be assigned to - /// - /// \return imaginary component - LIBRAPID_ALWAYS_INLINE auto imag() -> T & { return m_val[IM]; } - - /// \brief Complex number assigment operator - /// - /// Set the real component of this complex number to \p other, and the imaginary component - /// to 0 - /// - /// \param other - /// \return *this - LIBRAPID_ALWAYS_INLINE auto operator=(const T &other) -> Complex & { - m_val[RE] = other; - m_val[IM] = 0; - return *this; - } - - /// \brief Complex number assigment operator - /// - /// Assign another complex number to this one, copying the real and imaginary components - /// - /// \tparam Other The type of the other complex number - /// \param other Complex number to assign - /// \return *this - template - LIBRAPID_ALWAYS_INLINE auto operator=(const Complex &other) -> Complex & { - m_val[RE] = static_cast(other.real()); - m_val[IM] = static_cast(other.real()); - return *this; - } - - /// \brief Inplace addition - /// - /// Add a scalar value to the real component of this imaginary number - /// - /// \param other Scalar value to add - /// \return *this - LIBRAPID_ALWAYS_INLINE auto operator+=(const T &other) -> Complex & { - m_val[RE] = m_val[RE] + other; - return *this; - } - - /// \brief Inplace subtraction - /// - /// Subtract a scalar value from the real component of this imaginary number - /// - /// \param other Scalar value to subtract - /// \return *this - LIBRAPID_ALWAYS_INLINE auto operator-=(const T &other) -> Complex & { - m_val[RE] = m_val[RE] - other; - return *this; - } - - /// \brief Inplace multiplication - /// - /// Multiply both the real and imaginary components of this complex number by a scalar - /// - /// \param other Scalar value to multiply by - /// \return *this - LIBRAPID_ALWAYS_INLINE auto operator*=(const T &other) -> Complex & { - m_val[RE] = m_val[RE] * other; - m_val[IM] = m_val[IM] * other; - return *this; - } - - /// \brief Inplace division - /// - /// Divide both the real and imaginary components of this complex number by a scalar - /// - /// \param other Scalar value to divide by - /// \return *this - LIBRAPID_ALWAYS_INLINE auto operator/=(const T &other) -> Complex & { - m_val[RE] = m_val[RE] / other; - m_val[IM] = m_val[IM] / other; - return *this; - } - - /// \brief Inplace addition - /// - /// Add a complex number to this one - /// - /// \param other Complex number to add - /// \return *this - LIBRAPID_ALWAYS_INLINE auto operator+=(const Complex &other) -> Complex & { - this->_add(other); - return *this; - } - - /// \brief Inplace subtraction - /// - /// Subtract a complex number from this one - /// - /// \param other Complex number to subtract - /// \return *this - LIBRAPID_ALWAYS_INLINE auto operator-=(const Complex &other) -> Complex & { - this->_sub(other); - return *this; - } - - /// \brief Inplace multiplication - /// - /// Multiply this complex number by another one - /// - /// \param other Complex number to multiply by - /// \return *this - LIBRAPID_ALWAYS_INLINE auto operator*=(const Complex &other) -> Complex & { - this->_mul(other); - return *this; - } - - /// \brief Inplace division - /// - /// Divide this complex number by another one - /// - /// \param other Complex number to divide by - /// \return *this - LIBRAPID_ALWAYS_INLINE auto operator/=(const Complex &other) -> Complex & { - this->_div(other); - return *this; - } - - /// \brief Cast to scalar types - /// - /// Cast this complex number to a scalar type. This will extract only the real component. - /// - /// \tparam To Type to cast to - /// \return Scalar - template - LIBRAPID_ALWAYS_INLINE explicit operator To() const { - return static_cast(m_val[RE]); - } - - /// \brief Cast to a complex number with a different scalar type - /// - /// Cast the real and imaginary components of this complex number to a different type and - /// return the result as a new complex number - /// - /// \tparam To Scalar type to cast to - /// \return Complex number - template - LIBRAPID_ALWAYS_INLINE explicit operator Complex() const { - return Complex(static_cast(m_val[RE]), static_cast(m_val[IM])); - } - - /// \brief Complex number to string - /// - /// Create a std::string representation of a complex number, formatting each component with - /// the format string - /// - /// \param format Format string - /// \return std::string - LIBRAPID_NODISCARD auto str(const std::string &format = "{}") const -> std::string { - if (!::librapid::signBit(m_val[IM])) - return "(" + fmt::format(format, m_val[RE]) + "+" + fmt::format(format, m_val[IM]) + - "j)"; - else - return "(" + fmt::format(format, m_val[RE]) + "-" + - fmt::format(format, -m_val[IM]) + "j)"; - } - - protected: - /// \brief Add a complex number to this one - /// \tparam Other Scalar type of the other complex number - /// \param other Other complex number - template - LIBRAPID_ALWAYS_INLINE void _add(const Complex &other) { - m_val[RE] = m_val[RE] + other.real(); - m_val[IM] = m_val[IM] + other.imag(); - } - - /// \brief Subtract a complex number from this one - /// \tparam Other Scalar type of the other complex number - /// \param other Other complex number - template - LIBRAPID_ALWAYS_INLINE void _sub(const Complex &other) { - m_val[RE] = m_val[RE] - other.real(); - m_val[IM] = m_val[IM] - other.imag(); - } - - /// \brief Multiply this complex number by another one - /// \tparam Other Scalar type of the other complex number - /// \param other Other complex number - template - LIBRAPID_ALWAYS_INLINE void _mul(const Complex &other) { - T otherReal = static_cast(other.real()); - T otherImag = static_cast(other.imag()); - - T tmp = m_val[RE] * otherReal - m_val[IM] * otherImag; - m_val[IM] = m_val[RE] * otherImag + m_val[IM] * otherReal; - m_val[RE] = tmp; - } - - /// \brief Divide this complex number by another one - /// \tparam Other Scalar type of the other complex number - /// \param other Other complex number - template - LIBRAPID_ALWAYS_INLINE void _div(const Complex &other) { - T otherReal = static_cast(other.real()); - T otherImag = static_cast(other.imag()); - - if (::librapid::isNaN(otherReal) || ::librapid::isNaN(otherImag)) { // Set result to NaN - m_val[RE] = typetraits::TypeInfo::quietNaN(); - m_val[IM] = m_val[RE]; - } else if ((otherImag < 0 ? T(-otherImag) - : T(+otherImag)) < // |other.imag()| < |other.real()| - (otherReal < 0 ? T(-otherReal) : T(+otherReal))) { - T wr = otherImag / otherReal; - T wd = otherReal + wr * otherImag; - - if (::librapid::isNaN(wd) || wd == 0) { // NaN result - m_val[RE] = typetraits::TypeInfo::quietNaN(); - m_val[IM] = m_val[RE]; - } else { // Valid result - T tmp = (m_val[RE] + m_val[IM] * wr) / wd; - m_val[IM] = (m_val[IM] - m_val[RE] * wr) / wd; - m_val[RE] = tmp; - } - } else if (otherImag == 0) { // Set NaN - m_val[RE] = typetraits::TypeInfo::quietNaN(); - m_val[IM] = m_val[RE]; - } else { // 0 < |other.real()| <= |other.imag()| - T wr = otherReal / otherImag; - T wd = otherImag + wr * otherReal; - - if (::librapid::isNaN(wd) || wd == 0) { // NaN result - m_val[RE] = typetraits::TypeInfo::quietNaN(); - m_val[IM] = m_val[RE]; - } else { - T tmp = (m_val[RE] * wr + m_val[IM]) / wd; - m_val[IM] = (m_val[IM] * wr - m_val[RE]) / wd; - m_val[RE] = tmp; - } - } - } - - private: - T m_val[2]; - static constexpr size_t RE = 0; - static constexpr size_t IM = 1; - }; - - /// \brief Negate a complex number - /// \tparam T Scalar type of the complex number - /// \param other Complex number to negate - /// \return Negated complex number - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator-(const Complex &other) - -> Complex { - return {-other.real(), -other.imag()}; - } - - /// \brief Add two complex numbers - /// - /// Add two complex numbers together, returning the result - /// - /// \tparam L Scalar type of LHS - /// \tparam R Scalar type of RHS - /// \param left LHS complex number - /// \param right RHS complex number - /// \return Sum of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator+(const Complex &left, - const Complex &right) { - using Scalar = typename std::common_type_t; - Complex tmp(left.real(), left.imag()); - tmp += Complex(right.real(), right.imag()); - return tmp; - } - - /// \brief Add a complex number and a scalar - /// - /// Add a real number to the real component of a complex number, returning the result - /// - /// \tparam T Scalar type of the complex number - /// \tparam R Type of the real number - /// \param left LHS complex number - /// \param right RHS scalar - /// \return Sum of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator+(const Complex &left, - const R &right) { - Complex tmp(left); - tmp.real(tmp.real() + right); - return tmp; - } - - /// \brief Add a scalar to a complex number - /// - /// Add a real number to the real component of a complex number, returning the result - /// - /// \tparam R Type of the real number - /// \tparam T Scalar type of the complex number - /// \param left LHS scalar - /// \param right RHS complex number - /// \return Sum of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator+(const R &left, - const Complex &right) { - Complex tmp(left); - tmp += right; - return tmp; - } - - /// \brief Subtract a complex number from another complex number - /// - /// Subtract the real and imaginary components of the RHS complex number from the corresponding - /// components of the LHS complex number, returning the result - /// - /// \tparam L Scalar type of the LHS complex number - /// \tparam R Scalar type of the RHS complex number - /// \param left LHS complex number - /// \param right RHS complex number - /// \return Difference of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator-(const Complex &left, - const Complex &right) { - using Scalar = typename std::common_type_t; - Complex tmp(left.real(), left.imag()); - tmp -= Complex(right.real(), right.imag()); - return tmp; - } - - /// \brief Subtract a scalar from a complex number - /// - /// Subtract a real number from the real component of a complex number, returning the result - /// - /// \tparam T Scalar type of the complex number - /// \tparam R Type of the real number - /// \param left LHS complex number - /// \param right RHS scalar - /// \return Difference of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator-(const Complex &left, - const R &right) { - Complex tmp(left); - tmp.real(tmp.real() - right); - return tmp; - } - - /// \brief Subtract a complex number from a scalar - /// - /// Subtract the real and imaginary components of the RHS complex number from a real number, - /// returning the result - /// - /// \tparam T Scalar type of the complex number - /// \tparam R Type of the real number - /// \param left LHS scalar - /// \param right RHS complex number - /// \return Difference of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator-(const R &left, - const Complex &right) { - Complex tmp(left); - tmp -= right; - return tmp; - } - - /// \brief Multiply two complex numbers - /// - /// Multiply the LHS and RHS complex numbers, returning the result - /// - /// \tparam L Scalar type of the LHS complex number - /// \tparam R Scalar type of the RHS complex number - /// \param left LHS complex number - /// \param right RHS complex number - /// \return Product of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator*(const Complex &left, - const Complex &right) { - using Scalar = typename std::common_type_t; - Complex tmp(left.real(), left.imag()); - tmp *= Complex(right.real(), right.imag()); - return tmp; - } - - /// \brief Multiply a complex number by a scalar - /// - /// Multiply the real and imaginary components of a complex number by a real number, returning - /// the result - /// - /// \tparam T Scalar type of the complex number - /// \tparam R Type of the real number - /// \param left LHS complex number - /// \param right RHS scalar - /// \return Product of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator*(const Complex &left, - const R &right) { - Complex tmp(left); - tmp.real(tmp.real() * right); - tmp.imag(tmp.imag() * right); - return tmp; - } - - /// \brief Multiply a scalar by a complex number - /// - /// Multiply a real number by the real and imaginary components of a complex number, returning - /// the result - /// - /// \tparam T Scalar type of the complex number - /// \tparam R Type of the real number - /// \param left LHS scalar - /// \param right RHS complex number - /// \return Product of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator*(const R &left, - const Complex &right) { - Complex tmp(left); - tmp *= right; - return tmp; - } - - /// \brief Divide two complex numbers - /// - /// Divide the LHS complex number by the RHS complex number, returning the result - /// - /// \tparam L Scalar type of the LHS complex number - /// \tparam R Scalar type of the RHS complex number - /// \param left LHS complex number - /// \param right RHS complex number - /// \return Quotient of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator/(const Complex &left, - const Complex &right) { - using Scalar = typename std::common_type_t; - Complex tmp(left.real(), left.imag()); - tmp /= Complex(right.real(), right.imag()); - return tmp; - } - - /// \brief Divide a complex number by a scalar - /// - /// Divide the real and imaginary components of a complex number by a real number, returning the - /// result - /// - /// \tparam T Scalar type of the complex number - /// \tparam R Type of the real number - /// \param left LHS complex number - /// \param right RHS scalar - /// \return Quotient of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator/(const Complex &left, - const R &right) { - Complex tmp(left); - tmp.real(tmp.real() / right); - tmp.imag(tmp.imag() / right); - return tmp; - } - - /// \brief Divide a scalar by a complex number - /// - /// Divide a real number by the real and imaginary components of a complex number, returning the - /// result - /// - /// \tparam T Scalar type of the complex number - /// \tparam R Type of the real number - /// \param left LHS scalar - /// \param right RHS complex number - /// \return Quotient of LHS and RHS - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator/(const R &left, - const Complex &right) { - Complex tmp(left); - tmp /= right; - return tmp; - } - - /// \brief Equality comparison of two complex numbers - /// \tparam L Scalar type of LHS complex number - /// \tparam R Scalar type of RHS complex number - /// \param left LHS complex number - /// \param right RHS complex number - /// \return true if equal, false otherwise - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator==(const Complex &left, - const Complex &right) { - return left.real() == right.real() && left.imag() == right.imag(); - } - - /// \brief Equality comparison of complex number and scalar - /// - /// Compares the real component of the complex number to the scalar, and the imaginary component - /// to zero. Returns true if and only if both comparisons are true. - /// - /// \tparam T Scalar type of complex number - /// \param left LHS complex number - /// \param right RHS scalar - /// \return true if equal, false otherwise - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator==(const Complex &left, - T &right) { - return left.real() == right && left.imag() == 0; - } + } + } // namespace algorithm + } // namespace detail + + /// \brief A class representing a complex number of the form \f$a + bi\f$, where \f$a\f$ and + /// \f$b\f$ are real numbers + /// + /// This class represents a complex number of the form \f$a + bi\f$, where \f$a\f$ and + /// \f$b\f$ are real numbers. The class is templated, allowing the user to specify the type + /// of the real and imaginary components. The default type is ``double``. + /// + /// \tparam T The type of the real and imaginary components + template + class Complex { + public: + using Scalar = typename typetraits::TypeInfo::Scalar; + + /// \brief Default constructor + /// + /// Create a new complex number. Both the real and imaginary components are set to zero + Complex() : m_val {T(0), T(0)} {} + + /// \brief Construct a complex number from a real number + /// + /// Create a complex number, setting only the real component. The imaginary component is + /// initialized to zero + /// + /// \tparam R The type of the real component + /// \param realVal The real component + template + explicit Complex(const R &realVal) : m_val {T(realVal), T(0)} {} + + /// \brief Construct a complex number from real and imaginary components + /// + /// Create a new complex number where both the real and imaginary parts are set from the + /// passed parameters + /// + /// \tparam R The type of the real component + /// \tparam I The type of the imaginary component + /// \param realVal The real component + /// \param imagVal The imaginary component + template + Complex(const R &realVal, const I &imagVal) : m_val {T(realVal), T(imagVal)} {} + + /// \brief Complex number copy constructor + /// \param other The complex number to copy + Complex(const Complex &other) : m_val {other.real(), other.imag()} {} + + /// \brief Complex number move constructor + /// \param other The complex number to move + Complex(Complex &&other) noexcept : m_val {other.real(), other.imag()} {} + + /// \brief Construct a complex number from another complex number with a different type + /// \tparam Other Type of the components of the other complex number + /// \param other The complex number to copy + template + Complex(const Complex &other) : m_val {T(other.real()), T(other.imag())} {} + + /// \brief Construct a complex number from a std::complex + /// \param other The std::complex value to copy + explicit Complex(const std::complex &other) : m_val {other.real(), other.imag()} {} + + static constexpr auto size() -> size_t { + return typetraits::TypeInfo::packetWidth; + } + + /// \brief Complex number assignment operator + /// \param other The value to assign + /// \return *this + auto operator=(const Complex &other) -> Complex & { + if (this == &other) return *this; + m_val[RE] = other.real(); + m_val[IM] = other.imag(); + return *this; + } + + // template + // LIBRAPID_ALWAYS_INLINE void store(P *ptr) const { + // auto casted = reinterpret_cast(ptr); + // auto ret = Vc::interleave(m_val[RE], m_val[IM]); + // ret.first.store(casted); + // ret.second.store(casted + size()); + // } + + // template + // LIBRAPID_ALWAYS_INLINE void load(const P *ptr) { + // auto casted = reinterpret_cast(ptr); + // Vc::deinterleave(&m_val[RE], &m_val[IM], casted, Vc::Aligned); + // } + + /// \brief Assign to the real component + /// + /// Set the real component of this complex number to \p val + /// + /// \param val The value to assign + LIBRAPID_ALWAYS_INLINE void real(const T &val) { m_val[RE] = val; } + + /// \brief Assign to the imaginary component + /// + /// Set the imaginary component of this complex number to \p val + /// + /// \param val The value to assign + LIBRAPID_ALWAYS_INLINE void imag(const T &val) { m_val[IM] = val; } + + /// \brief Access the real component + /// + /// Returns a const reference to the real component of this complex number + /// + /// \return Real component + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto real() const -> const T & { + return m_val[RE]; + } + + /// \brief Access the imaginary component + /// + /// Returns a const reference to the imaginary component of this complex number + /// + /// \return Imaginary component + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto imag() const -> const T & { + return m_val[IM]; + } + + /// \brief Access the real component + /// + /// Returns a reference to the real component of this complex number. Since this is a + /// reference type, it can be assigned to + /// + /// \return Real component + LIBRAPID_ALWAYS_INLINE auto real() -> T & { return m_val[RE]; } + + /// \brief Access the imaginary component + /// + /// Returns a reference to the imaginary component of this complex number. Since this is a + /// reference type, it can be assigned to + /// + /// \return imaginary component + LIBRAPID_ALWAYS_INLINE auto imag() -> T & { return m_val[IM]; } + + /// \brief Complex number assigment operator + /// + /// Set the real component of this complex number to \p other, and the imaginary component + /// to 0 + /// + /// \param other + /// \return *this + LIBRAPID_ALWAYS_INLINE auto operator=(const T &other) -> Complex & { + m_val[RE] = other; + m_val[IM] = 0; + return *this; + } + + /// \brief Complex number assigment operator + /// + /// Assign another complex number to this one, copying the real and imaginary components + /// + /// \tparam Other The type of the other complex number + /// \param other Complex number to assign + /// \return *this + template + LIBRAPID_ALWAYS_INLINE auto operator=(const Complex &other) -> Complex & { + m_val[RE] = static_cast(other.real()); + m_val[IM] = static_cast(other.real()); + return *this; + } + + /// \brief Inplace addition + /// + /// Add a scalar value to the real component of this imaginary number + /// + /// \param other Scalar value to add + /// \return *this + LIBRAPID_ALWAYS_INLINE auto operator+=(const T &other) -> Complex & { + m_val[RE] = m_val[RE] + other; + return *this; + } + + /// \brief Inplace subtraction + /// + /// Subtract a scalar value from the real component of this imaginary number + /// + /// \param other Scalar value to subtract + /// \return *this + LIBRAPID_ALWAYS_INLINE auto operator-=(const T &other) -> Complex & { + m_val[RE] = m_val[RE] - other; + return *this; + } + + /// \brief Inplace multiplication + /// + /// Multiply both the real and imaginary components of this complex number by a scalar + /// + /// \param other Scalar value to multiply by + /// \return *this + LIBRAPID_ALWAYS_INLINE auto operator*=(const T &other) -> Complex & { + m_val[RE] = m_val[RE] * other; + m_val[IM] = m_val[IM] * other; + return *this; + } + + /// \brief Inplace division + /// + /// Divide both the real and imaginary components of this complex number by a scalar + /// + /// \param other Scalar value to divide by + /// \return *this + LIBRAPID_ALWAYS_INLINE auto operator/=(const T &other) -> Complex & { + m_val[RE] = m_val[RE] / other; + m_val[IM] = m_val[IM] / other; + return *this; + } + + /// \brief Inplace addition + /// + /// Add a complex number to this one + /// + /// \param other Complex number to add + /// \return *this + LIBRAPID_ALWAYS_INLINE auto operator+=(const Complex &other) -> Complex & { + this->_add(other); + return *this; + } + + /// \brief Inplace subtraction + /// + /// Subtract a complex number from this one + /// + /// \param other Complex number to subtract + /// \return *this + LIBRAPID_ALWAYS_INLINE auto operator-=(const Complex &other) -> Complex & { + this->_sub(other); + return *this; + } + + /// \brief Inplace multiplication + /// + /// Multiply this complex number by another one + /// + /// \param other Complex number to multiply by + /// \return *this + LIBRAPID_ALWAYS_INLINE auto operator*=(const Complex &other) -> Complex & { + this->_mul(other); + return *this; + } + + /// \brief Inplace division + /// + /// Divide this complex number by another one + /// + /// \param other Complex number to divide by + /// \return *this + LIBRAPID_ALWAYS_INLINE auto operator/=(const Complex &other) -> Complex & { + this->_div(other); + return *this; + } + + /// \brief Cast to scalar types + /// + /// Cast this complex number to a scalar type. This will extract only the real component. + /// + /// \tparam To Type to cast to + /// \return Scalar + template + LIBRAPID_ALWAYS_INLINE explicit operator To() const { + return static_cast(m_val[RE]); + } + + /// \brief Cast to a complex number with a different scalar type + /// + /// Cast the real and imaginary components of this complex number to a different type and + /// return the result as a new complex number + /// + /// \tparam To Scalar type to cast to + /// \return Complex number + template + LIBRAPID_ALWAYS_INLINE explicit operator Complex() const { + return Complex(static_cast(m_val[RE]), static_cast(m_val[IM])); + } + + /// \brief Complex number to string + /// + /// Create a std::string representation of a complex number, formatting each component with + /// the format string + /// + /// \param format Format string + /// \return std::string + LIBRAPID_NODISCARD auto str(const std::string &format = "{}") const -> std::string { + if (!::librapid::signBit(m_val[IM])) + return "(" + fmt::format(format, m_val[RE]) + "+" + fmt::format(format, m_val[IM]) + + "j)"; + else + return "(" + fmt::format(format, m_val[RE]) + "-" + + fmt::format(format, -m_val[IM]) + "j)"; + } + + protected: + /// \brief Add a complex number to this one + /// \tparam Other Scalar type of the other complex number + /// \param other Other complex number + template + LIBRAPID_ALWAYS_INLINE void _add(const Complex &other) { + m_val[RE] = m_val[RE] + other.real(); + m_val[IM] = m_val[IM] + other.imag(); + } + + /// \brief Subtract a complex number from this one + /// \tparam Other Scalar type of the other complex number + /// \param other Other complex number + template + LIBRAPID_ALWAYS_INLINE void _sub(const Complex &other) { + m_val[RE] = m_val[RE] - other.real(); + m_val[IM] = m_val[IM] - other.imag(); + } + + /// \brief Multiply this complex number by another one + /// \tparam Other Scalar type of the other complex number + /// \param other Other complex number + template + LIBRAPID_ALWAYS_INLINE void _mul(const Complex &other) { + T otherReal = static_cast(other.real()); + T otherImag = static_cast(other.imag()); + + T tmp = m_val[RE] * otherReal - m_val[IM] * otherImag; + m_val[IM] = m_val[RE] * otherImag + m_val[IM] * otherReal; + m_val[RE] = tmp; + } + + /// \brief Divide this complex number by another one + /// \tparam Other Scalar type of the other complex number + /// \param other Other complex number + template + LIBRAPID_ALWAYS_INLINE void _div(const Complex &other) { + T otherReal = static_cast(other.real()); + T otherImag = static_cast(other.imag()); + + if (::librapid::isNaN(otherReal) || ::librapid::isNaN(otherImag)) { // Set result to NaN + m_val[RE] = typetraits::TypeInfo::quietNaN(); + m_val[IM] = m_val[RE]; + } else if ((otherImag < 0 ? T(-otherImag) + : T(+otherImag)) < // |other.imag()| < |other.real()| + (otherReal < 0 ? T(-otherReal) : T(+otherReal))) { + T wr = otherImag / otherReal; + T wd = otherReal + wr * otherImag; + + if (::librapid::isNaN(wd) || wd == 0) { // NaN result + m_val[RE] = typetraits::TypeInfo::quietNaN(); + m_val[IM] = m_val[RE]; + } else { // Valid result + T tmp = (m_val[RE] + m_val[IM] * wr) / wd; + m_val[IM] = (m_val[IM] - m_val[RE] * wr) / wd; + m_val[RE] = tmp; + } + } else if (otherImag == 0) { // Set NaN + m_val[RE] = typetraits::TypeInfo::quietNaN(); + m_val[IM] = m_val[RE]; + } else { // 0 < |other.real()| <= |other.imag()| + T wr = otherReal / otherImag; + T wd = otherImag + wr * otherReal; + + if (::librapid::isNaN(wd) || wd == 0) { // NaN result + m_val[RE] = typetraits::TypeInfo::quietNaN(); + m_val[IM] = m_val[RE]; + } else { + T tmp = (m_val[RE] * wr + m_val[IM]) / wd; + m_val[IM] = (m_val[IM] * wr - m_val[RE]) / wd; + m_val[RE] = tmp; + } + } + } + + private: + T m_val[2]; + static constexpr size_t RE = 0; + static constexpr size_t IM = 1; + }; + + /// \brief Negate a complex number + /// \tparam T Scalar type of the complex number + /// \param other Complex number to negate + /// \return Negated complex number + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator-(const Complex &other) + -> Complex { + return {-other.real(), -other.imag()}; + } + + /// \brief Add two complex numbers + /// + /// Add two complex numbers together, returning the result + /// + /// \tparam L Scalar type of LHS + /// \tparam R Scalar type of RHS + /// \param left LHS complex number + /// \param right RHS complex number + /// \return Sum of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator+(const Complex &left, + const Complex &right) { + using Scalar = typename std::common_type_t; + Complex tmp(left.real(), left.imag()); + tmp += Complex(right.real(), right.imag()); + return tmp; + } + + /// \brief Add a complex number and a scalar + /// + /// Add a real number to the real component of a complex number, returning the result + /// + /// \tparam T Scalar type of the complex number + /// \tparam R Type of the real number + /// \param left LHS complex number + /// \param right RHS scalar + /// \return Sum of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator+(const Complex &left, + const R &right) { + Complex tmp(left); + tmp.real(tmp.real() + right); + return tmp; + } + + /// \brief Add a scalar to a complex number + /// + /// Add a real number to the real component of a complex number, returning the result + /// + /// \tparam R Type of the real number + /// \tparam T Scalar type of the complex number + /// \param left LHS scalar + /// \param right RHS complex number + /// \return Sum of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator+(const R &left, + const Complex &right) { + Complex tmp(left); + tmp += right; + return tmp; + } + + /// \brief Subtract a complex number from another complex number + /// + /// Subtract the real and imaginary components of the RHS complex number from the corresponding + /// components of the LHS complex number, returning the result + /// + /// \tparam L Scalar type of the LHS complex number + /// \tparam R Scalar type of the RHS complex number + /// \param left LHS complex number + /// \param right RHS complex number + /// \return Difference of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator-(const Complex &left, + const Complex &right) { + using Scalar = typename std::common_type_t; + Complex tmp(left.real(), left.imag()); + tmp -= Complex(right.real(), right.imag()); + return tmp; + } + + /// \brief Subtract a scalar from a complex number + /// + /// Subtract a real number from the real component of a complex number, returning the result + /// + /// \tparam T Scalar type of the complex number + /// \tparam R Type of the real number + /// \param left LHS complex number + /// \param right RHS scalar + /// \return Difference of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator-(const Complex &left, + const R &right) { + Complex tmp(left); + tmp.real(tmp.real() - right); + return tmp; + } + + /// \brief Subtract a complex number from a scalar + /// + /// Subtract the real and imaginary components of the RHS complex number from a real number, + /// returning the result + /// + /// \tparam T Scalar type of the complex number + /// \tparam R Type of the real number + /// \param left LHS scalar + /// \param right RHS complex number + /// \return Difference of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator-(const R &left, + const Complex &right) { + Complex tmp(left); + tmp -= right; + return tmp; + } + + /// \brief Multiply two complex numbers + /// + /// Multiply the LHS and RHS complex numbers, returning the result + /// + /// \tparam L Scalar type of the LHS complex number + /// \tparam R Scalar type of the RHS complex number + /// \param left LHS complex number + /// \param right RHS complex number + /// \return Product of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator*(const Complex &left, + const Complex &right) { + using Scalar = typename std::common_type_t; + Complex tmp(left.real(), left.imag()); + tmp *= Complex(right.real(), right.imag()); + return tmp; + } + + /// \brief Multiply a complex number by a scalar + /// + /// Multiply the real and imaginary components of a complex number by a real number, returning + /// the result + /// + /// \tparam T Scalar type of the complex number + /// \tparam R Type of the real number + /// \param left LHS complex number + /// \param right RHS scalar + /// \return Product of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator*(const Complex &left, + const R &right) { + Complex tmp(left); + tmp.real(tmp.real() * right); + tmp.imag(tmp.imag() * right); + return tmp; + } + + /// \brief Multiply a scalar by a complex number + /// + /// Multiply a real number by the real and imaginary components of a complex number, returning + /// the result + /// + /// \tparam T Scalar type of the complex number + /// \tparam R Type of the real number + /// \param left LHS scalar + /// \param right RHS complex number + /// \return Product of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator*(const R &left, + const Complex &right) { + Complex tmp(left); + tmp *= right; + return tmp; + } + + /// \brief Divide two complex numbers + /// + /// Divide the LHS complex number by the RHS complex number, returning the result + /// + /// \tparam L Scalar type of the LHS complex number + /// \tparam R Scalar type of the RHS complex number + /// \param left LHS complex number + /// \param right RHS complex number + /// \return Quotient of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator/(const Complex &left, + const Complex &right) { + using Scalar = typename std::common_type_t; + Complex tmp(left.real(), left.imag()); + tmp /= Complex(right.real(), right.imag()); + return tmp; + } + + /// \brief Divide a complex number by a scalar + /// + /// Divide the real and imaginary components of a complex number by a real number, returning the + /// result + /// + /// \tparam T Scalar type of the complex number + /// \tparam R Type of the real number + /// \param left LHS complex number + /// \param right RHS scalar + /// \return Quotient of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator/(const Complex &left, + const R &right) { + Complex tmp(left); + tmp.real(tmp.real() / right); + tmp.imag(tmp.imag() / right); + return tmp; + } + + /// \brief Divide a scalar by a complex number + /// + /// Divide a real number by the real and imaginary components of a complex number, returning the + /// result + /// + /// \tparam T Scalar type of the complex number + /// \tparam R Type of the real number + /// \param left LHS scalar + /// \param right RHS complex number + /// \return Quotient of LHS and RHS + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator/(const R &left, + const Complex &right) { + Complex tmp(left); + tmp /= right; + return tmp; + } + + /// \brief Equality comparison of two complex numbers + /// \tparam L Scalar type of LHS complex number + /// \tparam R Scalar type of RHS complex number + /// \param left LHS complex number + /// \param right RHS complex number + /// \return true if equal, false otherwise + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator==(const Complex &left, + const Complex &right) { + return left.real() == right.real() && left.imag() == right.imag(); + } + + /// \brief Equality comparison of complex number and scalar + /// + /// Compares the real component of the complex number to the scalar, and the imaginary component + /// to zero. Returns true if and only if both comparisons are true. + /// + /// \tparam T Scalar type of complex number + /// \param left LHS complex number + /// \param right RHS scalar + /// \return true if equal, false otherwise + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator==(const Complex &left, + T &right) { + return left.real() == right && left.imag() == 0; + } #if !defined(LIBRAPID_CXX_20) - /// \brief Equality comparison of scalar and complex number - /// - /// Compares the real component of the complex number to the scalar, and the imaginary component - /// to zero. Returns true if and only if both comparisons are true. - /// - /// \tparam T Scalar type of complex number - /// \param left LHS scalar - /// \param right RHS complex number - /// \return true if equal, false otherwise - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator==(const T &left, - const Complex &right) { - return left == right.real() && 0 == right.imag(); - } + /// \brief Equality comparison of scalar and complex number + /// + /// Compares the real component of the complex number to the scalar, and the imaginary component + /// to zero. Returns true if and only if both comparisons are true. + /// + /// \tparam T Scalar type of complex number + /// \param left LHS scalar + /// \param right RHS complex number + /// \return true if equal, false otherwise + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator==(const T &left, + const Complex &right) { + return left == right.real() && 0 == right.imag(); + } #endif #if !defined(LIBRAPID_CXX_20) - /// \brief Inequality comparison of two complex numbers - /// \tparam T Scalar type of complex number - /// \param left LHS complex number - /// \param right RHS complex number - /// \return true if ***not*** equal, false otherwise - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator!=(const Complex &left, - const Complex &right) { - return !(left == right); - } - - /// \brief Inequality comparison of complex number and scalar - /// \see operator==(const Complex &, T &) - /// \tparam T Scalar type of complex number - /// \param left LHS complex number - /// \param right RHS scalar - /// \return true if ***not*** equal, false otherwise - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator!=(const Complex &left, - T &right) { - return !(left == right); - } - - /// \brief Inequality comparison of scalar and complex number - /// \see operator==(const T &, const Complex &) - /// \tparam T Scalar type of complex number - /// \param left LHS scalar - /// \param right RHS complex number - /// \return true if ***not*** equal, false otherwise - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator!=(const T &left, - const Complex &right) { - return !(left == right); - } + /// \brief Inequality comparison of two complex numbers + /// \tparam T Scalar type of complex number + /// \param left LHS complex number + /// \param right RHS complex number + /// \return true if ***not*** equal, false otherwise + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator!=(const Complex &left, + const Complex &right) { + return !(left == right); + } + + /// \brief Inequality comparison of complex number and scalar + /// \see operator==(const Complex &, T &) + /// \tparam T Scalar type of complex number + /// \param left LHS complex number + /// \param right RHS scalar + /// \return true if ***not*** equal, false otherwise + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator!=(const Complex &left, + T &right) { + return !(left == right); + } + + /// \brief Inequality comparison of scalar and complex number + /// \see operator==(const T &, const Complex &) + /// \tparam T Scalar type of complex number + /// \param left LHS scalar + /// \param right RHS complex number + /// \return true if ***not*** equal, false otherwise + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr bool operator!=(const T &left, + const Complex &right) { + return !(left == right); + } #endif - /// \brief Return \f$ \mathrm{Re}(z) \f$ - /// \tparam T Scalar type of the complex number - /// \param val Complex number - /// \return Real component of the complex number - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T real(const Complex &val) { - return val.real(); - } - - /// \brief Return \f$ \mathrm{Im}(z) \f$ - /// \tparam T Scalar type of the complex number - /// \param val Complex number - /// \return Imaginary component of the complex number - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T imag(const Complex &val) { - return val.imag(); - } - - /// \brief Return \f$ \sqrt{z} \f$ - /// \tparam T Scalar type of the complex number - /// \param val Complex number - /// \return Square root of the complex number - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex - sqrt(const Complex &val); // Defined later - - /// \brief Return \f$ \sqrt{\mathrm{Re}(z)^2 + \mathrm{Im}(z)^2} \f$ - /// \tparam T Scalar type of the complex number - /// \param val Complex number - /// \return Absolute value of the complex number - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T abs(const Complex &val) { - return ::librapid::hypot(val.real(), val.imag()); - } - - /// \brief Returns \f$z^{*}\f$ - /// \tparam T Scalar type of the complex number - /// \param val Complex number - /// \return Complex conjugate of the complex number - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex conj(const Complex &val) { - return Complex(val.real(), -val.imag()); - } - - /// \brief Compute the complex arc cosine of a complex number - /// - /// This function computes the complex arc cosine of the input complex number, - /// \f$z = \text{acos}(z)\f$ - /// - /// The algorithm handles NaN and infinity values, and avoids overflow. - /// - /// \tparam T Scalar type of the complex number - /// \param other Input complex number - /// \return Complex arc cosine of the input complex number - template - LIBRAPID_NODISCARD Complex acos(const Complex &other) { - const T arcBig = T(0.25) * ::librapid::sqrt(typetraits::TypeInfo::max()); - const T pi = []() { + /// \brief Return \f$ \mathrm{Re}(z) \f$ + /// \tparam T Scalar type of the complex number + /// \param val Complex number + /// \return Real component of the complex number + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T real(const Complex &val) { + return val.real(); + } + + /// \brief Return \f$ \mathrm{Im}(z) \f$ + /// \tparam T Scalar type of the complex number + /// \param val Complex number + /// \return Imaginary component of the complex number + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T imag(const Complex &val) { + return val.imag(); + } + + /// \brief Return \f$ \sqrt{z} \f$ + /// \tparam T Scalar type of the complex number + /// \param val Complex number + /// \return Square root of the complex number + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex + sqrt(const Complex &val); // Defined later + + /// \brief Return \f$ \sqrt{\mathrm{Re}(z)^2 + \mathrm{Im}(z)^2} \f$ + /// \tparam T Scalar type of the complex number + /// \param val Complex number + /// \return Absolute value of the complex number + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T abs(const Complex &val) { + return ::librapid::hypot(val.real(), val.imag()); + } + + /// \brief Returns \f$z^{*}\f$ + /// \tparam T Scalar type of the complex number + /// \param val Complex number + /// \return Complex conjugate of the complex number + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex conj(const Complex &val) { + return Complex(val.real(), -val.imag()); + } + + /// \brief Compute the complex arc cosine of a complex number + /// + /// This function computes the complex arc cosine of the input complex number, + /// \f$z = \text{acos}(z)\f$ + /// + /// The algorithm handles NaN and infinity values, and avoids overflow. + /// + /// \tparam T Scalar type of the complex number + /// \param other Input complex number + /// \return Complex arc cosine of the input complex number + template + LIBRAPID_NODISCARD Complex acos(const Complex &other) { + const T arcBig = T(0.25) * ::librapid::sqrt(typetraits::TypeInfo::max()); + const T pi = []() { #if defined(LIBRAPID_USE_MULTIPREC) - if constexpr (std::is_same_v) - return ::librapid::constPi(); - else - return static_cast(3.1415926535897932384626433832795029L); + if constexpr (std::is_same_v) + return ::librapid::constPi(); + else + return static_cast(3.1415926535897932384626433832795029L); #else - return static_cast(3.1415926535897932384626433832795029L); + return static_cast(3.1415926535897932384626433832795029L); #endif - }(); - - const T re = real(other); - const T im = imag(other); - T ux, vx; - - if (::librapid::isNaN(re) || ::librapid::isNaN(im)) { // At least one NaN - ux = typetraits::TypeInfo::quietNaN(); - vx = ux; - } else if (::librapid::isInf(re)) { // +/- Inf - if (::librapid::isInf(im)) { - if (re < 0) - ux = T(0.75) * pi; // (-Inf, +/-Inf) - else - ux = T(0.25) * pi; // (-Inf, +/-Inf) - } else if (re < 0) { - ux = pi; // (-Inf, finite) - } else { - ux = 0; // (+Inf, finite) - } - vx = -::librapid::copySign(typetraits::TypeInfo::infinity(), im); - } else if (::librapid::isInf(im)) { // finite, Inf) - ux = T(0.5) * pi; // (finite, +/-Inf) - vx = -im; - } else { // (finite, finite) - const Complex wx = sqrt(Complex(1 + re, -im)); - const Complex zx = sqrt(Complex(1 - re, -im)); - const T wr = real(wx); - const T wi = imag(wx); - const T zr = real(zx); - const T zi = imag(zx); - T alpha, beta; - - ux = 2 * ::librapid::atan2(zr, wr); - - if (arcBig < wr) { // Real part is large - alpha = wr; - beta = zi + wi * (zr / alpha); - } else if (arcBig < wi) { // Imaginary part is large - alpha = wi; - beta = wr * (zi / alpha) + zr; - } else if (wi < -arcBig) { // Imaginary part of w is large negative - alpha = -wi; - beta = wr * (zi / alpha) - zr; - } else { // Shouldn't overflow (?) - alpha = 0; - beta = wr * zi + wi * zr; // Im(w * z) - } - - vx = ::librapid::asinh(beta); - if (alpha != 0) { - // asinh(a * b) = asinh(a) + log(b) - if (0 <= vx) - vx += ::librapid::log(alpha); - else - vx -= ::librapid::log(alpha); - } - } - return Complex(ux, vx); - } - - /// \brief Compute the complex hyperbolic arc cosine of a complex number - /// - /// - /// This function computes the complex area hyperbolic cosine of the input complex number, - /// \f$ z = \text{acosh}(z) \f$ - /// - /// The algorithm handles NaN and infinity values, and avoids overflow. - /// - /// \tparam T Scalar type of the complex number - /// \param other Input complex number - /// \return Complex area hyperbolic cosine of the input complex number - template - LIBRAPID_NODISCARD Complex acosh(const Complex &other) { - const T arcBig = T(0.25) * ::librapid::sqrt(typetraits::TypeInfo::max()); - const T pi = []() { + }(); + + const T re = real(other); + const T im = imag(other); + T ux, vx; + + if (::librapid::isNaN(re) || ::librapid::isNaN(im)) { // At least one NaN + ux = typetraits::TypeInfo::quietNaN(); + vx = ux; + } else if (::librapid::isInf(re)) { // +/- Inf + if (::librapid::isInf(im)) { + if (re < 0) + ux = T(0.75) * pi; // (-Inf, +/-Inf) + else + ux = T(0.25) * pi; // (-Inf, +/-Inf) + } else if (re < 0) { + ux = pi; // (-Inf, finite) + } else { + ux = 0; // (+Inf, finite) + } + vx = -::librapid::copySign(typetraits::TypeInfo::infinity(), im); + } else if (::librapid::isInf(im)) { // finite, Inf) + ux = T(0.5) * pi; // (finite, +/-Inf) + vx = -im; + } else { // (finite, finite) + const Complex wx = sqrt(Complex(1 + re, -im)); + const Complex zx = sqrt(Complex(1 - re, -im)); + const T wr = real(wx); + const T wi = imag(wx); + const T zr = real(zx); + const T zi = imag(zx); + T alpha, beta; + + ux = 2 * ::librapid::atan2(zr, wr); + + if (arcBig < wr) { // Real part is large + alpha = wr; + beta = zi + wi * (zr / alpha); + } else if (arcBig < wi) { // Imaginary part is large + alpha = wi; + beta = wr * (zi / alpha) + zr; + } else if (wi < -arcBig) { // Imaginary part of w is large negative + alpha = -wi; + beta = wr * (zi / alpha) - zr; + } else { // Shouldn't overflow (?) + alpha = 0; + beta = wr * zi + wi * zr; // Im(w * z) + } + + vx = ::librapid::asinh(beta); + if (alpha != 0) { + // asinh(a * b) = asinh(a) + log(b) + if (0 <= vx) + vx += ::librapid::log(alpha); + else + vx -= ::librapid::log(alpha); + } + } + return Complex(ux, vx); + } + + /// \brief Compute the complex hyperbolic arc cosine of a complex number + /// + /// + /// This function computes the complex area hyperbolic cosine of the input complex number, + /// \f$ z = \text{acosh}(z) \f$ + /// + /// The algorithm handles NaN and infinity values, and avoids overflow. + /// + /// \tparam T Scalar type of the complex number + /// \param other Input complex number + /// \return Complex area hyperbolic cosine of the input complex number + template + LIBRAPID_NODISCARD Complex acosh(const Complex &other) { + const T arcBig = T(0.25) * ::librapid::sqrt(typetraits::TypeInfo::max()); + const T pi = []() { #if defined(LIBRAPID_USE_MULTIPREC) - if constexpr (std::is_same_v) - return ::librapid::constPi(); - else - return static_cast(3.1415926535897932384626433832795029L); + if constexpr (std::is_same_v) + return ::librapid::constPi(); + else + return static_cast(3.1415926535897932384626433832795029L); #else - return static_cast(3.1415926535897932384626433832795029L); + return static_cast(3.1415926535897932384626433832795029L); #endif - }(); - - const T re = real(other); - T im = imag(other); - T ux, vx; - - if (::librapid::isNaN(re) || ::librapid::isNaN(im)) { // At least one NaN - ux = typetraits::TypeInfo::quietNaN(); - vx = ux; - } else if (::librapid::isInf(re)) { // (+/-Inf, not NaN) - ux = typetraits::TypeInfo::infinity(); - if (::librapid::isInf(im)) { - if (re < 0) - vx = T(0.75) * pi; // (-Inf, +/-Inf) - else - vx = T(0.25) * pi; // (+Inf, +/-Inf) - } else if (re < 0) { - vx = pi; // (-Inf, finite) - } else { - vx = 0; // (+Inf, finite) - } - vx = ::librapid::copySign(vx, im); - } else { // (finite, finite) - const Complex wx = sqrt(Complex(re - 1, -im)); - const Complex zx = sqrt(Complex(re + 1, im)); - const T wr = real(wx); - const T wi = imag(wx); - const T zr = real(zx); - const T zi = imag(zx); - T alpha, beta; - - if (arcBig < wr) { // Real parts large - alpha = wr; - beta = zr - wi * (zi / alpha); - } else if (arcBig < wi) { // Imaginary parts large - alpha = wi; - beta = wr * (zr / alpha) - zi; - } else { // Shouldn't overflow (?) - alpha = 0; - beta = wr * zr - wi * zi; // Re(w * z) - } - - ux = ::librapid::asinh(beta); - if (alpha != 0) { - if (0 <= ux) - ux += ::librapid::log(alpha); - else - ux -= ::librapid::log(alpha); - } - vx = 2 * ::librapid::atan2(imag(sqrt(Complex(re - 1, im))), zr); - } - return Complex(ux, vx); - } - - /// \brief Compute the complex arc hyperbolic sine of a complex number - /// - /// This function computes the complex arc hyperbolic sine of the input complex number, - /// \f$ z = \text{asinh}(z) \f$ - /// - /// The algorithm handles NaN and infinity values, and avoids overflow. - /// - /// \tparam T Scalar type of the complex number - /// \param other Input complex number - /// \return Complex arc hyperbolic sine of the input complex number - template - LIBRAPID_NODISCARD Complex asinh(const Complex &other) { - const T arcBig = T(0.25) * ::librapid::sqrt(typetraits::TypeInfo::max()); - const T pi = []() { + }(); + + const T re = real(other); + T im = imag(other); + T ux, vx; + + if (::librapid::isNaN(re) || ::librapid::isNaN(im)) { // At least one NaN + ux = typetraits::TypeInfo::quietNaN(); + vx = ux; + } else if (::librapid::isInf(re)) { // (+/-Inf, not NaN) + ux = typetraits::TypeInfo::infinity(); + if (::librapid::isInf(im)) { + if (re < 0) + vx = T(0.75) * pi; // (-Inf, +/-Inf) + else + vx = T(0.25) * pi; // (+Inf, +/-Inf) + } else if (re < 0) { + vx = pi; // (-Inf, finite) + } else { + vx = 0; // (+Inf, finite) + } + vx = ::librapid::copySign(vx, im); + } else { // (finite, finite) + const Complex wx = sqrt(Complex(re - 1, -im)); + const Complex zx = sqrt(Complex(re + 1, im)); + const T wr = real(wx); + const T wi = imag(wx); + const T zr = real(zx); + const T zi = imag(zx); + T alpha, beta; + + if (arcBig < wr) { // Real parts large + alpha = wr; + beta = zr - wi * (zi / alpha); + } else if (arcBig < wi) { // Imaginary parts large + alpha = wi; + beta = wr * (zr / alpha) - zi; + } else { // Shouldn't overflow (?) + alpha = 0; + beta = wr * zr - wi * zi; // Re(w * z) + } + + ux = ::librapid::asinh(beta); + if (alpha != 0) { + if (0 <= ux) + ux += ::librapid::log(alpha); + else + ux -= ::librapid::log(alpha); + } + vx = 2 * ::librapid::atan2(imag(sqrt(Complex(re - 1, im))), zr); + } + return Complex(ux, vx); + } + + /// \brief Compute the complex arc hyperbolic sine of a complex number + /// + /// This function computes the complex arc hyperbolic sine of the input complex number, + /// \f$ z = \text{asinh}(z) \f$ + /// + /// The algorithm handles NaN and infinity values, and avoids overflow. + /// + /// \tparam T Scalar type of the complex number + /// \param other Input complex number + /// \return Complex arc hyperbolic sine of the input complex number + template + LIBRAPID_NODISCARD Complex asinh(const Complex &other) { + const T arcBig = T(0.25) * ::librapid::sqrt(typetraits::TypeInfo::max()); + const T pi = []() { #if defined(LIBRAPID_USE_MULTIPREC) - if constexpr (std::is_same_v) - return ::librapid::constPi(); - else - return static_cast(3.1415926535897932384626433832795029L); + if constexpr (std::is_same_v) + return ::librapid::constPi(); + else + return static_cast(3.1415926535897932384626433832795029L); #else - return static_cast(3.1415926535897932384626433832795029L); + return static_cast(3.1415926535897932384626433832795029L); #endif - }(); - - const T re = real(other); - T im = imag(other); - T ux, vx; - - if (::librapid::isNaN(re) || ::librapid::isNaN(im)) { // At least one NaN/Inf - ux = typetraits::TypeInfo::quietNaN(); - vx = ux; - } else if (::librapid::isInf(re)) { // (+/-Inf, not NaN) - if (::librapid::isInf(im)) { // (+/-Inf, +/-Inf) - ux = re; - vx = ::librapid::copySign(T(0.25) * pi, im); - } else { // (+/-Inf, finite) - ux = re; - vx = ::librapid::copySign(T(0), im); - } - } else if (::librapid::isInf(im)) { - ux = ::librapid::copySign(typetraits::TypeInfo::infinity(), re); - vx = ::librapid::copySign(T(0.5) * pi, im); - } else { // (finite, finite) - const Complex wx = sqrt(Complex(1 - im, re)); - const Complex zx = sqrt(Complex(1 + im, -re)); - const T wr = real(wx); - const T wi = imag(wx); - const T zr = real(zx); - const T zi = imag(zx); - T alpha, beta; - - if (arcBig < wr) { // Real parts are large - alpha = wr; - beta = wi * (zr / alpha) - zi; - } else if (arcBig < wi) { // Imaginary parts are large - alpha = wi; - beta = zr - wr * (zi / alpha); - } else if (wi < -arcBig) { - alpha = -wi; - beta = -zr - wr * (zi / alpha); - } else { // Shouldn't overflow (?) - alpha = 0; - beta = wi * zr - wr * zi; // Im(w * conj(z)) - } - - ux = ::librapid::asinh(beta); - if (alpha != 0) { - if (0 <= ux) - ux += ::librapid::log(alpha); - else - ux -= ::librapid::log(alpha); - } - vx = ::librapid::atan2(im, real(wx * zx)); - } - return Complex(ux, vx); - } - - /// \brief Compute the complex arc sine of a complex number - /// - /// This function computes the complex arc sine of the input complex number, - /// \f$ z = \text{asin}(z) \f$ - /// - /// It calculates the complex arc sine by using the complex hyperbolic sine function. - /// - /// \tparam T Scalar type of the complex number - /// \param other Input complex number - /// \return Complex arc sine of the input complex number - /// \see asinh - template - LIBRAPID_NODISCARD Complex asin(const Complex &other) { - Complex asinhVal = asinh(Complex(-imag(other), real(other))); - return Complex(imag(asinhVal), -real(asinhVal)); - } - - /// \brief Compute the complex arc hyperbolic tangent of a complex number - /// - /// This function computes the complex arc hyperbolic tangent of the input complex number, - /// \f$ z = \text{atanh}(z) \f$ - /// - /// This function performs error checking and supports NaNs and Infs. - /// - /// \tparam T Scalar type of the complex number - /// \param other Input complex number - /// \return Complex arc hyperbolic tangent of the input complex number - template - LIBRAPID_NODISCARD Complex atanh(const Complex &other) { - const T arcBig = T(0.25) * ::librapid::sqrt(typetraits::TypeInfo::max()); - const T piBy2 = []() { + }(); + + const T re = real(other); + T im = imag(other); + T ux, vx; + + if (::librapid::isNaN(re) || ::librapid::isNaN(im)) { // At least one NaN/Inf + ux = typetraits::TypeInfo::quietNaN(); + vx = ux; + } else if (::librapid::isInf(re)) { // (+/-Inf, not NaN) + if (::librapid::isInf(im)) { // (+/-Inf, +/-Inf) + ux = re; + vx = ::librapid::copySign(T(0.25) * pi, im); + } else { // (+/-Inf, finite) + ux = re; + vx = ::librapid::copySign(T(0), im); + } + } else if (::librapid::isInf(im)) { + ux = ::librapid::copySign(typetraits::TypeInfo::infinity(), re); + vx = ::librapid::copySign(T(0.5) * pi, im); + } else { // (finite, finite) + const Complex wx = sqrt(Complex(1 - im, re)); + const Complex zx = sqrt(Complex(1 + im, -re)); + const T wr = real(wx); + const T wi = imag(wx); + const T zr = real(zx); + const T zi = imag(zx); + T alpha, beta; + + if (arcBig < wr) { // Real parts are large + alpha = wr; + beta = wi * (zr / alpha) - zi; + } else if (arcBig < wi) { // Imaginary parts are large + alpha = wi; + beta = zr - wr * (zi / alpha); + } else if (wi < -arcBig) { + alpha = -wi; + beta = -zr - wr * (zi / alpha); + } else { // Shouldn't overflow (?) + alpha = 0; + beta = wi * zr - wr * zi; // Im(w * conj(z)) + } + + ux = ::librapid::asinh(beta); + if (alpha != 0) { + if (0 <= ux) + ux += ::librapid::log(alpha); + else + ux -= ::librapid::log(alpha); + } + vx = ::librapid::atan2(im, real(wx * zx)); + } + return Complex(ux, vx); + } + + /// \brief Compute the complex arc sine of a complex number + /// + /// This function computes the complex arc sine of the input complex number, + /// \f$ z = \text{asin}(z) \f$ + /// + /// It calculates the complex arc sine by using the complex hyperbolic sine function. + /// + /// \tparam T Scalar type of the complex number + /// \param other Input complex number + /// \return Complex arc sine of the input complex number + /// \see asinh + template + LIBRAPID_NODISCARD Complex asin(const Complex &other) { + Complex asinhVal = asinh(Complex(-imag(other), real(other))); + return Complex(imag(asinhVal), -real(asinhVal)); + } + + /// \brief Compute the complex arc hyperbolic tangent of a complex number + /// + /// This function computes the complex arc hyperbolic tangent of the input complex number, + /// \f$ z = \text{atanh}(z) \f$ + /// + /// This function performs error checking and supports NaNs and Infs. + /// + /// \tparam T Scalar type of the complex number + /// \param other Input complex number + /// \return Complex arc hyperbolic tangent of the input complex number + template + LIBRAPID_NODISCARD Complex atanh(const Complex &other) { + const T arcBig = T(0.25) * ::librapid::sqrt(typetraits::TypeInfo::max()); + const T piBy2 = []() { #if defined(LIBRAPID_USE_MULTIPREC) - if constexpr (std::is_same_v) - return ::librapid::constPi() / 2; - else - return static_cast(1.5707963267948966192313216916397514L); + if constexpr (std::is_same_v) + return ::librapid::constPi() / 2; + else + return static_cast(1.5707963267948966192313216916397514L); #else - return static_cast(1.5707963267948966192313216916397514L); + return static_cast(1.5707963267948966192313216916397514L); #endif - }(); - - T re = real(other); - T im = imag(other); - T ux, vx; - - if (::librapid::isNaN(re) || ::librapid::isNaN(im)) { // At least one NaN - ux = typetraits::TypeInfo::quietNaN(); - vx = ux; - } else if (::librapid::isInf(re)) { // (+/-Inf, not NaN) - ux = ::librapid::copySign(T(0), re); - vx = ::librapid::copySign(piBy2, im); - } else { // (finite, not NaN) - const T magIm = ::librapid::abs(im); - const T oldRe = re; - - re = ::librapid::abs(re); - - if (arcBig < re) { // |re| is large - T fx = im / re; - ux = 1 / re / (1 + fx * fx); - vx = ::librapid::copySign(piBy2, im); - } else if (arcBig < magIm) { // |im| is large - T fx = re / im; - ux = fx / im / (1 + fx * fx); - vx = ::librapid::copySign(piBy2, im); - } else if (re != 1) { // |re| is small - T reFrom1 = 1 - re; - T imEps2 = magIm * magIm; - ux = T(0.25) * detail::algorithm::logP1(4 * re / (reFrom1 * reFrom1 + imEps2)); - vx = T(0.5) * ::librapid::atan2(2 * im, reFrom1 * (1 + re) - imEps2); - } else if (im == 0) { // {+/-1, 0) - ux = typetraits::TypeInfo::infinity(); - vx = im; - } else { // (+/-1, nonzero) - ux = ::librapid::log(::librapid::sqrt(::librapid::sqrt(4 + im * im)) / - ::librapid::sqrt(magIm)); - vx = ::librapid::copySign(T(0.5) * (piBy2 + ::librapid::atan2(magIm, T(2))), im); - } - ux = ::librapid::copySign(ux, oldRe); - } - return Complex(ux, vx); - } - - /// \brief Compute the complex arc tangent of a complex number - /// - /// This function computes the complex arc tangent of the input complex number, - /// \f$ z = \text{atan}(z) \f$ - /// - /// The algorithm handles NaN and infinity values, and avoids overflow. - /// - /// \tparam T Scalar type of the complex number - /// \param other Input complex number - /// \return Complex arc tangent of the input complex number - template - LIBRAPID_NODISCARD Complex atan(const Complex &other) { - Complex atanhVal = ::librapid::atanh(Complex(-imag(other), real(other))); - return Complex(imag(atanhVal), -real(atanhVal)); - } - - /// \brief Compute the complex hyperbolic cosine of a complex number - /// - /// This function computes the complex hyperbolic cosine of the input complex number, - /// \f$ z = \text{cosh}(z) \f$ - /// - /// \tparam T Scalar type of the complex number - /// \param other Input complex number - /// \return Complex hyperbolic cosine of the input complex number - template - LIBRAPID_NODISCARD Complex cosh(const Complex &other) { - return Complex(::librapid::cosh(real(other)) * ::librapid::cos(imag(other)), - ::librapid::sinh(real(other)) * ::librapid::sin(imag(other))); - } - - template - LIBRAPID_NODISCARD Complex polarPositiveNanInfZeroRho(const T &rho, const T &theta) { - // Rho is +NaN/+Inf/+0 - if (::librapid::isNaN(theta) || ::librapid::isInf(theta)) { // Theta is NaN/Inf - if (::librapid::isInf(rho)) { - return Complex(rho, ::librapid::sin(theta)); // (Inf, NaN/Inf) - } else { - return Complex(rho, ::librapid::copySign(rho, theta)); // (NaN/0, NaN/Inf) - } - } else if (theta == T(0)) { // Theta is zero - return Complex(rho, theta); // (NaN/Inf/0, 0) - } else { // Theta is finite non-zero - // (NaN/Inf/0, finite non-zero) - return Complex(rho * ::librapid::cos(theta), rho * ::librapid::sin(theta)); - } - } - - /// \brief Compute the complex exponential of a complex number - /// - /// This function computes the complex exponential of the input complex number, - /// \f$ z = e^z \f$ - /// - /// The algorithm handles NaN and infinity values. - /// - /// \tparam T Scalar type of the complex number - /// \param other Input complex number - /// \return Complex exponential of the input complex number - template - LIBRAPID_NODISCARD Complex exp(const Complex &other) { - const T logRho = real(other); - const T theta = imag(other); - - if (!::librapid::isNaN(logRho) && !::librapid::isInf(logRho)) { // Real component is finite - T real = logRho; - T imag = logRho; - detail::algorithm::expMul(&real, static_cast(::librapid::cos(theta)), 0); - detail::algorithm::expMul(&imag, static_cast(::librapid::sin(theta)), 0); - return Complex(real, imag); - } - - // Real component is NaN/Inf - // Return polar(exp(re), im) - if (::librapid::isInf(logRho)) { - if (logRho < 0) { - return polarPositiveNanInfZeroRho(T(0), theta); // exp(-Inf) = +0 - } else { - return polarPositiveNanInfZeroRho(logRho, theta); // exp(+Inf) = +Inf - } - } else { - return polarPositiveNanInfZeroRho(static_cast(::librapid::abs(logRho)), - theta); // exp(NaN) = +NaN - } - } - - /// \brief Compute the complex exponential base 2 of a complex number - /// \see exp - /// \tparam T Scalar type of the complex number - /// \param other Input complex number - /// \return Complex exponential base 2 of the input complex number - template - LIBRAPID_NODISCARD Complex exp2(const Complex &other) { - return pow(T(2), other); - } - - /// \brief Compute the complex exponential base 10 of a complex number - /// \see exp - /// \tparam T Scalar type of the complex number - /// \param other Input complex number - /// \return Complex exponential base 10 of the input complex number - template - LIBRAPID_NODISCARD Complex exp10(const Complex &other) { - return pow(T(10), other); - } - - template - T _fabs(const Complex &other, int64_t *exp) { - *exp = 0; - T av = ::librapid::abs(real(other)); - T bv = ::librapid::abs(imag(other)); - - if (::librapid::isInf(av) || ::librapid::isInf(bv)) { - return typetraits::TypeInfo::infinity(); // At least one component is Inf - } else if (::librapid::isNaN(av)) { - return av; // Real component is NaN - } else if (::librapid::isNaN(bv)) { - return bv; // Imaginary component is NaN - } else { - if (av < bv) std::swap(av, bv); - if (av == 0) return av; // |0| = 0 - - if (1 <= av) { - *exp = 4; - av = av * T(0.0625); - bv = bv * T(0.0625); - } else { - const T fltEps = typetraits::TypeInfo::epsilon(); - const T legTiny = fltEps == 0 ? T(0) : 2 * typetraits::TypeInfo::min() / fltEps; - - if (av < legTiny) { - int64_t exponent; + }(); + + T re = real(other); + T im = imag(other); + T ux, vx; + + if (::librapid::isNaN(re) || ::librapid::isNaN(im)) { // At least one NaN + ux = typetraits::TypeInfo::quietNaN(); + vx = ux; + } else if (::librapid::isInf(re)) { // (+/-Inf, not NaN) + ux = ::librapid::copySign(T(0), re); + vx = ::librapid::copySign(piBy2, im); + } else { // (finite, not NaN) + const T magIm = ::librapid::abs(im); + const T oldRe = re; + + re = ::librapid::abs(re); + + if (arcBig < re) { // |re| is large + T fx = im / re; + ux = 1 / re / (1 + fx * fx); + vx = ::librapid::copySign(piBy2, im); + } else if (arcBig < magIm) { // |im| is large + T fx = re / im; + ux = fx / im / (1 + fx * fx); + vx = ::librapid::copySign(piBy2, im); + } else if (re != 1) { // |re| is small + T reFrom1 = 1 - re; + T imEps2 = magIm * magIm; + ux = T(0.25) * detail::algorithm::logP1(4 * re / (reFrom1 * reFrom1 + imEps2)); + vx = T(0.5) * ::librapid::atan2(2 * im, reFrom1 * (1 + re) - imEps2); + } else if (im == 0) { // {+/-1, 0) + ux = typetraits::TypeInfo::infinity(); + vx = im; + } else { // (+/-1, nonzero) + ux = ::librapid::log(::librapid::sqrt(::librapid::sqrt(4 + im * im)) / + ::librapid::sqrt(magIm)); + vx = ::librapid::copySign(T(0.5) * (piBy2 + ::librapid::atan2(magIm, T(2))), im); + } + ux = ::librapid::copySign(ux, oldRe); + } + return Complex(ux, vx); + } + + /// \brief Compute the complex arc tangent of a complex number + /// + /// This function computes the complex arc tangent of the input complex number, + /// \f$ z = \text{atan}(z) \f$ + /// + /// The algorithm handles NaN and infinity values, and avoids overflow. + /// + /// \tparam T Scalar type of the complex number + /// \param other Input complex number + /// \return Complex arc tangent of the input complex number + template + LIBRAPID_NODISCARD Complex atan(const Complex &other) { + Complex atanhVal = ::librapid::atanh(Complex(-imag(other), real(other))); + return Complex(imag(atanhVal), -real(atanhVal)); + } + + /// \brief Compute the complex hyperbolic cosine of a complex number + /// + /// This function computes the complex hyperbolic cosine of the input complex number, + /// \f$ z = \text{cosh}(z) \f$ + /// + /// \tparam T Scalar type of the complex number + /// \param other Input complex number + /// \return Complex hyperbolic cosine of the input complex number + template + LIBRAPID_NODISCARD Complex cosh(const Complex &other) { + return Complex(::librapid::cosh(real(other)) * ::librapid::cos(imag(other)), + ::librapid::sinh(real(other)) * ::librapid::sin(imag(other))); + } + + template + LIBRAPID_NODISCARD Complex polarPositiveNanInfZeroRho(const T &rho, const T &theta) { + // Rho is +NaN/+Inf/+0 + if (::librapid::isNaN(theta) || ::librapid::isInf(theta)) { // Theta is NaN/Inf + if (::librapid::isInf(rho)) { + return Complex(rho, ::librapid::sin(theta)); // (Inf, NaN/Inf) + } else { + return Complex(rho, ::librapid::copySign(rho, theta)); // (NaN/0, NaN/Inf) + } + } else if (theta == T(0)) { // Theta is zero + return Complex(rho, theta); // (NaN/Inf/0, 0) + } else { // Theta is finite non-zero + // (NaN/Inf/0, finite non-zero) + return Complex(rho * ::librapid::cos(theta), rho * ::librapid::sin(theta)); + } + } + + /// \brief Compute the complex exponential of a complex number + /// + /// This function computes the complex exponential of the input complex number, + /// \f$ z = e^z \f$ + /// + /// The algorithm handles NaN and infinity values. + /// + /// \tparam T Scalar type of the complex number + /// \param other Input complex number + /// \return Complex exponential of the input complex number + template + LIBRAPID_NODISCARD Complex exp(const Complex &other) { + const T logRho = real(other); + const T theta = imag(other); + + if (!::librapid::isNaN(logRho) && !::librapid::isInf(logRho)) { // Real component is finite + T real = logRho; + T imag = logRho; + detail::algorithm::expMul(&real, static_cast(::librapid::cos(theta)), 0); + detail::algorithm::expMul(&imag, static_cast(::librapid::sin(theta)), 0); + return Complex(real, imag); + } + + // Real component is NaN/Inf + // Return polar(exp(re), im) + if (::librapid::isInf(logRho)) { + if (logRho < 0) { + return polarPositiveNanInfZeroRho(T(0), theta); // exp(-Inf) = +0 + } else { + return polarPositiveNanInfZeroRho(logRho, theta); // exp(+Inf) = +Inf + } + } else { + return polarPositiveNanInfZeroRho(static_cast(::librapid::abs(logRho)), + theta); // exp(NaN) = +NaN + } + } + + /// \brief Compute the complex exponential base 2 of a complex number + /// \see exp + /// \tparam T Scalar type of the complex number + /// \param other Input complex number + /// \return Complex exponential base 2 of the input complex number + template + LIBRAPID_NODISCARD Complex exp2(const Complex &other) { + return pow(T(2), other); + } + + /// \brief Compute the complex exponential base 10 of a complex number + /// \see exp + /// \tparam T Scalar type of the complex number + /// \param other Input complex number + /// \return Complex exponential base 10 of the input complex number + template + LIBRAPID_NODISCARD Complex exp10(const Complex &other) { + return pow(T(10), other); + } + + template + T _fabs(const Complex &other, int64_t *exp) { + *exp = 0; + T av = ::librapid::abs(real(other)); + T bv = ::librapid::abs(imag(other)); + + if (::librapid::isInf(av) || ::librapid::isInf(bv)) { + return typetraits::TypeInfo::infinity(); // At least one component is Inf + } else if (::librapid::isNaN(av)) { + return av; // Real component is NaN + } else if (::librapid::isNaN(bv)) { + return bv; // Imaginary component is NaN + } else { + if (av < bv) std::swap(av, bv); + if (av == 0) return av; // |0| = 0 + + if (1 <= av) { + *exp = 4; + av = av * T(0.0625); + bv = bv * T(0.0625); + } else { + const T fltEps = typetraits::TypeInfo::epsilon(); + const T legTiny = fltEps == 0 ? T(0) : 2 * typetraits::TypeInfo::min() / fltEps; + + if (av < legTiny) { + int64_t exponent; #if defined(LIBRAPID_USE_MULTIPREC) - if constexpr (std::is_same_v) { - exponent = -2 * ::mpfr::mpreal::get_default_prec(); - } else { - exponent = -2 * std::numeric_limits::digits; - } + if constexpr (std::is_same_v) { + exponent = -2 * ::mpfr::mpreal::get_default_prec(); + } else { + exponent = -2 * std::numeric_limits::digits; + } #else - exponent = -2 * std::numeric_limits::digits; + exponent = -2 * std::numeric_limits::digits; #endif - *exp = exponent; - av = ::librapid::ldexp(av, -exponent); - bv = ::librapid::ldexp(bv, -exponent); - } else { - *exp = -2; - av = av * 4; - bv = bv * 4; - } - } - - const T tmp = av - bv; - if (tmp == av) { - return av; // bv is unimportant - } else { + *exp = exponent; + av = ::librapid::ldexp(av, -exponent); + bv = ::librapid::ldexp(bv, -exponent); + } else { + *exp = -2; + av = av * 4; + bv = bv * 4; + } + } + + const T tmp = av - bv; + if (tmp == av) { + return av; // bv is unimportant + } else { #if defined(LIBRAPID_USE_MULTIPREC) - if constexpr (std::is_same_v) { // No approximations - const T root2 = ::librapid::sqrt(mpfr(2)); - const T onePlusRoot2 = root2 + 1; - - const T qv = tmp / bv; - const T rv = (qv + 2) * qv; - const T sv = rv / (root2 + ::librapid::sqrt(rv + 2)) + onePlusRoot2 + qv; - return av + bv / sv; - } else { + if constexpr (std::is_same_v) { // No approximations + const T root2 = ::librapid::sqrt(mpfr(2)); + const T onePlusRoot2 = root2 + 1; + + const T qv = tmp / bv; + const T rv = (qv + 2) * qv; + const T sv = rv / (root2 + ::librapid::sqrt(rv + 2)) + onePlusRoot2 + qv; + return av + bv / sv; + } else { #endif - if (bv < tmp) { // Use a simple approximation - const T qv = av / bv; - return av + bv / (qv + ::librapid::sqrt(qv * qv + 1)); - } else { // Use 1 1/2 precision to preserve bits - constexpr T root2 = static_cast(1.4142135623730950488016887242096981L); - constexpr T onePlusRoot2High = static_cast(10125945.0 / 4194304.0); - constexpr T onePlusRoot2Low = - static_cast(1.4341252375973918872420969807856967e-7L); - - const T qv = tmp / bv; - const T rv = (qv + 2) * qv; - const T sv = rv / (root2 + ::librapid::sqrt(rv + 2)) + onePlusRoot2Low + - qv + onePlusRoot2High; - return av + bv / sv; - } + if (bv < tmp) { // Use a simple approximation + const T qv = av / bv; + return av + bv / (qv + ::librapid::sqrt(qv * qv + 1)); + } else { // Use 1 1/2 precision to preserve bits + constexpr T root2 = static_cast(1.4142135623730950488016887242096981L); + constexpr T onePlusRoot2High = static_cast(10125945.0 / 4194304.0); + constexpr T onePlusRoot2Low = + static_cast(1.4341252375973918872420969807856967e-7L); + + const T qv = tmp / bv; + const T rv = (qv + 2) * qv; + const T sv = rv / (root2 + ::librapid::sqrt(rv + 2)) + onePlusRoot2Low + + qv + onePlusRoot2High; + return av + bv / sv; + } #if defined(LIBRAPID_USE_MULTIPREC) - } + } #endif - } - } - } + } + } + } - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T _logAbs(const Complex &other) noexcept { - return static_cast(detail::algorithm::logHypot(static_cast(real(other)), - static_cast(imag(other)))); - } + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T _logAbs(const Complex &other) noexcept { + return static_cast(detail::algorithm::logHypot(static_cast(real(other)), + static_cast(imag(other)))); + } #if defined(LIBRAPID_USE_MULTIPREC) - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE mpfr _logAbs(const Complex &other) noexcept { - return detail::algorithm::logHypot(real(other), imag(other)); - } + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE mpfr _logAbs(const Complex &other) noexcept { + return detail::algorithm::logHypot(real(other), imag(other)); + } #endif - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE float _logAbs(const Complex &other) noexcept { - return detail::algorithm::logHypot(real(other), imag(other)); - } - - /// \brief Calculates the natural logarithm of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return Natural logarithm of the complex number - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex log(const Complex &other) { - const T logAbs = _logAbs(other); - const T theta = ::librapid::atan2(imag(other), real(other)); - return Complex(logAbs, theta); - } - - /// \brief Calculates the logarithm of a complex number with a complex base - /// - /// \f$ \log_{\mathrm{base}}(z) = \log(z) / \log(\mathrm{base}) \f$ - /// \tparam T Scalar type - /// \tparam B Base type - /// \param other Complex number - /// \param base Base of the logarithm - /// \return Logarithm of the complex number with the given base - /// \see log - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex log(const Complex &other, - const Complex &base) { - return log(other) / log(base); - } - - /// \brief Calculates the logarithm of a complex number with a real base - /// - /// \f$ \log_{\mathrm{base}}(z) = \log(z) / \log(\mathrm{base}) \f$ - /// \tparam T Scalar type of the complex number - /// \tparam B Scalar type of the base - /// \param other Complex number - /// \param base Base of the logarithm (real) - /// \return Logarithm of the complex number with the given base - /// \see log - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex log(const Complex &other, - const B &base) { - const T logAbs = _logAbs(other); - const T theta = ::librapid::atan2(imag(other), real(other)); - return Complex(logAbs, theta) / ::librapid::log(base); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex _pow(const T &left, const T &right) { - if (0 <= left) { - return Complex(::librapid::pow(left, right), ::librapid::copySign(T(0), right)); - } else { - return exp(right * log(Complex(left))); - } - } - - /// \brief Calculate \f$ \text{left}^{\text{right}} \f$ for a complex-valued left-hand side - /// \tparam T Value type for the left-hand side - /// \tparam V Value type for the right-hand side - /// \param left Complex base - /// \param right Real exponent - /// \return \f$ \text{left}^{\text{right}} \f$ - template::type == detail::LibRapidType::Scalar, int> = 0> - LIBRAPID_NODISCARD Complex pow(const Complex &left, const V &right) { - if (imag(left) == 0) { - if (::librapid::signBit(imag(left))) { - return conj(_pow(real(left), static_cast(right))); - } else { - return _pow(real(left), static_cast(right)); - } - } else { - return exp(static_cast(right) * log(left)); - } - } - - /// \brief Calculate \f$ \text{left}^{\text{right}} \f$ for a complex-valued right-hand side - /// \tparam T Value type for the left-hand side - /// \tparam V Value type for the right-hand side - /// \param left Real base - /// \param right Complex exponent - /// \return \f$ \text{left}^{\text{right}} \f$ - template::type == detail::LibRapidType::Scalar, int> = 0> - LIBRAPID_NODISCARD Complex pow(const V &left, const Complex &right) { - if (imag(right) == 0) { - return _pow(static_cast(left), real(right)); - } else if (0 < left) { - return exp(right * ::librapid::log(static_cast(left))); - } else { - return exp(right * log(Complex(static_cast(left)))); - } - } - - /// \brief Calculate \f$ \text{left}^{\text{right}} \f$ for complex numbers - /// \tparam T Complex number component type - /// \param left Complex base - /// \param right Complex exponent - /// \return \f$ \text{left}^{\text{right}} \f$ - template - LIBRAPID_NODISCARD Complex pow(const Complex &left, const Complex &right) { - if (imag(right) == 0) { - return pow(left, real(right)); - } else if (imag(left) == 0 && 0 < real(left)) { - return exp(right * ::librapid::log(real(left))); - } else { - return exp(right * log(left)); - } - } - - /// \brief Calculate the hyperbolic sine of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \sinh(z) \f$ - template - LIBRAPID_NODISCARD Complex sinh(const Complex &other) { - return Complex(::librapid::sinh(real(other)) * ::librapid::cos(imag(other)), - ::librapid::cosh(real(other)) * ::librapid::sin(imag(other))); - } - - template - LIBRAPID_NODISCARD Complex sqrt(const Complex &other) { - int64_t otherExp; - T rho = _fabs(other, &otherExp); // Get magnitude and scale factor - - if (otherExp == 0) { // Argument is zero, Inf or NaN - if (rho == 0) { - return Complex(T(0), imag(other)); - } else if (::librapid::isInf(rho)) { - const T re = real(other); - const T im = imag(other); - - if (::librapid::isInf(im)) { - return Complex(typetraits::TypeInfo::infinity(), im); // (any, +/-Inf) - } else if (::librapid::isNaN(im)) { - if (re < 0) { - // (-Inf, NaN) - return Complex(::librapid::abs(im), ::librapid::copySign(re, im)); - } else { - return other; // (+Inf, NaN) - } - } else { - if (re < 0) { - return Complex(T(0), ::librapid::copySign(re, im)); // (-Inf, finite) - } else { - return Complex(re, ::librapid::copySign(T(0), im)); // (+Inf, finite) - } - } - } else { - return Complex(rho, rho); - } - } else { // Compute in safest quadrant - T realMag = ::librapid::ldexp(::librapid::abs(real(other)), -otherExp); - rho = ::librapid::ldexp(::librapid::sqrt(2 * (realMag + rho)), otherExp / 2 - 1); - if (0 <= real(other)) { - return Complex(rho, imag(other) / (2 * rho)); - } else { - return Complex(::librapid::abs(imag(other) / (2 * rho)), - ::librapid::copySign(rho, imag(other))); - } - } - } - - /// \brief Calculate the hyperbolic tangent of a complex number - /// - /// This function supports propagation of NaNs and Infs. - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \tanh(z) \f$ - template - LIBRAPID_NODISCARD Complex tanh(const Complex &other) { - T tv = ::librapid::tan(imag(other)); - T sv = ::librapid::sinh(real(other)); - T bv = sv * (T(1) + tv * tv); - T dv = T(1) + bv * sv; - - if (::librapid::isInf(dv)) { - T real; - if (sv < T(0)) - real = T(-1); - else - real = T(1); - return Complex(real, T(0)); - } - return Complex((::librapid::sqrt(T(1) + sv * sv)) * bv / dv, tv / dv); - } - - // Return the phase angle of a complex value as a real - - /// \brief Return the phase angle of a complex value as a real - /// - /// This function calls \f$ \text{atan2}(\text{imag}(z), \text{real}(z)) \f$. - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \arg(z) \f$ - /// \see atan2 - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T arg(const Complex &other) { - return ::librapid::atan2(imag(other), real(other)); - } - - /// \brief Project a complex number onto the Riemann sphere - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \text{proj}(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex proj(const Complex &other) { - if (::librapid::isInf(real(other)) || ::librapid::isInf(imag(other))) { - const T im = ::librapid::copySign(T(0), imag(other)); - return Complex(typetraits::TypeInfo::infinity(), im); - } - return other; - } - - /// \brief Calculate the cosine of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \cos(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex cos(const Complex &other) { - return Complex(::librapid::cosh(imag(other)) * ::librapid::cos(real(other)), - -::librapid::sinh(imag(other)) * ::librapid::sin(real(other))); - } - - /// \brief Calculate the cosecant of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \csc(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex csc(const Complex &other) { - return T(1) / sin(other); - } - - /// \brief Calculate the secant of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \sec(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex sec(const Complex &other) { - return T(1) / cos(other); - } - - /// \brief Calculate the cotangent of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \cot(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex cot(const Complex &other) { - return T(1) / tan(other); - } - - /// \brief Calculate the arc cosecant of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \operatorname{arccsc}(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex acsc(const Complex &other) { - return asin(T(1) / other); - } - - /// \brief Calculate the arc secant of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \operatorname{arcsec}(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex asec(const Complex &other) { - return acos(T(1) / other); - } - - /// \brief Calculate the arc cotangent of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \operatorname{arccot}(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex acot(const Complex &other) { - return atan(T(1) / other); - } - - /// \brief Calculate the logarithm base 2 of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \log_2(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex log2(const Complex &other) { - return log(other) / ::librapid::log(T(2)); - } - - /// \brief Calculate the logarithm base 10 of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \log_{10}(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex log10(const Complex &other) { - return log(other) / ::librapid::log(10); - } - - // Return magnitude squared - - /// \brief Calculate the magnitude squared of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ |z|^2 \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T norm(const Complex &other) { - return real(other) * real(other) + imag(other) * imag(other); - } - - /// \brief Return a complex number from polar coordinates - /// - /// Given a radius, \p rho, and an angle, \p theta, this function returns the complex number - /// \f$ \rho e^{i\theta} \f$. - /// - /// The function returns NaN, infinity or zero based on the input values of rho. - /// \tparam T Scalar type of the complex number - /// \param rho Radius of the polar coordinate system - /// \param theta Angle of the polar coordinate system - /// \return Complex number in polar form. - template - LIBRAPID_NODISCARD Complex polar(const T &rho, const T &theta) { - if (!::librapid::isNaN(rho) && !::librapid::isInf(rho) && rho != T(0)) { - // Rho is finite and non-zero - return Complex(rho * ::librapid::cos(theta), rho * ::librapid::sin(theta)); - } - - // Rho is NaN/Inf/0 - if (::librapid::signBit(rho)) - return -polarPositiveNanInfZeroRho(-rho, theta); - else - return polarPositiveNanInfZeroRho(rho, theta); - } - - /// \brief Compute the sine of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \sin(z) \f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex sin(const Complex &other) { - return Complex(::librapid::cosh(imag(other)) * ::librapid::sin(real(other)), - ::librapid::sinh(imag(other)) * ::librapid::cos(real(other))); - } - - /// \brief Compute the tangent of a complex number - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ \tan(z) \f$ - template - LIBRAPID_NODISCARD Complex tan(const Complex &other) { - Complex zv(tanh(Complex(-imag(other), real(other)))); - return Complex(imag(zv), -real(zv)); - } - - /// \brief Round the real and imaginary parts of a complex number towards \f$ -\infty \f$ - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$ (\lfloor\operatorname{real}(z)\rfloor,\lfloor\operatorname{imag}(z)\rfloor )\f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex floor(const Complex &other) { - return Complex(::librapid::floor(real(other)), ::librapid::floor(imag(other))); - } - - /// \brief Round the real and imaginary parts of a complex number towards \f$ +\infty \f$ - /// \tparam T Scalar type - /// \param other Complex number - /// \return \f$(\lceil\operatorname{real}(z)\rceil,\lceil\operatorname{imag}(z)\rceil )\f$ - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex ceil(const Complex &other) { - return Complex(::librapid::ceil(real(other)), ::librapid::ceil(imag(other))); - } - - /// \brief Generate a random complex number between two given complex numbers - /// - /// This function generates a random complex number in the range [min, max], where min - /// and max are given as input. The function uses a default seed if none is provided. - /// - /// \tparam T Scalar type of the complex number - /// \param min Minimum complex number - /// \param max Maximum complex number - /// \param seed Seed for the random number generator - /// \return Random complex number between min and max - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto random(const Complex &min, - const Complex &max, uint64_t seed = -1) - -> Complex { - return Complex(::librapid::random(real(min), real(max), seed), - ::librapid::random(imag(min), imag(max), seed)); - } - - namespace typetraits { - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = Complex; - using Packet = std::false_type; - // typename std::conditional_t<(TypeInfo::packetWidth > 1), - // Complex::Packet>, std::false_type>; - static constexpr int64_t packetWidth = - 0; // TypeInfo::Scalar>::packetWidth; - static constexpr char name[] = "Complex"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = false; + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE float _logAbs(const Complex &other) noexcept { + return detail::algorithm::logHypot(real(other), imag(other)); + } + + /// \brief Calculates the natural logarithm of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return Natural logarithm of the complex number + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex log(const Complex &other) { + const T logAbs = _logAbs(other); + const T theta = ::librapid::atan2(imag(other), real(other)); + return Complex(logAbs, theta); + } + + /// \brief Calculates the logarithm of a complex number with a complex base + /// + /// \f$ \log_{\mathrm{base}}(z) = \log(z) / \log(\mathrm{base}) \f$ + /// \tparam T Scalar type + /// \tparam B Base type + /// \param other Complex number + /// \param base Base of the logarithm + /// \return Logarithm of the complex number with the given base + /// \see log + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex log(const Complex &other, + const Complex &base) { + return log(other) / log(base); + } + + /// \brief Calculates the logarithm of a complex number with a real base + /// + /// \f$ \log_{\mathrm{base}}(z) = \log(z) / \log(\mathrm{base}) \f$ + /// \tparam T Scalar type of the complex number + /// \tparam B Scalar type of the base + /// \param other Complex number + /// \param base Base of the logarithm (real) + /// \return Logarithm of the complex number with the given base + /// \see log + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex log(const Complex &other, + const B &base) { + const T logAbs = _logAbs(other); + const T theta = ::librapid::atan2(imag(other), real(other)); + return Complex(logAbs, theta) / ::librapid::log(base); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex _pow(const T &left, const T &right) { + if (0 <= left) { + return Complex(::librapid::pow(left, right), ::librapid::copySign(T(0), right)); + } else { + return exp(right * log(Complex(left))); + } + } + + /// \brief Calculate \f$ \text{left}^{\text{right}} \f$ for a complex-valued left-hand side + /// \tparam T Value type for the left-hand side + /// \tparam V Value type for the right-hand side + /// \param left Complex base + /// \param right Real exponent + /// \return \f$ \text{left}^{\text{right}} \f$ + template::type == detail::LibRapidType::Scalar, int> = 0> + LIBRAPID_NODISCARD Complex pow(const Complex &left, const V &right) { + if (imag(left) == 0) { + if (::librapid::signBit(imag(left))) { + return conj(_pow(real(left), static_cast(right))); + } else { + return _pow(real(left), static_cast(right)); + } + } else { + return exp(static_cast(right) * log(left)); + } + } + + /// \brief Calculate \f$ \text{left}^{\text{right}} \f$ for a complex-valued right-hand side + /// \tparam T Value type for the left-hand side + /// \tparam V Value type for the right-hand side + /// \param left Real base + /// \param right Complex exponent + /// \return \f$ \text{left}^{\text{right}} \f$ + template::type == detail::LibRapidType::Scalar, int> = 0> + LIBRAPID_NODISCARD Complex pow(const V &left, const Complex &right) { + if (imag(right) == 0) { + return _pow(static_cast(left), real(right)); + } else if (0 < left) { + return exp(right * ::librapid::log(static_cast(left))); + } else { + return exp(right * log(Complex(static_cast(left)))); + } + } + + /// \brief Calculate \f$ \text{left}^{\text{right}} \f$ for complex numbers + /// \tparam T Complex number component type + /// \param left Complex base + /// \param right Complex exponent + /// \return \f$ \text{left}^{\text{right}} \f$ + template + LIBRAPID_NODISCARD Complex pow(const Complex &left, const Complex &right) { + if (imag(right) == 0) { + return pow(left, real(right)); + } else if (imag(left) == 0 && 0 < real(left)) { + return exp(right * ::librapid::log(real(left))); + } else { + return exp(right * log(left)); + } + } + + /// \brief Calculate the hyperbolic sine of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \sinh(z) \f$ + template + LIBRAPID_NODISCARD Complex sinh(const Complex &other) { + return Complex(::librapid::sinh(real(other)) * ::librapid::cos(imag(other)), + ::librapid::cosh(real(other)) * ::librapid::sin(imag(other))); + } + + template + LIBRAPID_NODISCARD Complex sqrt(const Complex &other) { + int64_t otherExp; + T rho = _fabs(other, &otherExp); // Get magnitude and scale factor + + if (otherExp == 0) { // Argument is zero, Inf or NaN + if (rho == 0) { + return Complex(T(0), imag(other)); + } else if (::librapid::isInf(rho)) { + const T re = real(other); + const T im = imag(other); + + if (::librapid::isInf(im)) { + return Complex(typetraits::TypeInfo::infinity(), im); // (any, +/-Inf) + } else if (::librapid::isNaN(im)) { + if (re < 0) { + // (-Inf, NaN) + return Complex(::librapid::abs(im), ::librapid::copySign(re, im)); + } else { + return other; // (+Inf, NaN) + } + } else { + if (re < 0) { + return Complex(T(0), ::librapid::copySign(re, im)); // (-Inf, finite) + } else { + return Complex(re, ::librapid::copySign(T(0), im)); // (+Inf, finite) + } + } + } else { + return Complex(rho, rho); + } + } else { // Compute in safest quadrant + T realMag = ::librapid::ldexp(::librapid::abs(real(other)), -otherExp); + rho = ::librapid::ldexp(::librapid::sqrt(2 * (realMag + rho)), otherExp / 2 - 1); + if (0 <= real(other)) { + return Complex(rho, imag(other) / (2 * rho)); + } else { + return Complex(::librapid::abs(imag(other) / (2 * rho)), + ::librapid::copySign(rho, imag(other))); + } + } + } + + /// \brief Calculate the hyperbolic tangent of a complex number + /// + /// This function supports propagation of NaNs and Infs. + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \tanh(z) \f$ + template + LIBRAPID_NODISCARD Complex tanh(const Complex &other) { + T tv = ::librapid::tan(imag(other)); + T sv = ::librapid::sinh(real(other)); + T bv = sv * (T(1) + tv * tv); + T dv = T(1) + bv * sv; + + if (::librapid::isInf(dv)) { + T real; + if (sv < T(0)) + real = T(-1); + else + real = T(1); + return Complex(real, T(0)); + } + return Complex((::librapid::sqrt(T(1) + sv * sv)) * bv / dv, tv / dv); + } + + // Return the phase angle of a complex value as a real + + /// \brief Return the phase angle of a complex value as a real + /// + /// This function calls \f$ \text{atan2}(\text{imag}(z), \text{real}(z)) \f$. + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \arg(z) \f$ + /// \see atan2 + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T arg(const Complex &other) { + return ::librapid::atan2(imag(other), real(other)); + } + + /// \brief Project a complex number onto the Riemann sphere + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \text{proj}(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex proj(const Complex &other) { + if (::librapid::isInf(real(other)) || ::librapid::isInf(imag(other))) { + const T im = ::librapid::copySign(T(0), imag(other)); + return Complex(typetraits::TypeInfo::infinity(), im); + } + return other; + } + + /// \brief Calculate the cosine of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \cos(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex cos(const Complex &other) { + return Complex(::librapid::cosh(imag(other)) * ::librapid::cos(real(other)), + -::librapid::sinh(imag(other)) * ::librapid::sin(real(other))); + } + + /// \brief Calculate the cosecant of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \csc(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex csc(const Complex &other) { + return T(1) / sin(other); + } + + /// \brief Calculate the secant of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \sec(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex sec(const Complex &other) { + return T(1) / cos(other); + } + + /// \brief Calculate the cotangent of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \cot(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex cot(const Complex &other) { + return T(1) / tan(other); + } + + /// \brief Calculate the arc cosecant of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \operatorname{arccsc}(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex acsc(const Complex &other) { + return asin(T(1) / other); + } + + /// \brief Calculate the arc secant of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \operatorname{arcsec}(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex asec(const Complex &other) { + return acos(T(1) / other); + } + + /// \brief Calculate the arc cotangent of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \operatorname{arccot}(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex acot(const Complex &other) { + return atan(T(1) / other); + } + + /// \brief Calculate the logarithm base 2 of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \log_2(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex log2(const Complex &other) { + return log(other) / ::librapid::log(T(2)); + } + + /// \brief Calculate the logarithm base 10 of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \log_{10}(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex log10(const Complex &other) { + return log(other) / ::librapid::log(10); + } + + // Return magnitude squared + + /// \brief Calculate the magnitude squared of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ |z|^2 \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T norm(const Complex &other) { + return real(other) * real(other) + imag(other) * imag(other); + } + + /// \brief Return a complex number from polar coordinates + /// + /// Given a radius, \p rho, and an angle, \p theta, this function returns the complex number + /// \f$ \rho e^{i\theta} \f$. + /// + /// The function returns NaN, infinity or zero based on the input values of rho. + /// \tparam T Scalar type of the complex number + /// \param rho Radius of the polar coordinate system + /// \param theta Angle of the polar coordinate system + /// \return Complex number in polar form. + template + LIBRAPID_NODISCARD Complex polar(const T &rho, const T &theta) { + if (!::librapid::isNaN(rho) && !::librapid::isInf(rho) && rho != T(0)) { + // Rho is finite and non-zero + return Complex(rho * ::librapid::cos(theta), rho * ::librapid::sin(theta)); + } + + // Rho is NaN/Inf/0 + if (::librapid::signBit(rho)) + return -polarPositiveNanInfZeroRho(-rho, theta); + else + return polarPositiveNanInfZeroRho(rho, theta); + } + + /// \brief Compute the sine of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \sin(z) \f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex sin(const Complex &other) { + return Complex(::librapid::cosh(imag(other)) * ::librapid::sin(real(other)), + ::librapid::sinh(imag(other)) * ::librapid::cos(real(other))); + } + + /// \brief Compute the tangent of a complex number + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ \tan(z) \f$ + template + LIBRAPID_NODISCARD Complex tan(const Complex &other) { + Complex zv(tanh(Complex(-imag(other), real(other)))); + return Complex(imag(zv), -real(zv)); + } + + /// \brief Round the real and imaginary parts of a complex number towards \f$ -\infty \f$ + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$ (\lfloor\operatorname{real}(z)\rfloor,\lfloor\operatorname{imag}(z)\rfloor )\f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex floor(const Complex &other) { + return Complex(::librapid::floor(real(other)), ::librapid::floor(imag(other))); + } + + /// \brief Round the real and imaginary parts of a complex number towards \f$ +\infty \f$ + /// \tparam T Scalar type + /// \param other Complex number + /// \return \f$(\lceil\operatorname{real}(z)\rceil,\lceil\operatorname{imag}(z)\rceil )\f$ + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex ceil(const Complex &other) { + return Complex(::librapid::ceil(real(other)), ::librapid::ceil(imag(other))); + } + + /// \brief Generate a random complex number between two given complex numbers + /// + /// This function generates a random complex number in the range [min, max], where min + /// and max are given as input. The function uses a default seed if none is provided. + /// + /// \tparam T Scalar type of the complex number + /// \param min Minimum complex number + /// \param max Maximum complex number + /// \param seed Seed for the random number generator + /// \return Random complex number between min and max + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto random(const Complex &min, + const Complex &max, uint64_t seed = -1) + -> Complex { + return Complex(::librapid::random(real(min), real(max), seed), + ::librapid::random(imag(min), imag(max), seed)); + } + + namespace typetraits { + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = Complex; + using Packet = std::false_type; + // typename std::conditional_t<(TypeInfo::packetWidth > 1), + // Complex::Packet>, std::false_type>; + static constexpr int64_t packetWidth = + 0; // TypeInfo::Scalar>::packetWidth; + static constexpr char name[] = "Complex"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = false; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_C_64F; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_C_64F; #endif - static constexpr bool canAlign = TypeInfo::canAlign; - static constexpr bool canMemcpy = TypeInfo::canMemcpy; - - LIMIT_IMPL(min) { return TypeInfo::min(); } - LIMIT_IMPL(max) { return TypeInfo::max(); } - LIMIT_IMPL(epsilon) { return TypeInfo::epsilon(); } - LIMIT_IMPL(roundError) { return TypeInfo::roundError(); } - LIMIT_IMPL(denormMin) { return TypeInfo::denormMin(); } - LIMIT_IMPL(infinity) { return TypeInfo::infinity(); } - LIMIT_IMPL(quietNaN) { return TypeInfo::quietNaN(); } - LIMIT_IMPL(signalingNaN) { return TypeInfo::signalingNaN(); } - }; - } // namespace typetraits + static constexpr bool canAlign = TypeInfo::canAlign; + static constexpr bool canMemcpy = TypeInfo::canMemcpy; + + LIMIT_IMPL(min) { return TypeInfo::min(); } + LIMIT_IMPL(max) { return TypeInfo::max(); } + LIMIT_IMPL(epsilon) { return TypeInfo::epsilon(); } + LIMIT_IMPL(roundError) { return TypeInfo::roundError(); } + LIMIT_IMPL(denormMin) { return TypeInfo::denormMin(); } + LIMIT_IMPL(infinity) { return TypeInfo::infinity(); } + LIMIT_IMPL(quietNaN) { return TypeInfo::quietNaN(); } + LIMIT_IMPL(signalingNaN) { return TypeInfo::signalingNaN(); } + }; + } // namespace typetraits } // namespace librapid // Support FMT printing @@ -2083,11 +2083,11 @@ LIBRAPID_SIMPLE_IO_IMPL(typename Scalar, librapid::Complex) #endif // FMT_API #ifdef USE_X86_X64_INTRINSICS -# undef USE_X86_X64_INTRINSICS +# undef USE_X86_X64_INTRINSICS #endif #ifdef USE_ARM64_INTRINSICS -# undef USE_ARM64_INTRINSICS +# undef USE_ARM64_INTRINSICS #endif #endif // LIBRAPID_MATH_COMPLEX_HPP \ No newline at end of file diff --git a/librapid/include/librapid/math/constants.hpp b/librapid/include/librapid/math/constants.hpp index 85d1b53e..6d5f0604 100644 --- a/librapid/include/librapid/math/constants.hpp +++ b/librapid/include/librapid/math/constants.hpp @@ -4,158 +4,158 @@ namespace librapid { #define CTYPED static const double - /// 32bit float minimum value - CTYPED EPSILON32 = FLT_MIN; + /// 32bit float minimum value + CTYPED EPSILON32 = FLT_MIN; - /// 64bit float minimum value - CTYPED EPSILON64 = DBL_MIN; + /// 64bit float minimum value + CTYPED EPSILON64 = DBL_MIN; - /// PI squared on 6 - CTYPED PISQRDIV6 = 1.6449340668482264364724151666460251892189499012067984377355582293; + /// PI squared on 6 + CTYPED PISQRDIV6 = 1.6449340668482264364724151666460251892189499012067984377355582293; - /// 180/PI -- Radians to Degrees - CTYPED RADTODEG = 57.295779513082320876798154814105170332405472466564321549160243861; + /// 180/PI -- Radians to Degrees + CTYPED RADTODEG = 57.295779513082320876798154814105170332405472466564321549160243861; - /// PI / 180 -- Degrees to Radians - CTYPED DEGTORAD = 0.0174532925199432957692369076848861271344287188854172545609719144; + /// PI / 180 -- Degrees to Radians + CTYPED DEGTORAD = 0.0174532925199432957692369076848861271344287188854172545609719144; - /// PI - CTYPED PI = 3.1415926535897932384626433832795028841971693993751058209749445923; + /// PI + CTYPED PI = 3.1415926535897932384626433832795028841971693993751058209749445923; - /// Sqrt(PI) - CTYPED SQRTPI = 1.7724538509055160272981674833411451827975494561223871282138077898; + /// Sqrt(PI) + CTYPED SQRTPI = 1.7724538509055160272981674833411451827975494561223871282138077898; - /// Tau - CTYPED TAU = 6.2831853071795864769252867665590057683943387987502116419498891846; + /// Tau + CTYPED TAU = 6.2831853071795864769252867665590057683943387987502116419498891846; - /// PI/2 - CTYPED HALFPI = 1.5707963267948966192313216916397514420985846996875529104874722961; + /// PI/2 + CTYPED HALFPI = 1.5707963267948966192313216916397514420985846996875529104874722961; - /// 2*PI - CTYPED TWOPI = 6.2831853071795864769252867665590057683943387987502116419498891846156; + /// 2*PI + CTYPED TWOPI = 6.2831853071795864769252867665590057683943387987502116419498891846156; - /// E - CTYPED EULER = 2.7182818284590452353602874713526624977572470936999595749669676277; + /// E + CTYPED EULER = 2.7182818284590452353602874713526624977572470936999595749669676277; - /// Sqrt(E) - CTYPED SQRTE = 1.6487212707001281468486507878141635716537761007101480115750793116; + /// Sqrt(E) + CTYPED SQRTE = 1.6487212707001281468486507878141635716537761007101480115750793116; - /// Sqrt(2) - CTYPED SQRT2 = 1.4142135623730950488016887242096980785696718753769480731766797379; + /// Sqrt(2) + CTYPED SQRT2 = 1.4142135623730950488016887242096980785696718753769480731766797379; - /// Sqrt(3) - CTYPED SQRT3 = 1.7320508075688772935274463415058723669428052538103806280558069794; + /// Sqrt(3) + CTYPED SQRT3 = 1.7320508075688772935274463415058723669428052538103806280558069794; - /// Sqrt(5) - CTYPED SQRT5 = 2.2360679774997896964091736687312762354406183596115257242708972454; + /// Sqrt(5) + CTYPED SQRT5 = 2.2360679774997896964091736687312762354406183596115257242708972454; - /// Golden Ratio - CTYPED GOLDENRATIO = 1.6180339887498948482045868343656381177203091798057628621354486227; + /// Golden Ratio + CTYPED GOLDENRATIO = 1.6180339887498948482045868343656381177203091798057628621354486227; - /// Euler-Mascheroni constant - CTYPED EULERMASCHERONI = 0.5772156649015328606065120900824024310421593359399235988057672348; + /// Euler-Mascheroni constant + CTYPED EULERMASCHERONI = 0.5772156649015328606065120900824024310421593359399235988057672348; - /// Twin Primes Constant - CTYPED TWINPRIMES = 0.6601618158468695739278121100145557784326233602847334133194484233; + /// Twin Primes Constant + CTYPED TWINPRIMES = 0.6601618158468695739278121100145557784326233602847334133194484233; - /// Ln(2) - CTYPED LN2 = 0.6931471805599453094172321214581765680755001343602552541206800094; + /// Ln(2) + CTYPED LN2 = 0.6931471805599453094172321214581765680755001343602552541206800094; - /// Ln(3) - CTYPED LN3 = 1.0986122886681096913952452369225257046474905578227494517346943336; + /// Ln(3) + CTYPED LN3 = 1.0986122886681096913952452369225257046474905578227494517346943336; - /// Ln(5) - CTYPED LN5 = 1.6094379124341003746007593332261876395256013542685177219126478914; + /// Ln(5) + CTYPED LN5 = 1.6094379124341003746007593332261876395256013542685177219126478914; - /// Zeta(3) - CTYPED ZETA3 = 1.2020569031595942853997381615114499907649862923404988817922715553; + /// Zeta(3) + CTYPED ZETA3 = 1.2020569031595942853997381615114499907649862923404988817922715553; - /// CubeRoot(2) - CTYPED CUBEROOT2 = 1.2599210498948731647672106072782283505702514647015079800819751121; + /// CubeRoot(2) + CTYPED CUBEROOT2 = 1.2599210498948731647672106072782283505702514647015079800819751121; - /// CubeRoot(3) - CTYPED CUBEROOT3 = 1.4422495703074083823216383107801095883918692534993505775464161945; + /// CubeRoot(3) + CTYPED CUBEROOT3 = 1.4422495703074083823216383107801095883918692534993505775464161945; - /// Speed of Light - CTYPED LIGHTSPEED = 299792458.0; + /// Speed of Light + CTYPED LIGHTSPEED = 299792458.0; - /// Acceleration due to gravity on Earth - CTYPED EARTHGRAVITY = 9.80665; + /// Acceleration due to gravity on Earth + CTYPED EARTHGRAVITY = 9.80665; - /// Wallis Constant - CTYPED WALLISCONST = 2.0945514815423265914823865405793029638573061056282391803041285290; + /// Wallis Constant + CTYPED WALLISCONST = 2.0945514815423265914823865405793029638573061056282391803041285290; - /// Laplace limit - CTYPED LAPLACELIMIT = 0.6627434193491815809747420971092529070562335491150224175203925349; + /// Laplace limit + CTYPED LAPLACELIMIT = 0.6627434193491815809747420971092529070562335491150224175203925349; - /// Gauss's constant - CTYPED GAUSSCONST = 0.8346268416740731862814297327990468089939930134903470024498273701; + /// Gauss's constant + CTYPED GAUSSCONST = 0.8346268416740731862814297327990468089939930134903470024498273701; - /// Cahen's constant - CTYPED CAHENSCONST = 0.6434105462883380261822543077575647632865878602682395059870309203; + /// Cahen's constant + CTYPED CAHENSCONST = 0.6434105462883380261822543077575647632865878602682395059870309203; - /// Parabolic constant -- P_2 - CTYPED P2 = 2.2955871493926380740342980491894903875978322036385834839299753466; + /// Parabolic constant -- P_2 + CTYPED P2 = 2.2955871493926380740342980491894903875978322036385834839299753466; - /// Dottie number - CTYPED DOTTIENUMBER = 0.7390851332151606416553120876738734040134117589007574649656806357; + /// Dottie number + CTYPED DOTTIENUMBER = 0.7390851332151606416553120876738734040134117589007574649656806357; - /// Meissel-Mertens constant - CTYPED MEISSELMERTENS = 0.2614972128476427837554268386086958590515666482611992061920642139; + /// Meissel-Mertens constant + CTYPED MEISSELMERTENS = 0.2614972128476427837554268386086958590515666482611992061920642139; - /// E^PI -- Gelfond's constant - CTYPED ETOPI = 23.140692632779269005729086367948547380266106242600211993445046409; + /// E^PI -- Gelfond's constant + CTYPED ETOPI = 23.140692632779269005729086367948547380266106242600211993445046409; - /// Golden angle - CTYPED GOLDENANGLE = 2.3999632297286533222315555066336138531249990110581150429351127507; + /// Golden angle + CTYPED GOLDENANGLE = 2.3999632297286533222315555066336138531249990110581150429351127507; - /// Area of the Mandelbrot fractal. - CTYPED MANDELBROTAREA = 1.5065918849; + /// Area of the Mandelbrot fractal. + CTYPED MANDELBROTAREA = 1.5065918849; - /// Gieseking constant. - CTYPED GIESEKINGCONST = 1.0149416064096536250212025542745202859416893075302997920174891067; + /// Gieseking constant. + CTYPED GIESEKINGCONST = 1.0149416064096536250212025542745202859416893075302997920174891067; - /// Bloch-Landau constant. - CTYPED BLOCHLANDAU = 0.5432589653429767069527282953006132311388632937583569889557325691; + /// Bloch-Landau constant. + CTYPED BLOCHLANDAU = 0.5432589653429767069527282953006132311388632937583569889557325691; - /// Golomb-Dickman constant. - CTYPED GOLOMBDICKMAN = 0.6243299885435508709929363831008372441796426201805292869735519024; + /// Golomb-Dickman constant. + CTYPED GOLOMBDICKMAN = 0.6243299885435508709929363831008372441796426201805292869735519024; - /// Feller-Tornier constant. - CTYPED FELLERTORNIER = 0.6613170494696223352897658462741185332854752898329; + /// Feller-Tornier constant. + CTYPED FELLERTORNIER = 0.6613170494696223352897658462741185332854752898329; - /// 2^Sqrt(2) - CTYPED TWOTOSQRT2 = 2.6651441426902251886502972498731398482742113137146594928359795933; + /// 2^Sqrt(2) + CTYPED TWOTOSQRT2 = 2.6651441426902251886502972498731398482742113137146594928359795933; - /// Khinchin's constant - CTYPED KHINCHINSCONST = 2.6854520010653064453097148354817956938203822939944629530511523455; + /// Khinchin's constant + CTYPED KHINCHINSCONST = 2.6854520010653064453097148354817956938203822939944629530511523455; - /// Mill's constant - CTYPED MILLSCONST = 1.3063778838630806904686144926026057129167845851567136443680537599; + /// Mill's constant + CTYPED MILLSCONST = 1.3063778838630806904686144926026057129167845851567136443680537599; - /// PI/Ln(2) - CTYPED PIOVERLN2 = 4.5323601418271938096276829457166668101718614677237955841860165479; + /// PI/Ln(2) + CTYPED PIOVERLN2 = 4.5323601418271938096276829457166668101718614677237955841860165479; - /// Loch's constant - CTYPED LOCHSCONST = 0.9702701143920339257402560192100108337812847047851612866103505299; + /// Loch's constant + CTYPED LOCHSCONST = 0.9702701143920339257402560192100108337812847047851612866103505299; - /// Niven's constant - CTYPED NIVENSCONST = 1.7052111401053677642885514534345081607620276516534690999942849065; + /// Niven's constant + CTYPED NIVENSCONST = 1.7052111401053677642885514534345081607620276516534690999942849065; - /// Reciprocal Fibonacci constant - CTYPED RECIPFIBCONST = 3.3598856662431775531720113029189271796889051337319684864955538153; + /// Reciprocal Fibonacci constant + CTYPED RECIPFIBCONST = 3.3598856662431775531720113029189271796889051337319684864955538153; - /// Backhouse's constant - CTYPED BACKHOUSECONST = 1.4560749485826896713995953511165435576531783748471315402707024374; + /// Backhouse's constant + CTYPED BACKHOUSECONST = 1.4560749485826896713995953511165435576531783748471315402707024374; - /// MRB constant - CTYPED MRBCONST = 0.1878596424620671202485179340542732300559030949001387861720046840; + /// MRB constant + CTYPED MRBCONST = 0.1878596424620671202485179340542732300559030949001387861720046840; - /// Somos' quadratic recurrence constant - CTYPED QUADRECURR = 1.6616879496335941212958189227499507499644186350250682081897111680; + /// Somos' quadratic recurrence constant + CTYPED QUADRECURR = 1.6616879496335941212958189227499507499644186350250682081897111680; - /// Plastic number - CTYPED PLASTICNUMBER = 1.3247179572447460259609088544780973407344040569017333645340150503; + /// Plastic number + CTYPED PLASTICNUMBER = 1.3247179572447460259609088544780973407344040569017333645340150503; } // namespace librapid #endif // LIBRAPID_MATH_CONSTANTS diff --git a/librapid/include/librapid/math/coreMath.hpp b/librapid/include/librapid/math/coreMath.hpp index 7ace3ee2..91ac1e2d 100644 --- a/librapid/include/librapid/math/coreMath.hpp +++ b/librapid/include/librapid/math/coreMath.hpp @@ -9,450 +9,450 @@ */ namespace librapid { - namespace detail { - template - struct ContainsArrayType; - } // namespace detail - - /// Return the smallest value of a given set of values - /// \tparam T Data type - /// \param val Input set - /// \return Smallest element of the input set - template - T &&min(T &&val) { - return std::forward(val); - } - - /// Return the smallest value of a given set of values - /// \tparam Types Data types of the input values - /// \param vals Input values - /// \return The smallest element of the input values - template - auto min(T0 &&val1, T1 &&val2, Ts &&...vs) { - return (val1 < val2) ? min(val1, std::forward(vs)...) - : min(val2, std::forward(vs)...); - } - - /// Return the largest value of a given set of values - /// \tparam T Data type - /// \param val Input set - /// \return Largest element of the input set - template - T &&max(T &&val) { - return std::forward(val); - } - - /// Return the largest value of a given set of values - /// \tparam Types Data types of the input values - /// \param vals Input values - /// \return The largest element of the input values - template - auto max(T0 &&val1, T1 &&val2, Ts &&...vs) { - return (val1 > val2) ? max(val1, std::forward(vs)...) - : max(val2, std::forward(vs)...); - } - - /// Return the absolute value of a given value - /// \tparam T Data type - /// \param val Input value - /// \return Absolute value of the input value - template, int> = 0> - constexpr T abs(T val) { - return (val < T(0)) ? -val : val; - } - - /// Map a value from one range to another - /// \tparam V Data type of the value to map - /// \tparam B1 Data type of the lower bound of the input range - /// \tparam E1 Data type of the upper bound of the input range - /// \tparam B2 Data type of the lower bound of the output range - /// \tparam E2 Data type of the upper bound of the output range - /// \param val Value to map - /// \param start1 Lower bound of the input range - /// \param stop1 Upper bound of the input range - /// \param start2 Lower bound of the output range - /// \param stop2 Upper bound of the output range - /// \return Mapped value - template - LIBRAPID_INLINE auto map(const V &val, const B1 &start1, const E1 &stop1, const B2 &start2, - const E2 &stop2) { - return start2 + (val - start1) * (stop2 - start2) / (stop1 - start1); - - // if constexpr (detail::ContainsArrayType::val) { - // return start2 + (val - start1) * (stop2 - start2) / (stop1 - start1); - // } else { - // using T = decltype((val - start1) * (stop2 - start2) / (stop1 - start1) + start2); - // return static_cast(start2) + (static_cast(stop2) - static_cast(start2)) * - // ((static_cast(val) - static_cast(start1)) / - // (static_cast(stop1) - static_cast(start1))); - // } - } - - template - LIBRAPID_INLINE auto mod(const T1 &val, const T2 &mod) { - if constexpr (std::is_floating_point_v || std::is_floating_point_v) { - return std::fmod(val, mod); - } else { - return val % mod; - } - } - - /// Return the floor of a given value - /// \tparam T Data type - /// \param val Input value - /// \return Floor of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr T floor(T val) { - return std::floor(val); - } - - /// Return the ceiling of a given value - /// \tparam T Data type - /// \param val Input value - /// \return Ceiling of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr T ceil(T val) { - return std::ceil(val); - } - - /// Return the square root of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the square root. - /// \tparam T Data type - /// \param val Input value - /// \return Square root of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto sqrt(T val) { - if constexpr (std::is_integral_v) { - return std::sqrt(static_cast(val)); - } else { - return std::sqrt(val); - } - } - - /// Return the hypotenuse of a right triangle given the lengths of the two legs. Note that, - /// for integer values, this function will cast the input values to a floating point type - /// before calculating the hypotenuse. - /// \tparam T Data type - /// \param leg1 Length of the first leg - /// \param leg2 Length of the second leg - /// \return Hypotenuse of the right triangle - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto hypot(T leg1, T leg2) { - if constexpr (std::is_integral_v) { - return std::hypot(static_cast(leg1), static_cast(leg2)); - } else { - return std::hypot(leg1, leg2); - } - } - - /// Return the cube root of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the cube root. - /// \tparam T Data type - /// \param val Input value - /// \return Cube root of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto cbrt(T val) { - if constexpr (std::is_integral_v) { - return std::cbrt(static_cast(val)); - } else { - return std::cbrt(val); - } - } - - /// Return the first number raised to the power of the second number. The return value will be - /// promoted to the larger of the two input types. - /// \tparam T0 Data type of the first input value - /// \tparam T1 Data type of the second input value - /// \param val1 First input value - /// \param val2 Second input value - /// \return First input value raised to the power of the second input value - template< - typename T0, typename T1, - typename std::enable_if_t && std::is_fundamental_v, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto pow(T0 val1, T1 val2) { - if constexpr (std::is_integral_v && std::is_integral_v) { - return std::pow(static_cast(val1), static_cast(val2)); - } else if constexpr (std::is_integral_v) { - return std::pow(static_cast(val1), val2); - } else if constexpr (std::is_integral_v) { - return std::pow(val1, static_cast(val2)); - } else { - return std::pow(val1, val2); - } - } - - /// Return the exponential of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the exponential. - /// \tparam T Data type - /// \param val Input value - /// \return Exponential of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto exp(T val) { - if constexpr (std::is_integral_v) { - return std::exp(static_cast(val)); - } else { - return std::exp(val); - } - } - - /// Return 2 raised to a given power. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the exponential. - /// \tparam T Data type - /// \param val Input value - /// \return 2 raised to the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto exp2(T val) { - if constexpr (std::is_integral_v) { - return std::exp2(static_cast(val)); - } else { - return std::exp2(val); - } - } - - // Return 10 raised to a given power. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the exponential. - /// \tparam T Data type - /// \param val Input value - /// \return 10 raised to the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto exp10(T val) { - // C++ standard does not implement exp10 - - if constexpr (std::is_integral_v) { - return std::pow(10.0, static_cast(val)); - } else { - return std::pow(10.0, val); - } - } - - /// Return the natural logarithm of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the logarithm. - /// \tparam T Data type - /// \param val Input value - /// \return Natural logarithm of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto log(T val) { - if constexpr (std::is_integral_v) { - return std::log(static_cast(val)); - } else { - return std::log(val); - } - } - - /// Return the logarithm base-10 of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the logarithm. - /// \tparam T Data type - /// \param val Input value - /// \return Logarithm base-10 of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto log10(T val) { - if constexpr (std::is_integral_v) { - return std::log10(static_cast(val)); - } else { - return std::log10(val); - } - } - - /// Return the logarithm base-2 of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the logarithm. - /// \tparam T Data type - /// \param val Input value - /// \return Logarithm base-2 of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto log2(T val) { - if constexpr (std::is_integral_v) { - return std::log2(static_cast(val)); - } else { - return std::log2(val); - } - } - - /// Return the sine of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the sine. - /// \tparam T Data type - /// \param val Input value - /// \return Sine of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto sin(T val) { - if constexpr (std::is_integral_v) { - return std::sin(static_cast(val)); - } else { - return std::sin(val); - } - } - - /// Return the cosine of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the cosine. - /// \tparam T Data type - /// \param val Input value - /// \return Cosine of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto cos(T val) { - if constexpr (std::is_integral_v) { - return std::cos(static_cast(val)); - } else { - return std::cos(val); - } - } - - /// Return the tangent of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the tangent. - /// \tparam T Data type - /// \param val Input value - /// \return Tangent of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto tan(T val) { - if constexpr (std::is_integral_v) { - return std::tan(static_cast(val)); - } else { - return std::tan(val); - } - } - - /// Return the arcsine of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the arcsine. - /// \tparam T Data type - /// \param val Input value - /// \return Arcsine of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto asin(T val) { - if constexpr (std::is_integral_v) { - return std::asin(static_cast(val)); - } else { - return std::asin(val); - } - } - - /// Return the arccosine of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the arccosine. - /// \tparam T Data type - /// \param val Input value - /// \return Arccosine of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto acos(T val) { - if constexpr (std::is_integral_v) { - return std::acos(static_cast(val)); - } else { - return std::acos(val); - } - } - - /// Return the arctangent of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before calculating the arctangent. - /// \tparam T Data type - /// \param val Input value - /// \return Arctangent of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto atan(T val) { - if constexpr (std::is_integral_v) { - return std::atan(static_cast(val)); - } else { - return std::atan(val); - } - } - - /// Return the angle formed by a given y and x offset. This is often more useful than using - /// atan, since it gives more usable outputs. Note that, for integer values, this function - /// will cast the input values to a floating point type before calculating the angle. - /// \tparam TY Data type of the y offset - /// \tparam TX Data type of the x offset - /// \param dy Y offset - /// \param dx X offset - /// \return Angle formed by the given offsets - template< - typename TY, typename TX, - typename std::enable_if_t && std::is_fundamental_v, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto atan2(TY dy, TX dx) { - if constexpr (std::is_integral_v || std::is_integral_v) { - return std::atan2(static_cast(dy), static_cast(dx)); - } else { - return std::atan2(dy, dx); - } - } - - /// Return the hyperbolic sin of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before computing the result. - /// \tparam T Data type - /// \param val Input value - /// \return Hyperbolic sine of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto sinh(T val) { - if constexpr (std::is_integral_v) { - return std::sinh(static_cast(val)); - } else { - return std::sinh(val); - } - } - - /// Return the hyperbolic cosine of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before computing the result. - /// \tparam T Data type - /// \param val Input value - /// \return Hyperbolic cosine of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto cosh(T val) { - if constexpr (std::is_integral_v) { - return std::cosh(static_cast(val)); - } else { - return std::cosh(val); - } - } - - /// Return the hyperbolic tangent of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before computing the result. - /// \tparam T Data type - /// \param val Input value - /// \return Hyperbolic tangent of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto tanh(T val) { - if constexpr (std::is_integral_v) { - return std::tanh(static_cast(val)); - } else { - return std::tanh(val); - } - } - - /// Return the hyperbolic arcsine of a given value. Note that, for integer values, this function - /// will cast the input value to a floating point type before computing the result. - /// \tparam T Data type - /// \param val Input value - /// \return Hyperbolic arcsine of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto asinh(T val) { - if constexpr (std::is_integral_v) { - return std::asinh(static_cast(val)); - } else { - return std::asinh(val); - } - } - - /// Return the hyperbolic arccosine of a given value. Note that, for integer values, this - /// function will cast the input value to a floating point type before computing the result. - /// \tparam T Data type - /// \param val Input value - /// \return Hyperbolic arccosine of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto acosh(T val) { - if constexpr (std::is_integral_v) { - return std::acosh(static_cast(val)); - } else { - return std::acosh(val); - } - } - - /// Return the hyperbolic arctangent of a given value. Note that, for integer values, this - /// function will cast the input value to a floating point type before computing the result. - /// \tparam T Data type - /// \param val Input value - /// \return Hyperbolic arctangent of the input value - template, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto atanh(T val) { - if constexpr (std::is_integral_v) { - return std::atanh(static_cast(val)); - } else { - return std::atanh(val); - } - } + namespace detail { + template + struct ContainsArrayType; + } // namespace detail + + /// Return the smallest value of a given set of values + /// \tparam T Data type + /// \param val Input set + /// \return Smallest element of the input set + template + T &&min(T &&val) { + return std::forward(val); + } + + /// Return the smallest value of a given set of values + /// \tparam Types Data types of the input values + /// \param vals Input values + /// \return The smallest element of the input values + template + auto min(T0 &&val1, T1 &&val2, Ts &&...vs) { + return (val1 < val2) ? min(val1, std::forward(vs)...) + : min(val2, std::forward(vs)...); + } + + /// Return the largest value of a given set of values + /// \tparam T Data type + /// \param val Input set + /// \return Largest element of the input set + template + T &&max(T &&val) { + return std::forward(val); + } + + /// Return the largest value of a given set of values + /// \tparam Types Data types of the input values + /// \param vals Input values + /// \return The largest element of the input values + template + auto max(T0 &&val1, T1 &&val2, Ts &&...vs) { + return (val1 > val2) ? max(val1, std::forward(vs)...) + : max(val2, std::forward(vs)...); + } + + /// Return the absolute value of a given value + /// \tparam T Data type + /// \param val Input value + /// \return Absolute value of the input value + template, int> = 0> + constexpr T abs(T val) { + return (val < T(0)) ? -val : val; + } + + /// Map a value from one range to another + /// \tparam V Data type of the value to map + /// \tparam B1 Data type of the lower bound of the input range + /// \tparam E1 Data type of the upper bound of the input range + /// \tparam B2 Data type of the lower bound of the output range + /// \tparam E2 Data type of the upper bound of the output range + /// \param val Value to map + /// \param start1 Lower bound of the input range + /// \param stop1 Upper bound of the input range + /// \param start2 Lower bound of the output range + /// \param stop2 Upper bound of the output range + /// \return Mapped value + template + LIBRAPID_INLINE auto map(const V &val, const B1 &start1, const E1 &stop1, const B2 &start2, + const E2 &stop2) { + return start2 + (val - start1) * (stop2 - start2) / (stop1 - start1); + + // if constexpr (detail::ContainsArrayType::val) { + // return start2 + (val - start1) * (stop2 - start2) / (stop1 - start1); + // } else { + // using T = decltype((val - start1) * (stop2 - start2) / (stop1 - start1) + start2); + // return static_cast(start2) + (static_cast(stop2) - static_cast(start2)) * + // ((static_cast(val) - static_cast(start1)) / + // (static_cast(stop1) - static_cast(start1))); + // } + } + + template + LIBRAPID_INLINE auto mod(const T1 &val, const T2 &mod) { + if constexpr (std::is_floating_point_v || std::is_floating_point_v) { + return std::fmod(val, mod); + } else { + return val % mod; + } + } + + /// Return the floor of a given value + /// \tparam T Data type + /// \param val Input value + /// \return Floor of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr T floor(T val) { + return std::floor(val); + } + + /// Return the ceiling of a given value + /// \tparam T Data type + /// \param val Input value + /// \return Ceiling of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr T ceil(T val) { + return std::ceil(val); + } + + /// Return the square root of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the square root. + /// \tparam T Data type + /// \param val Input value + /// \return Square root of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto sqrt(T val) { + if constexpr (std::is_integral_v) { + return std::sqrt(static_cast(val)); + } else { + return std::sqrt(val); + } + } + + /// Return the hypotenuse of a right triangle given the lengths of the two legs. Note that, + /// for integer values, this function will cast the input values to a floating point type + /// before calculating the hypotenuse. + /// \tparam T Data type + /// \param leg1 Length of the first leg + /// \param leg2 Length of the second leg + /// \return Hypotenuse of the right triangle + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto hypot(T leg1, T leg2) { + if constexpr (std::is_integral_v) { + return std::hypot(static_cast(leg1), static_cast(leg2)); + } else { + return std::hypot(leg1, leg2); + } + } + + /// Return the cube root of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the cube root. + /// \tparam T Data type + /// \param val Input value + /// \return Cube root of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto cbrt(T val) { + if constexpr (std::is_integral_v) { + return std::cbrt(static_cast(val)); + } else { + return std::cbrt(val); + } + } + + /// Return the first number raised to the power of the second number. The return value will be + /// promoted to the larger of the two input types. + /// \tparam T0 Data type of the first input value + /// \tparam T1 Data type of the second input value + /// \param val1 First input value + /// \param val2 Second input value + /// \return First input value raised to the power of the second input value + template< + typename T0, typename T1, + typename std::enable_if_t && std::is_fundamental_v, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto pow(T0 val1, T1 val2) { + if constexpr (std::is_integral_v && std::is_integral_v) { + return std::pow(static_cast(val1), static_cast(val2)); + } else if constexpr (std::is_integral_v) { + return std::pow(static_cast(val1), val2); + } else if constexpr (std::is_integral_v) { + return std::pow(val1, static_cast(val2)); + } else { + return std::pow(val1, val2); + } + } + + /// Return the exponential of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the exponential. + /// \tparam T Data type + /// \param val Input value + /// \return Exponential of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto exp(T val) { + if constexpr (std::is_integral_v) { + return std::exp(static_cast(val)); + } else { + return std::exp(val); + } + } + + /// Return 2 raised to a given power. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the exponential. + /// \tparam T Data type + /// \param val Input value + /// \return 2 raised to the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto exp2(T val) { + if constexpr (std::is_integral_v) { + return std::exp2(static_cast(val)); + } else { + return std::exp2(val); + } + } + + // Return 10 raised to a given power. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the exponential. + /// \tparam T Data type + /// \param val Input value + /// \return 10 raised to the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto exp10(T val) { + // C++ standard does not implement exp10 + + if constexpr (std::is_integral_v) { + return std::pow(10.0, static_cast(val)); + } else { + return std::pow(10.0, val); + } + } + + /// Return the natural logarithm of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the logarithm. + /// \tparam T Data type + /// \param val Input value + /// \return Natural logarithm of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto log(T val) { + if constexpr (std::is_integral_v) { + return std::log(static_cast(val)); + } else { + return std::log(val); + } + } + + /// Return the logarithm base-10 of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the logarithm. + /// \tparam T Data type + /// \param val Input value + /// \return Logarithm base-10 of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto log10(T val) { + if constexpr (std::is_integral_v) { + return std::log10(static_cast(val)); + } else { + return std::log10(val); + } + } + + /// Return the logarithm base-2 of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the logarithm. + /// \tparam T Data type + /// \param val Input value + /// \return Logarithm base-2 of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto log2(T val) { + if constexpr (std::is_integral_v) { + return std::log2(static_cast(val)); + } else { + return std::log2(val); + } + } + + /// Return the sine of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the sine. + /// \tparam T Data type + /// \param val Input value + /// \return Sine of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto sin(T val) { + if constexpr (std::is_integral_v) { + return std::sin(static_cast(val)); + } else { + return std::sin(val); + } + } + + /// Return the cosine of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the cosine. + /// \tparam T Data type + /// \param val Input value + /// \return Cosine of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto cos(T val) { + if constexpr (std::is_integral_v) { + return std::cos(static_cast(val)); + } else { + return std::cos(val); + } + } + + /// Return the tangent of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the tangent. + /// \tparam T Data type + /// \param val Input value + /// \return Tangent of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto tan(T val) { + if constexpr (std::is_integral_v) { + return std::tan(static_cast(val)); + } else { + return std::tan(val); + } + } + + /// Return the arcsine of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the arcsine. + /// \tparam T Data type + /// \param val Input value + /// \return Arcsine of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto asin(T val) { + if constexpr (std::is_integral_v) { + return std::asin(static_cast(val)); + } else { + return std::asin(val); + } + } + + /// Return the arccosine of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the arccosine. + /// \tparam T Data type + /// \param val Input value + /// \return Arccosine of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto acos(T val) { + if constexpr (std::is_integral_v) { + return std::acos(static_cast(val)); + } else { + return std::acos(val); + } + } + + /// Return the arctangent of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before calculating the arctangent. + /// \tparam T Data type + /// \param val Input value + /// \return Arctangent of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto atan(T val) { + if constexpr (std::is_integral_v) { + return std::atan(static_cast(val)); + } else { + return std::atan(val); + } + } + + /// Return the angle formed by a given y and x offset. This is often more useful than using + /// atan, since it gives more usable outputs. Note that, for integer values, this function + /// will cast the input values to a floating point type before calculating the angle. + /// \tparam TY Data type of the y offset + /// \tparam TX Data type of the x offset + /// \param dy Y offset + /// \param dx X offset + /// \return Angle formed by the given offsets + template< + typename TY, typename TX, + typename std::enable_if_t && std::is_fundamental_v, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto atan2(TY dy, TX dx) { + if constexpr (std::is_integral_v || std::is_integral_v) { + return std::atan2(static_cast(dy), static_cast(dx)); + } else { + return std::atan2(dy, dx); + } + } + + /// Return the hyperbolic sin of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before computing the result. + /// \tparam T Data type + /// \param val Input value + /// \return Hyperbolic sine of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto sinh(T val) { + if constexpr (std::is_integral_v) { + return std::sinh(static_cast(val)); + } else { + return std::sinh(val); + } + } + + /// Return the hyperbolic cosine of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before computing the result. + /// \tparam T Data type + /// \param val Input value + /// \return Hyperbolic cosine of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto cosh(T val) { + if constexpr (std::is_integral_v) { + return std::cosh(static_cast(val)); + } else { + return std::cosh(val); + } + } + + /// Return the hyperbolic tangent of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before computing the result. + /// \tparam T Data type + /// \param val Input value + /// \return Hyperbolic tangent of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto tanh(T val) { + if constexpr (std::is_integral_v) { + return std::tanh(static_cast(val)); + } else { + return std::tanh(val); + } + } + + /// Return the hyperbolic arcsine of a given value. Note that, for integer values, this function + /// will cast the input value to a floating point type before computing the result. + /// \tparam T Data type + /// \param val Input value + /// \return Hyperbolic arcsine of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto asinh(T val) { + if constexpr (std::is_integral_v) { + return std::asinh(static_cast(val)); + } else { + return std::asinh(val); + } + } + + /// Return the hyperbolic arccosine of a given value. Note that, for integer values, this + /// function will cast the input value to a floating point type before computing the result. + /// \tparam T Data type + /// \param val Input value + /// \return Hyperbolic arccosine of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto acosh(T val) { + if constexpr (std::is_integral_v) { + return std::acosh(static_cast(val)); + } else { + return std::acosh(val); + } + } + + /// Return the hyperbolic arctangent of a given value. Note that, for integer values, this + /// function will cast the input value to a floating point type before computing the result. + /// \tparam T Data type + /// \param val Input value + /// \return Hyperbolic arctangent of the input value + template, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto atanh(T val) { + if constexpr (std::is_integral_v) { + return std::atanh(static_cast(val)); + } else { + return std::atanh(val); + } + } } // namespace librapid #endif // LIBRAPID_MATH_CORE_MATH_HPP diff --git a/librapid/include/librapid/math/fastMath.hpp b/librapid/include/librapid/math/fastMath.hpp index 290a8e43..407640b3 100644 --- a/librapid/include/librapid/math/fastMath.hpp +++ b/librapid/include/librapid/math/fastMath.hpp @@ -2,7 +2,7 @@ #define LIBRAPID_MATH_FAST_MATH_HPP namespace librapid::fastmath { - double pow10(int64_t exponent); + double pow10(int64_t exponent); } // namespace librapid::fastmath #endif // LIBRAPID_MATH_FAST_MATH_HPP \ No newline at end of file diff --git a/librapid/include/librapid/math/half.hpp b/librapid/include/librapid/math/half.hpp index da98f3fd..c5f836f3 100644 --- a/librapid/include/librapid/math/half.hpp +++ b/librapid/include/librapid/math/half.hpp @@ -8,755 +8,755 @@ // namespace librapid { - namespace detail { - union float16_t { - uint16_t m_bits; - struct { - uint16_t m_frac : 10; - uint16_t m_exp : 5; - uint16_t m_sign : 1; - } m_ieee; - }; - - union float32_t { - uint32_t m_bits; - struct { - uint32_t m_frac : 23; - uint32_t m_exp : 8; - uint32_t m_sign : 1; - } m_ieee; - float m_float; - }; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint32_t - uint32Sels(uint32_t test, uint32_t a, uint32_t b) noexcept { - const uint32_t mask = (((std::int32_t)test) >> 31); - const uint32_t sel_a = (a & mask); - const uint32_t sel_b = (b & ~mask); - const uint32_t result = (sel_a | sel_b); - return (result); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint32_t - uint32Selb(uint32_t mask, uint32_t a, uint32_t b) noexcept { - const uint32_t sel_a = (a & mask); - const uint32_t sel_b = (b & ~mask); - const uint32_t result = (sel_a | sel_b); - return (result); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint16_t - uint16Sels(uint16_t test, uint16_t a, uint16_t b) noexcept { - const uint16_t mask = (((int16_t)test) >> 15); - const uint16_t sel_a = (a & mask); - const uint16_t sel_b = (b & ~mask); - const uint16_t result = (sel_a | sel_b); - return (result); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint32_t - uint32Cntlz(uint32_t x) noexcept { + namespace detail { + union float16_t { + uint16_t m_bits; + struct { + uint16_t m_frac : 10; + uint16_t m_exp : 5; + uint16_t m_sign : 1; + } m_ieee; + }; + + union float32_t { + uint32_t m_bits; + struct { + uint32_t m_frac : 23; + uint32_t m_exp : 8; + uint32_t m_sign : 1; + } m_ieee; + float m_float; + }; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint32_t + uint32Sels(uint32_t test, uint32_t a, uint32_t b) noexcept { + const uint32_t mask = (((std::int32_t)test) >> 31); + const uint32_t sel_a = (a & mask); + const uint32_t sel_b = (b & ~mask); + const uint32_t result = (sel_a | sel_b); + return (result); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint32_t + uint32Selb(uint32_t mask, uint32_t a, uint32_t b) noexcept { + const uint32_t sel_a = (a & mask); + const uint32_t sel_b = (b & ~mask); + const uint32_t result = (sel_a | sel_b); + return (result); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint16_t + uint16Sels(uint16_t test, uint16_t a, uint16_t b) noexcept { + const uint16_t mask = (((int16_t)test) >> 15); + const uint16_t sel_a = (a & mask); + const uint16_t sel_b = (b & ~mask); + const uint16_t result = (sel_a | sel_b); + return (result); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint32_t + uint32Cntlz(uint32_t x) noexcept { #if defined(LIBRAPID_GNU_CXX) - uint32_t is_x_nez_msb = (-x); - uint32_t nlz = __builtin_clz(x); - uint32_t result = _uint32_sels(is_x_nez_msb, nlz, 0x00000020); - return (result); + uint32_t is_x_nez_msb = (-x); + uint32_t nlz = __builtin_clz(x); + uint32_t result = _uint32_sels(is_x_nez_msb, nlz, 0x00000020); + return (result); #else - const uint32_t x0 = (x >> 1); - const uint32_t x1 = (x | x0); - const uint32_t x2 = (x1 >> 2); - const uint32_t x3 = (x1 | x2); - const uint32_t x4 = (x3 >> 4); - const uint32_t x5 = (x3 | x4); - const uint32_t x6 = (x5 >> 8); - const uint32_t x7 = (x5 | x6); - const uint32_t x8 = (x7 >> 16); - const uint32_t x9 = (x7 | x8); - const uint32_t xA = (~x9); - const uint32_t xB = (xA >> 1); - const uint32_t xC = (xB & 0x55555555); - const uint32_t xD = (xA - xC); - const uint32_t xE = (xD & 0x33333333); - const uint32_t xF = (xD >> 2); - const uint32_t x10 = (xF & 0x33333333); - const uint32_t x11 = (xE + x10); - const uint32_t x12 = (x11 >> 4); - const uint32_t x13 = (x11 + x12); - const uint32_t x14 = (x13 & 0x0f0f0f0f); - const uint32_t x15 = (x14 >> 8); - const uint32_t x16 = (x14 + x15); - const uint32_t x17 = (x16 >> 16); - const uint32_t x18 = (x16 + x17); - const uint32_t x19 = (x18 & 0x0000003f); - return (x19); + const uint32_t x0 = (x >> 1); + const uint32_t x1 = (x | x0); + const uint32_t x2 = (x1 >> 2); + const uint32_t x3 = (x1 | x2); + const uint32_t x4 = (x3 >> 4); + const uint32_t x5 = (x3 | x4); + const uint32_t x6 = (x5 >> 8); + const uint32_t x7 = (x5 | x6); + const uint32_t x8 = (x7 >> 16); + const uint32_t x9 = (x7 | x8); + const uint32_t xA = (~x9); + const uint32_t xB = (xA >> 1); + const uint32_t xC = (xB & 0x55555555); + const uint32_t xD = (xA - xC); + const uint32_t xE = (xD & 0x33333333); + const uint32_t xF = (xD >> 2); + const uint32_t x10 = (xF & 0x33333333); + const uint32_t x11 = (xE + x10); + const uint32_t x12 = (x11 >> 4); + const uint32_t x13 = (x11 + x12); + const uint32_t x14 = (x13 & 0x0f0f0f0f); + const uint32_t x15 = (x14 >> 8); + const uint32_t x16 = (x14 + x15); + const uint32_t x17 = (x16 >> 16); + const uint32_t x18 = (x16 + x17); + const uint32_t x19 = (x18 & 0x0000003f); + return (x19); #endif - } + } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint16_t - uint16Cntlz(uint16_t x) noexcept { + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint16_t + uint16Cntlz(uint16_t x) noexcept { #if defined(LIBRAPID_GNU_CXX) - uint16_t nlz32 = (uint16_t)_uint32_cntlz((uint32_t)x); - uint32_t nlz = (nlz32 - 16); - return (nlz); + uint16_t nlz32 = (uint16_t)_uint32_cntlz((uint32_t)x); + uint32_t nlz = (nlz32 - 16); + return (nlz); #else - const uint16_t x0 = (x >> 1); - const uint16_t x1 = (x | x0); - const uint16_t x2 = (x1 >> 2); - const uint16_t x3 = (x1 | x2); - const uint16_t x4 = (x3 >> 4); - const uint16_t x5 = (x3 | x4); - const uint16_t x6 = (x5 >> 8); - const uint16_t x7 = (x5 | x6); - const uint16_t x8 = (~x7); - const uint16_t x9 = ((x8 >> 1) & 0x5555); - const uint16_t xA = (x8 - x9); - const uint16_t xB = (xA & 0x3333); - const uint16_t xC = ((xA >> 2) & 0x3333); - const uint16_t xD = (xB + xC); - const uint16_t xE = (xD >> 4); - const uint16_t xF = ((xD + xE) & 0x0f0f); - const uint16_t x10 = (xF >> 8); - const uint16_t x11 = ((xF + x10) & 0x001f); - return (x11); + const uint16_t x0 = (x >> 1); + const uint16_t x1 = (x | x0); + const uint16_t x2 = (x1 >> 2); + const uint16_t x3 = (x1 | x2); + const uint16_t x4 = (x3 >> 4); + const uint16_t x5 = (x3 | x4); + const uint16_t x6 = (x5 >> 8); + const uint16_t x7 = (x5 | x6); + const uint16_t x8 = (~x7); + const uint16_t x9 = ((x8 >> 1) & 0x5555); + const uint16_t xA = (x8 - x9); + const uint16_t xB = (xA & 0x3333); + const uint16_t xC = ((xA >> 2) & 0x3333); + const uint16_t xD = (xB + xC); + const uint16_t xE = (xD >> 4); + const uint16_t xF = ((xD + xE) & 0x0f0f); + const uint16_t x10 = (xF >> 8); + const uint16_t x11 = ((xF + x10) & 0x001f); + return (x11); #endif - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint16_t - floatToHalf(uint32_t f) noexcept { - const uint32_t one = (0x00000001); - const uint32_t f_s_mask = (0x80000000); - const uint32_t f_e_mask = (0x7f800000); - const uint32_t f_m_mask = (0x007fffff); - const uint32_t f_m_hidden_bit = (0x00800000); - const uint32_t f_m_round_bit = (0x00001000); - const uint32_t f_snan_mask = (0x7fc00000); - const uint32_t f_e_pos = (0x00000017); - const uint32_t h_e_pos = (0x0000000a); - const uint32_t h_e_mask = (0x00007c00); - const uint32_t h_snan_mask = (0x00007e00); - const uint32_t h_e_mask_value = (0x0000001f); - const uint32_t f_h_s_pos_offset = (0x00000010); - const uint32_t f_h_bias_offset = (0x00000070); - const uint32_t f_h_m_pos_offset = (0x0000000d); - const uint32_t h_nan_min = (0x00007c01); - const uint32_t f_h_e_biased_flag = (0x0000008f); - const uint32_t f_s = (f & f_s_mask); - const uint32_t f_e = (f & f_e_mask); - const uint16_t h_s = (f_s >> f_h_s_pos_offset); - const uint32_t f_m = (f & f_m_mask); - const uint16_t f_e_amount = (f_e >> f_e_pos); - const uint32_t f_e_half_bias = (f_e_amount - f_h_bias_offset); - const uint32_t f_snan = (f & f_snan_mask); - const uint32_t f_m_round_mask = (f_m & f_m_round_bit); - const uint32_t f_m_round_offset = (f_m_round_mask << one); - const uint32_t f_m_rounded = (f_m + f_m_round_offset); - const uint32_t f_m_denorm_sa = (one - f_e_half_bias); - const uint32_t f_m_with_hidden = (f_m_rounded | f_m_hidden_bit); - const uint32_t f_m_denorm = (f_m_with_hidden >> f_m_denorm_sa); - const uint32_t h_m_denorm = (f_m_denorm >> f_h_m_pos_offset); - const uint32_t f_m_rounded_overflow = (f_m_rounded & f_m_hidden_bit); - const uint32_t m_nan = (f_m >> f_h_m_pos_offset); - const uint32_t h_em_nan = (h_e_mask | m_nan); - const uint32_t h_e_norm_overflow_offset = (f_e_half_bias + 1); - const uint32_t h_e_norm_overflow = (h_e_norm_overflow_offset << h_e_pos); - const uint32_t h_e_norm = (f_e_half_bias << h_e_pos); - const uint32_t h_m_norm = (f_m_rounded >> f_h_m_pos_offset); - const uint32_t h_em_norm = (h_e_norm | h_m_norm); - const uint32_t is_h_ndenorm_msb = (f_h_bias_offset - f_e_amount); - const uint32_t is_f_e_flagged_msb = (f_h_e_biased_flag - f_e_half_bias); - const uint32_t is_h_denorm_msb = (~is_h_ndenorm_msb); - const uint32_t is_f_m_eqz_msb = (f_m - 1); - const uint32_t is_h_nan_eqz_msb = (m_nan - 1); - const uint32_t is_f_inf_msb = (is_f_e_flagged_msb & is_f_m_eqz_msb); - const uint32_t is_f_nan_underflow_msb = (is_f_e_flagged_msb & is_h_nan_eqz_msb); - const uint32_t is_e_overflow_msb = (h_e_mask_value - f_e_half_bias); - const uint32_t is_h_inf_msb = (is_e_overflow_msb | is_f_inf_msb); - const uint32_t is_f_nsnan_msb = (f_snan - f_snan_mask); - const uint32_t is_m_norm_overflow_msb = (-((int32_t)f_m_rounded_overflow)); - const uint32_t is_f_snan_msb = (~is_f_nsnan_msb); - const uint32_t h_em_overflow_result = - uint32Sels(is_m_norm_overflow_msb, h_e_norm_overflow, h_em_norm); - const uint32_t h_em_nan_result = - uint32Sels(is_f_e_flagged_msb, h_em_nan, h_em_overflow_result); - const uint32_t h_em_nan_underflow_result = - uint32Sels(is_f_nan_underflow_msb, h_nan_min, h_em_nan_result); - const uint32_t h_em_inf_result = - uint32Sels(is_h_inf_msb, h_e_mask, h_em_nan_underflow_result); - const uint32_t h_em_denorm_result = - uint32Sels(is_h_denorm_msb, h_m_denorm, h_em_inf_result); - const uint32_t h_em_snan_result = - uint32Sels(is_f_snan_msb, h_snan_mask, h_em_denorm_result); - const uint32_t h_result = (h_s | h_em_snan_result); - return (uint16_t)(h_result); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint32_t - halfToFloat(uint16_t h) noexcept { - const uint32_t h_e_mask = (0x00007c00); - const uint32_t h_m_mask = (0x000003ff); - const uint32_t h_s_mask = (0x00008000); - const uint32_t h_f_s_pos_offset = (0x00000010); - const uint32_t h_f_e_pos_offset = (0x0000000d); - const uint32_t h_f_bias_offset = (0x0001c000); - const uint32_t f_e_mask = (0x7f800000); - const uint32_t f_m_mask = (0x007fffff); - const uint32_t h_f_e_denorm_bias = (0x0000007e); - const uint32_t h_f_m_denorm_sa_bias = (0x00000008); - const uint32_t f_e_pos = (0x00000017); - const uint32_t h_e_mask_minus_one = (0x00007bff); - const uint32_t h_e = (h & h_e_mask); - const uint32_t h_m = (h & h_m_mask); - const uint32_t h_s = (h & h_s_mask); - const uint32_t h_e_f_bias = (h_e + h_f_bias_offset); - const uint32_t h_m_nlz = uint32Cntlz(h_m); - const uint32_t f_s = (h_s << h_f_s_pos_offset); - const uint32_t f_e = (h_e_f_bias << h_f_e_pos_offset); - const uint32_t f_m = (h_m << h_f_e_pos_offset); - const uint32_t f_em = (f_e | f_m); - const uint32_t h_f_m_sa = (h_m_nlz - h_f_m_denorm_sa_bias); - const uint32_t f_e_denorm_unpacked = (h_f_e_denorm_bias - h_f_m_sa); - const uint32_t h_f_m = (h_m << h_f_m_sa); - const uint32_t f_m_denorm = (h_f_m & f_m_mask); - const uint32_t f_e_denorm = (f_e_denorm_unpacked << f_e_pos); - const uint32_t f_em_denorm = (f_e_denorm | f_m_denorm); - const uint32_t f_em_nan = (f_e_mask | f_m); - const uint32_t is_e_eqz_msb = (h_e - 1); - const uint32_t is_m_nez_msb = (-((int32_t)h_m)); - const uint32_t is_e_flagged_msb = (h_e_mask_minus_one - h_e); - const uint32_t is_zero_msb = (is_e_eqz_msb & ~is_m_nez_msb); - const uint32_t is_inf_msb = (is_e_flagged_msb & ~is_m_nez_msb); - const uint32_t is_denorm_msb = (is_m_nez_msb & is_e_eqz_msb); - const uint32_t is_nan_msb = (is_e_flagged_msb & is_m_nez_msb); - const uint32_t is_zero = (((std::int32_t)is_zero_msb) >> 31); - const uint32_t f_zero_result = (f_em & ~is_zero); - const uint32_t f_denorm_result = uint32Sels(is_denorm_msb, f_em_denorm, f_zero_result); - const uint32_t f_inf_result = uint32Sels(is_inf_msb, f_e_mask, f_denorm_result); - const uint32_t f_nan_result = uint32Sels(is_nan_msb, f_em_nan, f_inf_result); - const uint32_t f_result = (f_s | f_nan_result); - return (f_result); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint16_t halfAdd(uint16_t x, - uint16_t y) noexcept { - constexpr uint16_t one = (0x0001); - constexpr uint16_t msb_to_lsb_sa = (0x000f); - constexpr uint16_t h_s_mask = (0x8000); - constexpr uint16_t h_e_mask = (0x7c00); - constexpr uint16_t h_m_mask = (0x03ff); - constexpr uint16_t h_m_msb_mask = (0x2000); - constexpr uint16_t h_m_msb_sa = (0x000d); - constexpr uint16_t h_m_hidden = (0x0400); - constexpr uint16_t h_e_pos = (0x000a); - constexpr uint16_t h_e_bias_minus_one = (0x000e); - constexpr uint16_t h_m_grs_carry = (0x4000); - constexpr uint16_t h_m_grs_carry_pos = (0x000e); - constexpr uint16_t h_grs_size = (0x0003); - constexpr uint16_t h_snan = (0xfe00); - constexpr uint16_t h_e_mask_minus_one = (0x7bff); - const uint16_t h_grs_round_carry = (one << h_grs_size); - const uint16_t h_grs_round_mask = (h_grs_round_carry - one); - const uint16_t x_e = (x & h_e_mask); - const uint16_t y_e = (y & h_e_mask); - const uint16_t is_y_e_larger_msb = (x_e - y_e); - const uint16_t a = uint16Sels(is_y_e_larger_msb, y, x); - const uint16_t a_s = (a & h_s_mask); - const uint16_t a_e = (a & h_e_mask); - const uint16_t a_m_no_hidden_bit = (a & h_m_mask); - const uint16_t a_em_no_hidden_bit = (a_e | a_m_no_hidden_bit); - const uint16_t b = uint16Sels(is_y_e_larger_msb, x, y); - const uint16_t b_s = (b & h_s_mask); - const uint16_t b_e = (b & h_e_mask); - const uint16_t b_m_no_hidden_bit = (b & h_m_mask); - const uint16_t b_em_no_hidden_bit = (b_e | b_m_no_hidden_bit); - const uint16_t is_diff_sign_msb = (a_s ^ b_s); - const uint16_t is_a_inf_msb = (h_e_mask_minus_one - a_em_no_hidden_bit); - const uint16_t is_b_inf_msb = (h_e_mask_minus_one - b_em_no_hidden_bit); - const uint16_t is_undenorm_msb = (a_e - 1); - const uint16_t is_undenorm = (((int16_t)is_undenorm_msb) >> 15); - const uint16_t is_both_inf_msb = (is_a_inf_msb & is_b_inf_msb); - const uint16_t is_invalid_inf_op_msb = (is_both_inf_msb & b_s); - const uint16_t is_a_e_nez_msb = (-a_e); - const uint16_t is_b_e_nez_msb = (-b_e); - const uint16_t is_a_e_nez = (((int16_t)is_a_e_nez_msb) >> 15); - const uint16_t is_b_e_nez = (((int16_t)is_b_e_nez_msb) >> 15); - const uint16_t a_m_hidden_bit = (is_a_e_nez & h_m_hidden); - const uint16_t b_m_hidden_bit = (is_b_e_nez & h_m_hidden); - const uint16_t a_m_no_grs = (a_m_no_hidden_bit | a_m_hidden_bit); - const uint16_t b_m_no_grs = (b_m_no_hidden_bit | b_m_hidden_bit); - const uint16_t diff_e = (a_e - b_e); - const uint16_t a_e_unbias = (a_e - h_e_bias_minus_one); - const uint16_t a_m = (a_m_no_grs << h_grs_size); - const uint16_t a_e_biased = (a_e >> h_e_pos); - const uint16_t m_sa_unbias = (a_e_unbias >> h_e_pos); - const uint16_t m_sa_default = (diff_e >> h_e_pos); - const uint16_t m_sa_unbias_mask = (is_a_e_nez_msb & ~is_b_e_nez_msb); - const uint16_t m_sa = uint16Sels(m_sa_unbias_mask, m_sa_unbias, m_sa_default); - const uint16_t b_m_no_sticky = (b_m_no_grs << h_grs_size); - const uint16_t sh_m = (b_m_no_sticky >> m_sa); - const uint16_t sticky_overflow = (one << m_sa); - const uint16_t sticky_mask = (sticky_overflow - 1); - const uint16_t sticky_collect = (b_m_no_sticky & sticky_mask); - const uint16_t is_sticky_set_msb = (-sticky_collect); - const uint16_t sticky = (is_sticky_set_msb >> msb_to_lsb_sa); - const uint16_t b_m = (sh_m | sticky); - const uint16_t is_c_m_ab_pos_msb = (b_m - a_m); - const uint16_t c_inf = (a_s | h_e_mask); - const uint16_t c_m_sum = (a_m + b_m); - const uint16_t c_m_diff_ab = (a_m - b_m); - const uint16_t c_m_diff_ba = (b_m - a_m); - const uint16_t c_m_smag_diff = uint16Sels(is_c_m_ab_pos_msb, c_m_diff_ab, c_m_diff_ba); - const uint16_t c_s_diff = uint16Sels(is_c_m_ab_pos_msb, a_s, b_s); - const uint16_t c_s = uint16Sels(is_diff_sign_msb, c_s_diff, a_s); - const uint16_t c_m_smag_diff_nlz = uint16Cntlz(c_m_smag_diff); - const uint16_t diff_norm_sa = (c_m_smag_diff_nlz - one); - const uint16_t is_diff_denorm_msb = (a_e_biased - diff_norm_sa); - const uint16_t is_diff_denorm = (((int16_t)is_diff_denorm_msb) >> 15); - const uint16_t is_a_or_b_norm_msb = (-a_e_biased); - const uint16_t diff_denorm_sa = (a_e_biased - 1); - const uint16_t c_m_diff_denorm = (c_m_smag_diff << diff_denorm_sa); - const uint16_t c_m_diff_norm = (c_m_smag_diff << diff_norm_sa); - const uint16_t c_e_diff_norm = (a_e_biased - diff_norm_sa); - const uint16_t c_m_diff_ab_norm = - uint16Sels(is_diff_denorm_msb, c_m_diff_denorm, c_m_diff_norm); - const uint16_t c_e_diff_ab_norm = (c_e_diff_norm & ~is_diff_denorm); - const uint16_t c_m_diff = - uint16Sels(is_a_or_b_norm_msb, c_m_diff_ab_norm, c_m_smag_diff); - const uint16_t c_e_diff = uint16Sels(is_a_or_b_norm_msb, c_e_diff_ab_norm, a_e_biased); - const uint16_t is_diff_eqz_msb = (c_m_diff - 1); - const uint16_t is_diff_exactly_zero_msb = (is_diff_sign_msb & is_diff_eqz_msb); - const uint16_t is_diff_exactly_zero = (((int16_t)is_diff_exactly_zero_msb) >> 15); - const uint16_t c_m_added = uint16Sels(is_diff_sign_msb, c_m_diff, c_m_sum); - const uint16_t c_e_added = uint16Sels(is_diff_sign_msb, c_e_diff, a_e_biased); - const uint16_t c_m_carry = (c_m_added & h_m_grs_carry); - const uint16_t is_c_m_carry_msb = (-c_m_carry); - const uint16_t c_e_hidden_offset = ((c_m_added & h_m_grs_carry) >> h_m_grs_carry_pos); - const uint16_t c_m_sub_hidden = (c_m_added >> one); - const uint16_t c_m_no_hidden = uint16Sels(is_c_m_carry_msb, c_m_sub_hidden, c_m_added); - const uint16_t c_e_no_hidden = (c_e_added + c_e_hidden_offset); - const uint16_t c_m_no_hidden_msb = (c_m_no_hidden & h_m_msb_mask); - const uint16_t undenorm_m_msb_odd = (c_m_no_hidden_msb >> h_m_msb_sa); - const uint16_t undenorm_fix_e = (is_undenorm & undenorm_m_msb_odd); - const uint16_t c_e_fixed = (c_e_no_hidden + undenorm_fix_e); - const uint16_t c_m_round_amount = (c_m_no_hidden & h_grs_round_mask); - const uint16_t c_m_rounded = (c_m_no_hidden + c_m_round_amount); - const uint16_t c_m_round_overflow = - ((c_m_rounded & h_m_grs_carry) >> h_m_grs_carry_pos); - const uint16_t c_e_rounded = (c_e_fixed + c_m_round_overflow); - const uint16_t c_m_no_grs = ((c_m_rounded >> h_grs_size) & h_m_mask); - const uint16_t c_e = (c_e_rounded << h_e_pos); - const uint16_t c_em = (c_e | c_m_no_grs); - const uint16_t c_normal = (c_s | c_em); - const uint16_t c_inf_result = uint16Sels(is_a_inf_msb, c_inf, c_normal); - const uint16_t c_zero_result = (c_inf_result & ~is_diff_exactly_zero); - const uint16_t c_result = uint16Sels(is_invalid_inf_op_msb, h_snan, c_zero_result); - return (c_result); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint16_t halfMul(uint16_t x, - uint16_t y) noexcept { - const uint32_t one = (0x00000001); - const uint32_t h_s_mask = (0x00008000); - const uint32_t h_e_mask = (0x00007c00); - const uint32_t h_m_mask = (0x000003ff); - const uint32_t h_m_hidden = (0x00000400); - const uint32_t h_e_pos = (0x0000000a); - const uint32_t h_e_bias = (0x0000000f); - const uint32_t h_m_bit_count = (0x0000000a); - const uint32_t h_m_bit_half_count = (0x00000005); - const uint32_t h_nan_min = (0x00007c01); - const uint32_t h_e_mask_minus_one = (0x00007bff); - const uint32_t h_snan = (0x0000fe00); - const uint32_t m_round_overflow_bit = (0x00000020); - const uint32_t m_hidden_bit = (0x00100000); - const uint32_t a_s = (x & h_s_mask); - const uint32_t b_s = (y & h_s_mask); - const uint32_t c_s = (a_s ^ b_s); - const uint32_t x_e = (x & h_e_mask); - const uint32_t x_e_eqz_msb = (x_e - 1); - const uint32_t a = uint32Sels(x_e_eqz_msb, y, x); - const uint32_t b = uint32Sels(x_e_eqz_msb, x, y); - const uint32_t a_e = (a & h_e_mask); - const uint32_t b_e = (b & h_e_mask); - const uint32_t a_m = (a & h_m_mask); - const uint32_t b_m = (b & h_m_mask); - const uint32_t a_e_amount = (a_e >> h_e_pos); - const uint32_t b_e_amount = (b_e >> h_e_pos); - const uint32_t a_m_with_hidden = (a_m | h_m_hidden); - const uint32_t b_m_with_hidden = (b_m | h_m_hidden); - const uint32_t c_m_normal = (a_m_with_hidden * b_m_with_hidden); - const uint32_t c_m_denorm_biased = (a_m_with_hidden * b_m); - const uint32_t c_e_denorm_unbias_e = (h_e_bias - a_e_amount); - const uint32_t c_m_denorm_round_amount = (c_m_denorm_biased & h_m_mask); - const uint32_t c_m_denorm_rounded = (c_m_denorm_biased + c_m_denorm_round_amount); - const uint32_t c_m_denorm_inplace = (c_m_denorm_rounded >> h_m_bit_count); - const uint32_t c_m_denorm_unbiased = (c_m_denorm_inplace >> c_e_denorm_unbias_e); - const uint32_t c_m_denorm = (c_m_denorm_unbiased & h_m_mask); - const uint32_t c_e_amount_biased = (a_e_amount + b_e_amount); - const uint32_t c_e_amount_unbiased = (c_e_amount_biased - h_e_bias); - const uint32_t is_c_e_unbiased_underflow = (((std::int32_t)c_e_amount_unbiased) >> 31); - const uint32_t c_e_underflow_half_sa = (-((int32_t)c_e_amount_unbiased)); - const uint32_t c_e_underflow_sa = (c_e_underflow_half_sa << one); - const uint32_t c_m_underflow = (c_m_normal >> c_e_underflow_sa); - const uint32_t c_e_underflow_added = (c_e_amount_unbiased & ~is_c_e_unbiased_underflow); - const uint32_t c_m_underflow_added = - uint32Selb(is_c_e_unbiased_underflow, c_m_underflow, c_m_normal); - const uint32_t is_mul_overflow_test = (c_e_underflow_added & m_round_overflow_bit); - const uint32_t is_mul_overflow_msb = (-((int32_t)is_mul_overflow_test)); - const uint32_t c_e_norm_radix_corrected = (c_e_underflow_added + 1); - const uint32_t c_m_norm_radix_corrected = (c_m_underflow_added >> one); - const uint32_t c_m_norm_hidden_bit = (c_m_norm_radix_corrected & m_hidden_bit); - const uint32_t is_c_m_norm_no_hidden_msb = (c_m_norm_hidden_bit - 1); - const uint32_t c_m_norm_lo = (c_m_norm_radix_corrected >> h_m_bit_half_count); - const uint32_t c_m_norm_lo_nlz = - static_cast(uint16Cntlz((uint16_t)c_m_norm_lo)); - const uint32_t is_c_m_hidden_nunderflow_msb = - (c_m_norm_lo_nlz - c_e_norm_radix_corrected); - const uint32_t is_c_m_hidden_underflow_msb = (~is_c_m_hidden_nunderflow_msb); - const uint32_t is_c_m_hidden_underflow = - (((std::int32_t)is_c_m_hidden_underflow_msb) >> 31); - const uint32_t c_m_hidden_underflow_normalized_sa = (c_m_norm_lo_nlz >> one); - const uint32_t c_m_hidden_underflow_normalized = - (c_m_norm_radix_corrected << c_m_hidden_underflow_normalized_sa); - const uint32_t c_m_hidden_normalized = (c_m_norm_radix_corrected << c_m_norm_lo_nlz); - const uint32_t c_e_hidden_normalized = (c_e_norm_radix_corrected - c_m_norm_lo_nlz); - const uint32_t c_e_hidden = (c_e_hidden_normalized & ~is_c_m_hidden_underflow); - const uint32_t c_m_hidden = uint32Sels( - is_c_m_hidden_underflow_msb, c_m_hidden_underflow_normalized, c_m_hidden_normalized); - const uint32_t c_m_normalized = - uint32Sels(is_c_m_norm_no_hidden_msb, c_m_hidden, c_m_norm_radix_corrected); - const uint32_t c_e_normalized = - uint32Sels(is_c_m_norm_no_hidden_msb, c_e_hidden, c_e_norm_radix_corrected); - const uint32_t c_m_norm_round_amount = (c_m_normalized & h_m_mask); - const uint32_t c_m_norm_rounded = (c_m_normalized + c_m_norm_round_amount); - const uint32_t is_round_overflow_test = (c_e_normalized & m_round_overflow_bit); - const uint32_t is_round_overflow_msb = (-((int32_t)is_round_overflow_test)); - const uint32_t c_m_norm_inplace = (c_m_norm_rounded >> h_m_bit_count); - const uint32_t c_m = (c_m_norm_inplace & h_m_mask); - const uint32_t c_e_norm_inplace = (c_e_normalized << h_e_pos); - const uint32_t c_e = (c_e_norm_inplace & h_e_mask); - const uint32_t c_em_nan = (h_e_mask | a_m); - const uint32_t c_nan = (a_s | c_em_nan); - const uint32_t c_denorm = (c_s | c_m_denorm); - const uint32_t c_inf = (c_s | h_e_mask); - const uint32_t c_em_norm = (c_e | c_m); - const uint32_t is_a_e_flagged_msb = (h_e_mask_minus_one - a_e); - const uint32_t is_b_e_flagged_msb = (h_e_mask_minus_one - b_e); - const uint32_t is_a_e_eqz_msb = (a_e - 1); - const uint32_t is_a_m_eqz_msb = (a_m - 1); - const uint32_t is_b_e_eqz_msb = (b_e - 1); - const uint32_t is_b_m_eqz_msb = (b_m - 1); - const uint32_t is_b_eqz_msb = (is_b_e_eqz_msb & is_b_m_eqz_msb); - const uint32_t is_a_eqz_msb = (is_a_e_eqz_msb & is_a_m_eqz_msb); - const uint32_t is_c_nan_via_a_msb = (is_a_e_flagged_msb & ~is_b_e_flagged_msb); - const uint32_t is_c_nan_via_b_msb = (is_b_e_flagged_msb & ~is_b_m_eqz_msb); - const uint32_t is_c_nan_msb = (is_c_nan_via_a_msb | is_c_nan_via_b_msb); - const uint32_t is_c_denorm_msb = (is_b_e_eqz_msb & ~is_a_e_flagged_msb); - const uint32_t is_a_inf_msb = (is_a_e_flagged_msb & is_a_m_eqz_msb); - const uint32_t is_c_snan_msb = (is_a_inf_msb & is_b_eqz_msb); - const uint32_t is_c_nan_min_via_a_msb = (is_a_e_flagged_msb & is_b_eqz_msb); - const uint32_t is_c_nan_min_via_b_msb = (is_b_e_flagged_msb & is_a_eqz_msb); - const uint32_t is_c_nan_min_msb = (is_c_nan_min_via_a_msb | is_c_nan_min_via_b_msb); - const uint32_t is_c_inf_msb = (is_a_e_flagged_msb | is_b_e_flagged_msb); - const uint32_t is_overflow_msb = (is_round_overflow_msb | is_mul_overflow_msb); - const uint32_t c_em_overflow_result = uint32Sels(is_overflow_msb, h_e_mask, c_em_norm); - const uint32_t c_common_result = (c_s | c_em_overflow_result); - const uint32_t c_zero_result = uint32Sels(is_b_eqz_msb, c_s, c_common_result); - const uint32_t c_nan_result = uint32Sels(is_c_nan_msb, c_nan, c_zero_result); - const uint32_t c_nan_min_result = uint32Sels(is_c_nan_min_msb, h_nan_min, c_nan_result); - const uint32_t c_inf_result = uint32Sels(is_c_inf_msb, c_inf, c_nan_min_result); - const uint32_t c_denorm_result = uint32Sels(is_c_denorm_msb, c_denorm, c_inf_result); - const uint32_t c_result = uint32Sels(is_c_snan_msb, h_snan, c_denorm_result); - return (uint16_t)(c_result); - } - - constexpr inline uint16_t halfNeg(uint16_t h) noexcept { return h ^ 0x8000; } - - constexpr inline uint16_t halfSub(uint16_t ha, uint16_t hb) noexcept { - return halfAdd(ha, halfNeg(hb)); - } - } // namespace detail - - class half { - public: - half() noexcept = default; - half(const half &) = default; - half(half &&) = default; - - LIBRAPID_ALWAYS_INLINE half(float f) noexcept; - - template - LIBRAPID_ALWAYS_INLINE explicit half(T d) noexcept; - - half &operator=(const half &) = default; - half &operator=(half &&) = default; - - template - LIBRAPID_ALWAYS_INLINE half &operator=(T d) noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static half fromBits(uint16_t bits) noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE explicit operator float() const noexcept; - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE explicit operator T() const noexcept; - - LIBRAPID_ALWAYS_INLINE half &operator+=(const half &rhs) noexcept; - LIBRAPID_ALWAYS_INLINE half &operator-=(const half &rhs) noexcept; - LIBRAPID_ALWAYS_INLINE half &operator*=(const half &rhs) noexcept; - LIBRAPID_ALWAYS_INLINE half &operator/=(const half &rhs) noexcept; - - LIBRAPID_ALWAYS_INLINE half &operator--() noexcept; - LIBRAPID_ALWAYS_INLINE half operator--(int) noexcept; - LIBRAPID_ALWAYS_INLINE half &operator++() noexcept; - LIBRAPID_ALWAYS_INLINE half operator++(int) noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator-() const noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator+() const noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::float16_t data() const noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::float16_t &data() noexcept; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE std::string - str(const std::string &format = "{}") const; - - // static half infinity; - // static half max; - // static half maxSubnormal; - // static half min; - // static half minPositive; - // static half minPositiveSubnormal; - // static half nan; - // static half negativeInfinity; - // static half epsilon; - // - // static half one; - // static half negativeOne; - // static half two; - // static half negativeTwo; - // static half half_; - // static half negativeHalf; - // static half zero; - // static half negativeZero; - // static half e; - // static half pi; - - private: - detail::float16_t m_value; - }; - - half::half(float f) noexcept { - detail::float32_t tmp; - tmp.m_float = f; - m_value.m_bits = detail::floatToHalf(tmp.m_bits); - } - - template - half::half(T d) noexcept : half(static_cast(d)) {} - - template - half &half::operator=(T d) noexcept { - *this = half(d); - return *this; - } - - half half::fromBits(uint16_t bits) noexcept { - half h; - h.m_value.m_bits = bits; - return h; - } - - half::operator float() const noexcept { - detail::float32_t tmp; - tmp.m_bits = detail::halfToFloat(m_value.m_bits); - return tmp.m_float; - } - - template - LIBRAPID_NODISCARD half::operator T() const noexcept { - return static_cast(static_cast(*this)); - } - - LIBRAPID_ALWAYS_INLINE half &half::operator+=(const half &rhs) noexcept { - m_value.m_bits = detail::halfAdd(m_value.m_bits, rhs.m_value.m_bits); - return *this; - } - - LIBRAPID_ALWAYS_INLINE half &half::operator-=(const half &rhs) noexcept { - m_value.m_bits = detail::halfSub(m_value.m_bits, rhs.m_value.m_bits); - return *this; - } - - LIBRAPID_ALWAYS_INLINE half &half::operator*=(const half &rhs) noexcept { - m_value.m_bits = detail::halfMul(m_value.m_bits, rhs.m_value.m_bits); - return *this; - } - - LIBRAPID_ALWAYS_INLINE half &half::operator/=(const half &rhs) noexcept { - *this = static_cast(*this) / static_cast(rhs); - return *this; - } - - LIBRAPID_ALWAYS_INLINE half &half::operator--() noexcept { - *this -= half::fromBits(static_cast(0x3c00)); - return *this; - } - - LIBRAPID_ALWAYS_INLINE half half::operator--(int) noexcept { - half tmp(*this); - tmp -= half::fromBits(static_cast(0x3c00)); - return tmp; - } - - LIBRAPID_ALWAYS_INLINE half &half::operator++() noexcept { - *this += half::fromBits(static_cast(0x3c00)); - return *this; - } - - LIBRAPID_ALWAYS_INLINE half half::operator++(int) noexcept { - half tmp(*this); - tmp += half::fromBits(static_cast(0x3c00)); - return tmp; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half half::operator-() const noexcept { - return half::fromBits((m_value.m_bits & 0x7fff) | (m_value.m_bits ^ 0x8000)); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half half::operator+() const noexcept { - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::float16_t half::data() const noexcept { - return m_value; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::float16_t &half::data() noexcept { - return m_value; - } - - std::string half::str(const std::string &format) const { - // return fmt::vformat(format, fmt::make_wformat_args(detail::halfToFloat(m_value.m_bits))); - - return std::vformat(format, std::make_format_args(detail::halfToFloat(m_value.m_bits))); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator+(const half &lhs, - const half &rhs) noexcept { - half tmp(lhs); - tmp += rhs; - return tmp; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator-(const half &lhs, - const half &rhs) noexcept { - half tmp(lhs); - tmp -= rhs; - return tmp; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator*(const half &lhs, - const half &rhs) noexcept { - half tmp(lhs); - tmp *= rhs; - return tmp; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator/(const half &lhs, - const half &rhs) noexcept { - half tmp(lhs); - tmp /= rhs; - return tmp; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator<(const half &lhs, - const half &rhs) noexcept { - auto const &l_ieee = lhs.data().m_ieee; - auto const &r_ieee = rhs.data().m_ieee; - - if (l_ieee.m_sign == 1) { - if (r_ieee.m_sign == 0) return true; - if (l_ieee.m_exp > r_ieee.m_exp) return true; - if (l_ieee.m_exp < r_ieee.m_exp) return false; - if (l_ieee.m_frac > r_ieee.m_frac) return true; - return false; - } - - if (r_ieee.m_sign == 1) return false; - if (l_ieee.m_exp > r_ieee.m_exp) return false; - if (l_ieee.m_exp < r_ieee.m_exp) return true; - if (l_ieee.m_frac >= r_ieee.m_frac) return false; - return true; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator==(const half &lhs, - const half &rhs) noexcept { - return lhs.data().m_bits == rhs.data().m_bits; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator!=(const half &lhs, - const half &rhs) noexcept { - return lhs.data().m_bits != rhs.data().m_bits; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator<=(const half &lhs, - const half &rhs) noexcept { - return (lhs < rhs) || (lhs == rhs); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator>(const half &lhs, - const half &rhs) noexcept { - return !(lhs <= rhs); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator>=(const half &lhs, - const half &rhs) noexcept { - return !(lhs < rhs); - } - - namespace typetraits { - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = half; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "half"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = false; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint16_t + floatToHalf(uint32_t f) noexcept { + const uint32_t one = (0x00000001); + const uint32_t f_s_mask = (0x80000000); + const uint32_t f_e_mask = (0x7f800000); + const uint32_t f_m_mask = (0x007fffff); + const uint32_t f_m_hidden_bit = (0x00800000); + const uint32_t f_m_round_bit = (0x00001000); + const uint32_t f_snan_mask = (0x7fc00000); + const uint32_t f_e_pos = (0x00000017); + const uint32_t h_e_pos = (0x0000000a); + const uint32_t h_e_mask = (0x00007c00); + const uint32_t h_snan_mask = (0x00007e00); + const uint32_t h_e_mask_value = (0x0000001f); + const uint32_t f_h_s_pos_offset = (0x00000010); + const uint32_t f_h_bias_offset = (0x00000070); + const uint32_t f_h_m_pos_offset = (0x0000000d); + const uint32_t h_nan_min = (0x00007c01); + const uint32_t f_h_e_biased_flag = (0x0000008f); + const uint32_t f_s = (f & f_s_mask); + const uint32_t f_e = (f & f_e_mask); + const uint16_t h_s = (f_s >> f_h_s_pos_offset); + const uint32_t f_m = (f & f_m_mask); + const uint16_t f_e_amount = (f_e >> f_e_pos); + const uint32_t f_e_half_bias = (f_e_amount - f_h_bias_offset); + const uint32_t f_snan = (f & f_snan_mask); + const uint32_t f_m_round_mask = (f_m & f_m_round_bit); + const uint32_t f_m_round_offset = (f_m_round_mask << one); + const uint32_t f_m_rounded = (f_m + f_m_round_offset); + const uint32_t f_m_denorm_sa = (one - f_e_half_bias); + const uint32_t f_m_with_hidden = (f_m_rounded | f_m_hidden_bit); + const uint32_t f_m_denorm = (f_m_with_hidden >> f_m_denorm_sa); + const uint32_t h_m_denorm = (f_m_denorm >> f_h_m_pos_offset); + const uint32_t f_m_rounded_overflow = (f_m_rounded & f_m_hidden_bit); + const uint32_t m_nan = (f_m >> f_h_m_pos_offset); + const uint32_t h_em_nan = (h_e_mask | m_nan); + const uint32_t h_e_norm_overflow_offset = (f_e_half_bias + 1); + const uint32_t h_e_norm_overflow = (h_e_norm_overflow_offset << h_e_pos); + const uint32_t h_e_norm = (f_e_half_bias << h_e_pos); + const uint32_t h_m_norm = (f_m_rounded >> f_h_m_pos_offset); + const uint32_t h_em_norm = (h_e_norm | h_m_norm); + const uint32_t is_h_ndenorm_msb = (f_h_bias_offset - f_e_amount); + const uint32_t is_f_e_flagged_msb = (f_h_e_biased_flag - f_e_half_bias); + const uint32_t is_h_denorm_msb = (~is_h_ndenorm_msb); + const uint32_t is_f_m_eqz_msb = (f_m - 1); + const uint32_t is_h_nan_eqz_msb = (m_nan - 1); + const uint32_t is_f_inf_msb = (is_f_e_flagged_msb & is_f_m_eqz_msb); + const uint32_t is_f_nan_underflow_msb = (is_f_e_flagged_msb & is_h_nan_eqz_msb); + const uint32_t is_e_overflow_msb = (h_e_mask_value - f_e_half_bias); + const uint32_t is_h_inf_msb = (is_e_overflow_msb | is_f_inf_msb); + const uint32_t is_f_nsnan_msb = (f_snan - f_snan_mask); + const uint32_t is_m_norm_overflow_msb = (-((int32_t)f_m_rounded_overflow)); + const uint32_t is_f_snan_msb = (~is_f_nsnan_msb); + const uint32_t h_em_overflow_result = + uint32Sels(is_m_norm_overflow_msb, h_e_norm_overflow, h_em_norm); + const uint32_t h_em_nan_result = + uint32Sels(is_f_e_flagged_msb, h_em_nan, h_em_overflow_result); + const uint32_t h_em_nan_underflow_result = + uint32Sels(is_f_nan_underflow_msb, h_nan_min, h_em_nan_result); + const uint32_t h_em_inf_result = + uint32Sels(is_h_inf_msb, h_e_mask, h_em_nan_underflow_result); + const uint32_t h_em_denorm_result = + uint32Sels(is_h_denorm_msb, h_m_denorm, h_em_inf_result); + const uint32_t h_em_snan_result = + uint32Sels(is_f_snan_msb, h_snan_mask, h_em_denorm_result); + const uint32_t h_result = (h_s | h_em_snan_result); + return (uint16_t)(h_result); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint32_t + halfToFloat(uint16_t h) noexcept { + const uint32_t h_e_mask = (0x00007c00); + const uint32_t h_m_mask = (0x000003ff); + const uint32_t h_s_mask = (0x00008000); + const uint32_t h_f_s_pos_offset = (0x00000010); + const uint32_t h_f_e_pos_offset = (0x0000000d); + const uint32_t h_f_bias_offset = (0x0001c000); + const uint32_t f_e_mask = (0x7f800000); + const uint32_t f_m_mask = (0x007fffff); + const uint32_t h_f_e_denorm_bias = (0x0000007e); + const uint32_t h_f_m_denorm_sa_bias = (0x00000008); + const uint32_t f_e_pos = (0x00000017); + const uint32_t h_e_mask_minus_one = (0x00007bff); + const uint32_t h_e = (h & h_e_mask); + const uint32_t h_m = (h & h_m_mask); + const uint32_t h_s = (h & h_s_mask); + const uint32_t h_e_f_bias = (h_e + h_f_bias_offset); + const uint32_t h_m_nlz = uint32Cntlz(h_m); + const uint32_t f_s = (h_s << h_f_s_pos_offset); + const uint32_t f_e = (h_e_f_bias << h_f_e_pos_offset); + const uint32_t f_m = (h_m << h_f_e_pos_offset); + const uint32_t f_em = (f_e | f_m); + const uint32_t h_f_m_sa = (h_m_nlz - h_f_m_denorm_sa_bias); + const uint32_t f_e_denorm_unpacked = (h_f_e_denorm_bias - h_f_m_sa); + const uint32_t h_f_m = (h_m << h_f_m_sa); + const uint32_t f_m_denorm = (h_f_m & f_m_mask); + const uint32_t f_e_denorm = (f_e_denorm_unpacked << f_e_pos); + const uint32_t f_em_denorm = (f_e_denorm | f_m_denorm); + const uint32_t f_em_nan = (f_e_mask | f_m); + const uint32_t is_e_eqz_msb = (h_e - 1); + const uint32_t is_m_nez_msb = (-((int32_t)h_m)); + const uint32_t is_e_flagged_msb = (h_e_mask_minus_one - h_e); + const uint32_t is_zero_msb = (is_e_eqz_msb & ~is_m_nez_msb); + const uint32_t is_inf_msb = (is_e_flagged_msb & ~is_m_nez_msb); + const uint32_t is_denorm_msb = (is_m_nez_msb & is_e_eqz_msb); + const uint32_t is_nan_msb = (is_e_flagged_msb & is_m_nez_msb); + const uint32_t is_zero = (((std::int32_t)is_zero_msb) >> 31); + const uint32_t f_zero_result = (f_em & ~is_zero); + const uint32_t f_denorm_result = uint32Sels(is_denorm_msb, f_em_denorm, f_zero_result); + const uint32_t f_inf_result = uint32Sels(is_inf_msb, f_e_mask, f_denorm_result); + const uint32_t f_nan_result = uint32Sels(is_nan_msb, f_em_nan, f_inf_result); + const uint32_t f_result = (f_s | f_nan_result); + return (f_result); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint16_t halfAdd(uint16_t x, + uint16_t y) noexcept { + constexpr uint16_t one = (0x0001); + constexpr uint16_t msb_to_lsb_sa = (0x000f); + constexpr uint16_t h_s_mask = (0x8000); + constexpr uint16_t h_e_mask = (0x7c00); + constexpr uint16_t h_m_mask = (0x03ff); + constexpr uint16_t h_m_msb_mask = (0x2000); + constexpr uint16_t h_m_msb_sa = (0x000d); + constexpr uint16_t h_m_hidden = (0x0400); + constexpr uint16_t h_e_pos = (0x000a); + constexpr uint16_t h_e_bias_minus_one = (0x000e); + constexpr uint16_t h_m_grs_carry = (0x4000); + constexpr uint16_t h_m_grs_carry_pos = (0x000e); + constexpr uint16_t h_grs_size = (0x0003); + constexpr uint16_t h_snan = (0xfe00); + constexpr uint16_t h_e_mask_minus_one = (0x7bff); + const uint16_t h_grs_round_carry = (one << h_grs_size); + const uint16_t h_grs_round_mask = (h_grs_round_carry - one); + const uint16_t x_e = (x & h_e_mask); + const uint16_t y_e = (y & h_e_mask); + const uint16_t is_y_e_larger_msb = (x_e - y_e); + const uint16_t a = uint16Sels(is_y_e_larger_msb, y, x); + const uint16_t a_s = (a & h_s_mask); + const uint16_t a_e = (a & h_e_mask); + const uint16_t a_m_no_hidden_bit = (a & h_m_mask); + const uint16_t a_em_no_hidden_bit = (a_e | a_m_no_hidden_bit); + const uint16_t b = uint16Sels(is_y_e_larger_msb, x, y); + const uint16_t b_s = (b & h_s_mask); + const uint16_t b_e = (b & h_e_mask); + const uint16_t b_m_no_hidden_bit = (b & h_m_mask); + const uint16_t b_em_no_hidden_bit = (b_e | b_m_no_hidden_bit); + const uint16_t is_diff_sign_msb = (a_s ^ b_s); + const uint16_t is_a_inf_msb = (h_e_mask_minus_one - a_em_no_hidden_bit); + const uint16_t is_b_inf_msb = (h_e_mask_minus_one - b_em_no_hidden_bit); + const uint16_t is_undenorm_msb = (a_e - 1); + const uint16_t is_undenorm = (((int16_t)is_undenorm_msb) >> 15); + const uint16_t is_both_inf_msb = (is_a_inf_msb & is_b_inf_msb); + const uint16_t is_invalid_inf_op_msb = (is_both_inf_msb & b_s); + const uint16_t is_a_e_nez_msb = (-a_e); + const uint16_t is_b_e_nez_msb = (-b_e); + const uint16_t is_a_e_nez = (((int16_t)is_a_e_nez_msb) >> 15); + const uint16_t is_b_e_nez = (((int16_t)is_b_e_nez_msb) >> 15); + const uint16_t a_m_hidden_bit = (is_a_e_nez & h_m_hidden); + const uint16_t b_m_hidden_bit = (is_b_e_nez & h_m_hidden); + const uint16_t a_m_no_grs = (a_m_no_hidden_bit | a_m_hidden_bit); + const uint16_t b_m_no_grs = (b_m_no_hidden_bit | b_m_hidden_bit); + const uint16_t diff_e = (a_e - b_e); + const uint16_t a_e_unbias = (a_e - h_e_bias_minus_one); + const uint16_t a_m = (a_m_no_grs << h_grs_size); + const uint16_t a_e_biased = (a_e >> h_e_pos); + const uint16_t m_sa_unbias = (a_e_unbias >> h_e_pos); + const uint16_t m_sa_default = (diff_e >> h_e_pos); + const uint16_t m_sa_unbias_mask = (is_a_e_nez_msb & ~is_b_e_nez_msb); + const uint16_t m_sa = uint16Sels(m_sa_unbias_mask, m_sa_unbias, m_sa_default); + const uint16_t b_m_no_sticky = (b_m_no_grs << h_grs_size); + const uint16_t sh_m = (b_m_no_sticky >> m_sa); + const uint16_t sticky_overflow = (one << m_sa); + const uint16_t sticky_mask = (sticky_overflow - 1); + const uint16_t sticky_collect = (b_m_no_sticky & sticky_mask); + const uint16_t is_sticky_set_msb = (-sticky_collect); + const uint16_t sticky = (is_sticky_set_msb >> msb_to_lsb_sa); + const uint16_t b_m = (sh_m | sticky); + const uint16_t is_c_m_ab_pos_msb = (b_m - a_m); + const uint16_t c_inf = (a_s | h_e_mask); + const uint16_t c_m_sum = (a_m + b_m); + const uint16_t c_m_diff_ab = (a_m - b_m); + const uint16_t c_m_diff_ba = (b_m - a_m); + const uint16_t c_m_smag_diff = uint16Sels(is_c_m_ab_pos_msb, c_m_diff_ab, c_m_diff_ba); + const uint16_t c_s_diff = uint16Sels(is_c_m_ab_pos_msb, a_s, b_s); + const uint16_t c_s = uint16Sels(is_diff_sign_msb, c_s_diff, a_s); + const uint16_t c_m_smag_diff_nlz = uint16Cntlz(c_m_smag_diff); + const uint16_t diff_norm_sa = (c_m_smag_diff_nlz - one); + const uint16_t is_diff_denorm_msb = (a_e_biased - diff_norm_sa); + const uint16_t is_diff_denorm = (((int16_t)is_diff_denorm_msb) >> 15); + const uint16_t is_a_or_b_norm_msb = (-a_e_biased); + const uint16_t diff_denorm_sa = (a_e_biased - 1); + const uint16_t c_m_diff_denorm = (c_m_smag_diff << diff_denorm_sa); + const uint16_t c_m_diff_norm = (c_m_smag_diff << diff_norm_sa); + const uint16_t c_e_diff_norm = (a_e_biased - diff_norm_sa); + const uint16_t c_m_diff_ab_norm = + uint16Sels(is_diff_denorm_msb, c_m_diff_denorm, c_m_diff_norm); + const uint16_t c_e_diff_ab_norm = (c_e_diff_norm & ~is_diff_denorm); + const uint16_t c_m_diff = + uint16Sels(is_a_or_b_norm_msb, c_m_diff_ab_norm, c_m_smag_diff); + const uint16_t c_e_diff = uint16Sels(is_a_or_b_norm_msb, c_e_diff_ab_norm, a_e_biased); + const uint16_t is_diff_eqz_msb = (c_m_diff - 1); + const uint16_t is_diff_exactly_zero_msb = (is_diff_sign_msb & is_diff_eqz_msb); + const uint16_t is_diff_exactly_zero = (((int16_t)is_diff_exactly_zero_msb) >> 15); + const uint16_t c_m_added = uint16Sels(is_diff_sign_msb, c_m_diff, c_m_sum); + const uint16_t c_e_added = uint16Sels(is_diff_sign_msb, c_e_diff, a_e_biased); + const uint16_t c_m_carry = (c_m_added & h_m_grs_carry); + const uint16_t is_c_m_carry_msb = (-c_m_carry); + const uint16_t c_e_hidden_offset = ((c_m_added & h_m_grs_carry) >> h_m_grs_carry_pos); + const uint16_t c_m_sub_hidden = (c_m_added >> one); + const uint16_t c_m_no_hidden = uint16Sels(is_c_m_carry_msb, c_m_sub_hidden, c_m_added); + const uint16_t c_e_no_hidden = (c_e_added + c_e_hidden_offset); + const uint16_t c_m_no_hidden_msb = (c_m_no_hidden & h_m_msb_mask); + const uint16_t undenorm_m_msb_odd = (c_m_no_hidden_msb >> h_m_msb_sa); + const uint16_t undenorm_fix_e = (is_undenorm & undenorm_m_msb_odd); + const uint16_t c_e_fixed = (c_e_no_hidden + undenorm_fix_e); + const uint16_t c_m_round_amount = (c_m_no_hidden & h_grs_round_mask); + const uint16_t c_m_rounded = (c_m_no_hidden + c_m_round_amount); + const uint16_t c_m_round_overflow = + ((c_m_rounded & h_m_grs_carry) >> h_m_grs_carry_pos); + const uint16_t c_e_rounded = (c_e_fixed + c_m_round_overflow); + const uint16_t c_m_no_grs = ((c_m_rounded >> h_grs_size) & h_m_mask); + const uint16_t c_e = (c_e_rounded << h_e_pos); + const uint16_t c_em = (c_e | c_m_no_grs); + const uint16_t c_normal = (c_s | c_em); + const uint16_t c_inf_result = uint16Sels(is_a_inf_msb, c_inf, c_normal); + const uint16_t c_zero_result = (c_inf_result & ~is_diff_exactly_zero); + const uint16_t c_result = uint16Sels(is_invalid_inf_op_msb, h_snan, c_zero_result); + return (c_result); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr uint16_t halfMul(uint16_t x, + uint16_t y) noexcept { + const uint32_t one = (0x00000001); + const uint32_t h_s_mask = (0x00008000); + const uint32_t h_e_mask = (0x00007c00); + const uint32_t h_m_mask = (0x000003ff); + const uint32_t h_m_hidden = (0x00000400); + const uint32_t h_e_pos = (0x0000000a); + const uint32_t h_e_bias = (0x0000000f); + const uint32_t h_m_bit_count = (0x0000000a); + const uint32_t h_m_bit_half_count = (0x00000005); + const uint32_t h_nan_min = (0x00007c01); + const uint32_t h_e_mask_minus_one = (0x00007bff); + const uint32_t h_snan = (0x0000fe00); + const uint32_t m_round_overflow_bit = (0x00000020); + const uint32_t m_hidden_bit = (0x00100000); + const uint32_t a_s = (x & h_s_mask); + const uint32_t b_s = (y & h_s_mask); + const uint32_t c_s = (a_s ^ b_s); + const uint32_t x_e = (x & h_e_mask); + const uint32_t x_e_eqz_msb = (x_e - 1); + const uint32_t a = uint32Sels(x_e_eqz_msb, y, x); + const uint32_t b = uint32Sels(x_e_eqz_msb, x, y); + const uint32_t a_e = (a & h_e_mask); + const uint32_t b_e = (b & h_e_mask); + const uint32_t a_m = (a & h_m_mask); + const uint32_t b_m = (b & h_m_mask); + const uint32_t a_e_amount = (a_e >> h_e_pos); + const uint32_t b_e_amount = (b_e >> h_e_pos); + const uint32_t a_m_with_hidden = (a_m | h_m_hidden); + const uint32_t b_m_with_hidden = (b_m | h_m_hidden); + const uint32_t c_m_normal = (a_m_with_hidden * b_m_with_hidden); + const uint32_t c_m_denorm_biased = (a_m_with_hidden * b_m); + const uint32_t c_e_denorm_unbias_e = (h_e_bias - a_e_amount); + const uint32_t c_m_denorm_round_amount = (c_m_denorm_biased & h_m_mask); + const uint32_t c_m_denorm_rounded = (c_m_denorm_biased + c_m_denorm_round_amount); + const uint32_t c_m_denorm_inplace = (c_m_denorm_rounded >> h_m_bit_count); + const uint32_t c_m_denorm_unbiased = (c_m_denorm_inplace >> c_e_denorm_unbias_e); + const uint32_t c_m_denorm = (c_m_denorm_unbiased & h_m_mask); + const uint32_t c_e_amount_biased = (a_e_amount + b_e_amount); + const uint32_t c_e_amount_unbiased = (c_e_amount_biased - h_e_bias); + const uint32_t is_c_e_unbiased_underflow = (((std::int32_t)c_e_amount_unbiased) >> 31); + const uint32_t c_e_underflow_half_sa = (-((int32_t)c_e_amount_unbiased)); + const uint32_t c_e_underflow_sa = (c_e_underflow_half_sa << one); + const uint32_t c_m_underflow = (c_m_normal >> c_e_underflow_sa); + const uint32_t c_e_underflow_added = (c_e_amount_unbiased & ~is_c_e_unbiased_underflow); + const uint32_t c_m_underflow_added = + uint32Selb(is_c_e_unbiased_underflow, c_m_underflow, c_m_normal); + const uint32_t is_mul_overflow_test = (c_e_underflow_added & m_round_overflow_bit); + const uint32_t is_mul_overflow_msb = (-((int32_t)is_mul_overflow_test)); + const uint32_t c_e_norm_radix_corrected = (c_e_underflow_added + 1); + const uint32_t c_m_norm_radix_corrected = (c_m_underflow_added >> one); + const uint32_t c_m_norm_hidden_bit = (c_m_norm_radix_corrected & m_hidden_bit); + const uint32_t is_c_m_norm_no_hidden_msb = (c_m_norm_hidden_bit - 1); + const uint32_t c_m_norm_lo = (c_m_norm_radix_corrected >> h_m_bit_half_count); + const uint32_t c_m_norm_lo_nlz = + static_cast(uint16Cntlz((uint16_t)c_m_norm_lo)); + const uint32_t is_c_m_hidden_nunderflow_msb = + (c_m_norm_lo_nlz - c_e_norm_radix_corrected); + const uint32_t is_c_m_hidden_underflow_msb = (~is_c_m_hidden_nunderflow_msb); + const uint32_t is_c_m_hidden_underflow = + (((std::int32_t)is_c_m_hidden_underflow_msb) >> 31); + const uint32_t c_m_hidden_underflow_normalized_sa = (c_m_norm_lo_nlz >> one); + const uint32_t c_m_hidden_underflow_normalized = + (c_m_norm_radix_corrected << c_m_hidden_underflow_normalized_sa); + const uint32_t c_m_hidden_normalized = (c_m_norm_radix_corrected << c_m_norm_lo_nlz); + const uint32_t c_e_hidden_normalized = (c_e_norm_radix_corrected - c_m_norm_lo_nlz); + const uint32_t c_e_hidden = (c_e_hidden_normalized & ~is_c_m_hidden_underflow); + const uint32_t c_m_hidden = uint32Sels( + is_c_m_hidden_underflow_msb, c_m_hidden_underflow_normalized, c_m_hidden_normalized); + const uint32_t c_m_normalized = + uint32Sels(is_c_m_norm_no_hidden_msb, c_m_hidden, c_m_norm_radix_corrected); + const uint32_t c_e_normalized = + uint32Sels(is_c_m_norm_no_hidden_msb, c_e_hidden, c_e_norm_radix_corrected); + const uint32_t c_m_norm_round_amount = (c_m_normalized & h_m_mask); + const uint32_t c_m_norm_rounded = (c_m_normalized + c_m_norm_round_amount); + const uint32_t is_round_overflow_test = (c_e_normalized & m_round_overflow_bit); + const uint32_t is_round_overflow_msb = (-((int32_t)is_round_overflow_test)); + const uint32_t c_m_norm_inplace = (c_m_norm_rounded >> h_m_bit_count); + const uint32_t c_m = (c_m_norm_inplace & h_m_mask); + const uint32_t c_e_norm_inplace = (c_e_normalized << h_e_pos); + const uint32_t c_e = (c_e_norm_inplace & h_e_mask); + const uint32_t c_em_nan = (h_e_mask | a_m); + const uint32_t c_nan = (a_s | c_em_nan); + const uint32_t c_denorm = (c_s | c_m_denorm); + const uint32_t c_inf = (c_s | h_e_mask); + const uint32_t c_em_norm = (c_e | c_m); + const uint32_t is_a_e_flagged_msb = (h_e_mask_minus_one - a_e); + const uint32_t is_b_e_flagged_msb = (h_e_mask_minus_one - b_e); + const uint32_t is_a_e_eqz_msb = (a_e - 1); + const uint32_t is_a_m_eqz_msb = (a_m - 1); + const uint32_t is_b_e_eqz_msb = (b_e - 1); + const uint32_t is_b_m_eqz_msb = (b_m - 1); + const uint32_t is_b_eqz_msb = (is_b_e_eqz_msb & is_b_m_eqz_msb); + const uint32_t is_a_eqz_msb = (is_a_e_eqz_msb & is_a_m_eqz_msb); + const uint32_t is_c_nan_via_a_msb = (is_a_e_flagged_msb & ~is_b_e_flagged_msb); + const uint32_t is_c_nan_via_b_msb = (is_b_e_flagged_msb & ~is_b_m_eqz_msb); + const uint32_t is_c_nan_msb = (is_c_nan_via_a_msb | is_c_nan_via_b_msb); + const uint32_t is_c_denorm_msb = (is_b_e_eqz_msb & ~is_a_e_flagged_msb); + const uint32_t is_a_inf_msb = (is_a_e_flagged_msb & is_a_m_eqz_msb); + const uint32_t is_c_snan_msb = (is_a_inf_msb & is_b_eqz_msb); + const uint32_t is_c_nan_min_via_a_msb = (is_a_e_flagged_msb & is_b_eqz_msb); + const uint32_t is_c_nan_min_via_b_msb = (is_b_e_flagged_msb & is_a_eqz_msb); + const uint32_t is_c_nan_min_msb = (is_c_nan_min_via_a_msb | is_c_nan_min_via_b_msb); + const uint32_t is_c_inf_msb = (is_a_e_flagged_msb | is_b_e_flagged_msb); + const uint32_t is_overflow_msb = (is_round_overflow_msb | is_mul_overflow_msb); + const uint32_t c_em_overflow_result = uint32Sels(is_overflow_msb, h_e_mask, c_em_norm); + const uint32_t c_common_result = (c_s | c_em_overflow_result); + const uint32_t c_zero_result = uint32Sels(is_b_eqz_msb, c_s, c_common_result); + const uint32_t c_nan_result = uint32Sels(is_c_nan_msb, c_nan, c_zero_result); + const uint32_t c_nan_min_result = uint32Sels(is_c_nan_min_msb, h_nan_min, c_nan_result); + const uint32_t c_inf_result = uint32Sels(is_c_inf_msb, c_inf, c_nan_min_result); + const uint32_t c_denorm_result = uint32Sels(is_c_denorm_msb, c_denorm, c_inf_result); + const uint32_t c_result = uint32Sels(is_c_snan_msb, h_snan, c_denorm_result); + return (uint16_t)(c_result); + } + + constexpr inline uint16_t halfNeg(uint16_t h) noexcept { return h ^ 0x8000; } + + constexpr inline uint16_t halfSub(uint16_t ha, uint16_t hb) noexcept { + return halfAdd(ha, halfNeg(hb)); + } + } // namespace detail + + class half { + public: + half() noexcept = default; + half(const half &) = default; + half(half &&) = default; + + LIBRAPID_ALWAYS_INLINE half(float f) noexcept; + + template + LIBRAPID_ALWAYS_INLINE explicit half(T d) noexcept; + + half &operator=(const half &) = default; + half &operator=(half &&) = default; + + template + LIBRAPID_ALWAYS_INLINE half &operator=(T d) noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static half fromBits(uint16_t bits) noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE explicit operator float() const noexcept; + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE explicit operator T() const noexcept; + + LIBRAPID_ALWAYS_INLINE half &operator+=(const half &rhs) noexcept; + LIBRAPID_ALWAYS_INLINE half &operator-=(const half &rhs) noexcept; + LIBRAPID_ALWAYS_INLINE half &operator*=(const half &rhs) noexcept; + LIBRAPID_ALWAYS_INLINE half &operator/=(const half &rhs) noexcept; + + LIBRAPID_ALWAYS_INLINE half &operator--() noexcept; + LIBRAPID_ALWAYS_INLINE half operator--(int) noexcept; + LIBRAPID_ALWAYS_INLINE half &operator++() noexcept; + LIBRAPID_ALWAYS_INLINE half operator++(int) noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator-() const noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator+() const noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::float16_t data() const noexcept; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::float16_t &data() noexcept; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE std::string + str(const std::string &format = "{}") const; + + // static half infinity; + // static half max; + // static half maxSubnormal; + // static half min; + // static half minPositive; + // static half minPositiveSubnormal; + // static half nan; + // static half negativeInfinity; + // static half epsilon; + // + // static half one; + // static half negativeOne; + // static half two; + // static half negativeTwo; + // static half half_; + // static half negativeHalf; + // static half zero; + // static half negativeZero; + // static half e; + // static half pi; + + private: + detail::float16_t m_value; + }; + + half::half(float f) noexcept { + detail::float32_t tmp; + tmp.m_float = f; + m_value.m_bits = detail::floatToHalf(tmp.m_bits); + } + + template + half::half(T d) noexcept : half(static_cast(d)) {} + + template + half &half::operator=(T d) noexcept { + *this = half(d); + return *this; + } + + half half::fromBits(uint16_t bits) noexcept { + half h; + h.m_value.m_bits = bits; + return h; + } + + half::operator float() const noexcept { + detail::float32_t tmp; + tmp.m_bits = detail::halfToFloat(m_value.m_bits); + return tmp.m_float; + } + + template + LIBRAPID_NODISCARD half::operator T() const noexcept { + return static_cast(static_cast(*this)); + } + + LIBRAPID_ALWAYS_INLINE half &half::operator+=(const half &rhs) noexcept { + m_value.m_bits = detail::halfAdd(m_value.m_bits, rhs.m_value.m_bits); + return *this; + } + + LIBRAPID_ALWAYS_INLINE half &half::operator-=(const half &rhs) noexcept { + m_value.m_bits = detail::halfSub(m_value.m_bits, rhs.m_value.m_bits); + return *this; + } + + LIBRAPID_ALWAYS_INLINE half &half::operator*=(const half &rhs) noexcept { + m_value.m_bits = detail::halfMul(m_value.m_bits, rhs.m_value.m_bits); + return *this; + } + + LIBRAPID_ALWAYS_INLINE half &half::operator/=(const half &rhs) noexcept { + *this = static_cast(*this) / static_cast(rhs); + return *this; + } + + LIBRAPID_ALWAYS_INLINE half &half::operator--() noexcept { + *this -= half::fromBits(static_cast(0x3c00)); + return *this; + } + + LIBRAPID_ALWAYS_INLINE half half::operator--(int) noexcept { + half tmp(*this); + tmp -= half::fromBits(static_cast(0x3c00)); + return tmp; + } + + LIBRAPID_ALWAYS_INLINE half &half::operator++() noexcept { + *this += half::fromBits(static_cast(0x3c00)); + return *this; + } + + LIBRAPID_ALWAYS_INLINE half half::operator++(int) noexcept { + half tmp(*this); + tmp += half::fromBits(static_cast(0x3c00)); + return tmp; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half half::operator-() const noexcept { + return half::fromBits((m_value.m_bits & 0x7fff) | (m_value.m_bits ^ 0x8000)); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half half::operator+() const noexcept { + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::float16_t half::data() const noexcept { + return m_value; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::float16_t &half::data() noexcept { + return m_value; + } + + std::string half::str(const std::string &format) const { + // return fmt::vformat(format, fmt::make_wformat_args(detail::halfToFloat(m_value.m_bits))); + + return std::vformat(format, std::make_format_args(detail::halfToFloat(m_value.m_bits))); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator+(const half &lhs, + const half &rhs) noexcept { + half tmp(lhs); + tmp += rhs; + return tmp; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator-(const half &lhs, + const half &rhs) noexcept { + half tmp(lhs); + tmp -= rhs; + return tmp; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator*(const half &lhs, + const half &rhs) noexcept { + half tmp(lhs); + tmp *= rhs; + return tmp; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator/(const half &lhs, + const half &rhs) noexcept { + half tmp(lhs); + tmp /= rhs; + return tmp; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator<(const half &lhs, + const half &rhs) noexcept { + auto const &l_ieee = lhs.data().m_ieee; + auto const &r_ieee = rhs.data().m_ieee; + + if (l_ieee.m_sign == 1) { + if (r_ieee.m_sign == 0) return true; + if (l_ieee.m_exp > r_ieee.m_exp) return true; + if (l_ieee.m_exp < r_ieee.m_exp) return false; + if (l_ieee.m_frac > r_ieee.m_frac) return true; + return false; + } + + if (r_ieee.m_sign == 1) return false; + if (l_ieee.m_exp > r_ieee.m_exp) return false; + if (l_ieee.m_exp < r_ieee.m_exp) return true; + if (l_ieee.m_frac >= r_ieee.m_frac) return false; + return true; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator==(const half &lhs, + const half &rhs) noexcept { + return lhs.data().m_bits == rhs.data().m_bits; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator!=(const half &lhs, + const half &rhs) noexcept { + return lhs.data().m_bits != rhs.data().m_bits; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator<=(const half &lhs, + const half &rhs) noexcept { + return (lhs < rhs) || (lhs == rhs); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator>(const half &lhs, + const half &rhs) noexcept { + return !(lhs <= rhs); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool operator>=(const half &lhs, + const half &rhs) noexcept { + return !(lhs < rhs); + } + + namespace typetraits { + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = half; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "half"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = false; #if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_16F; - static constexpr int64_t cudaPacketWidth = 1; + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_16F; + static constexpr int64_t cudaPacketWidth = 1; #endif - static constexpr bool canAlign = true; - static constexpr bool canMemcpy = true; - - LIMIT_IMPL(infinity) { return half::fromBits(static_cast(0x7c00)); } - LIMIT_IMPL(max) { return half::fromBits(static_cast(0x7bff)); } - LIMIT_IMPL(maxSubnormal) { return half::fromBits(static_cast(0x3ff)); } - LIMIT_IMPL(min) { return half::fromBits(static_cast(0xfbff)); } - LIMIT_IMPL(minPositive) { return half::fromBits(static_cast(0x400)); } - LIMIT_IMPL(minPositiveSubnormal) { return half::fromBits(static_cast(0x1)); } - LIMIT_IMPL(nan) { return half::fromBits(static_cast(0x7e00)); } - LIMIT_IMPL(negativeInfinity) { return half::fromBits(static_cast(0xfc00)); } - LIMIT_IMPL(epsilon) { return half::fromBits(static_cast(0x1400)); } - - LIMIT_IMPL(one) { return half::fromBits(static_cast(0x3c00)); } - LIMIT_IMPL(negativeOne) { return half::fromBits(static_cast(0x4000)); } - LIMIT_IMPL(two) { return half::fromBits(static_cast(0x4000)); } - LIMIT_IMPL(negativeTwo) { return half::fromBits(static_cast(0xc000)); } - LIMIT_IMPL(half_) { return half::fromBits(static_cast(0x3800)); } - LIMIT_IMPL(negativeHalf) { return half::fromBits(static_cast(0x3b00)); } - LIMIT_IMPL(zero) { return half::fromBits(static_cast(0x0)); } - LIMIT_IMPL(negativeZero) { return half::fromBits(static_cast(0x8000)); } - LIMIT_IMPL(e) { return half::fromBits(static_cast(0x4170)); } - LIMIT_IMPL(pi) { return half::fromBits(static_cast(0x4248)); } - }; - } // namespace typetraits + static constexpr bool canAlign = true; + static constexpr bool canMemcpy = true; + + LIMIT_IMPL(infinity) { return half::fromBits(static_cast(0x7c00)); } + LIMIT_IMPL(max) { return half::fromBits(static_cast(0x7bff)); } + LIMIT_IMPL(maxSubnormal) { return half::fromBits(static_cast(0x3ff)); } + LIMIT_IMPL(min) { return half::fromBits(static_cast(0xfbff)); } + LIMIT_IMPL(minPositive) { return half::fromBits(static_cast(0x400)); } + LIMIT_IMPL(minPositiveSubnormal) { return half::fromBits(static_cast(0x1)); } + LIMIT_IMPL(nan) { return half::fromBits(static_cast(0x7e00)); } + LIMIT_IMPL(negativeInfinity) { return half::fromBits(static_cast(0xfc00)); } + LIMIT_IMPL(epsilon) { return half::fromBits(static_cast(0x1400)); } + + LIMIT_IMPL(one) { return half::fromBits(static_cast(0x3c00)); } + LIMIT_IMPL(negativeOne) { return half::fromBits(static_cast(0x4000)); } + LIMIT_IMPL(two) { return half::fromBits(static_cast(0x4000)); } + LIMIT_IMPL(negativeTwo) { return half::fromBits(static_cast(0xc000)); } + LIMIT_IMPL(half_) { return half::fromBits(static_cast(0x3800)); } + LIMIT_IMPL(negativeHalf) { return half::fromBits(static_cast(0x3b00)); } + LIMIT_IMPL(zero) { return half::fromBits(static_cast(0x0)); } + LIMIT_IMPL(negativeZero) { return half::fromBits(static_cast(0x8000)); } + LIMIT_IMPL(e) { return half::fromBits(static_cast(0x4170)); } + LIMIT_IMPL(pi) { return half::fromBits(static_cast(0x4248)); } + }; + } // namespace typetraits } // namespace librapid LIBRAPID_SIMPLE_IO_IMPL_NO_TEMPLATE(librapid::half); diff --git a/librapid/include/librapid/math/multiprec.hpp b/librapid/include/librapid/math/multiprec.hpp index 28192d06..e51bf95e 100644 --- a/librapid/include/librapid/math/multiprec.hpp +++ b/librapid/include/librapid/math/multiprec.hpp @@ -4,907 +4,907 @@ #if defined(LIBRAPID_USE_MULTIPREC) namespace librapid { - /// Multiprecision integer type - using mpz = mpz_class; - /// Multiprecision floating point type - using mpf = mpf_class; - /// Multiprecision rational type - using mpq = mpq_class; - /// Multiprecision floating point type with greater functionality - using mpfr = mpfr::mpreal; - - /// Convert a multiprecision integer type to a string with a given base - /// \param val The value to convert - /// \param base The base to convert to - /// \return The converted value - std::string str(const mpz &val, int64_t digits = -1, int base = 10); - - /// Convert a multiprecision floating point type to a string with a given base - /// \param val The value to convert - /// \param base The base to convert to - /// \return The converted value - std::string str(const mpf &val, int64_t digits = -1, int base = 10); - - /// Convert a multiprecision rational type to a string with a given base - /// \param val The value to convert - /// \param base The base to convert to - /// \return The converted value - std::string str(const mpq &val, int64_t digits = -1, int base = 10); - - /// Convert a multiprecision floating point type to a string with a given base - /// \param val The value to convert - /// \param base The base to convert to - /// \return The converted value - std::string str(const mpfr &val, int64_t digits = -1, int base = 10); - - /// Multiprecision integer to multiprecision integer cast - /// \param other The value to cast - /// \return The cast value - mpz toMpz(const mpz &other); - - /// Multiprecision floating point to multiprecision integer cast - /// \param other The value to cast - /// \return The cast value - mpz toMpz(const mpf &other); - - /// Multiprecision rational to multiprecision integer cast - /// \param other The value to cast - /// \return The cast value - mpz toMpz(const mpq &other); - - /// Multiprecision floating point to multiprecision integer cast - /// \param other The value to cast - /// \return The cast value - mpz toMpz(const mpfr &other); - - /// Multiprecision integer to multiprecision floating point cast - /// \param other The value to cast - /// \return The cast value - mpf toMpf(const mpz &other); - - /// Multiprecision floating point to multiprecision floating point cast - /// \param other The value to cast - /// \return The cast value - mpf toMpf(const mpf &other); - - /// Multiprecision rational to multiprecision floating point cast - /// \param other The value to cast - /// \return The cast value - mpf toMpf(const mpq &other); - - /// Multiprecision floating point to multiprecision floating point cast - /// \param other The value to cast - /// \return The cast value - mpf toMpf(const mpfr &other); - - /// Multiprecision integer to multiprecision rational cast - /// \param other The value to cast - /// \return The cast value - - mpq toMpq(const mpz &other); - - /// Multiprecision floating point to multiprecision rational cast - /// \param other The value to cast - /// \return The cast value - mpq toMpq(const mpf &other); - - /// Multiprecision rational to multiprecision rational cast - /// \param other The value to cast - /// \return The cast value - mpq toMpq(const mpq &other); - - /// Multiprecision floating point to multiprecision rational cast - /// \param other The value to cast - /// \return The cast value - mpq toMpq(const mpfr &other); - - /// Multiprecision integer to multiprecision floating point cast - /// \param other The value to cast - /// \return The cast value - mpfr toMpfr(const mpz &other); - - /// Multiprecision floating point to multiprecision floating point cast - /// \param other The value to cast - /// \return The cast value - mpfr toMpfr(const mpf &other); - - /// Multiprecision rational to multiprecision floating point cast - /// \param other The value to cast - /// \return The cast value - mpfr toMpfr(const mpq &other); - - /// Multiprecision floating point to multiprecision floating point cast - /// \param other The value to cast - /// \return The cast value - mpfr toMpfr(const mpfr &other); - - // Trigonometric Functionality for mpf - - /// Sine of a multiprecision floating point value: \f$ \sin (x) \f$ - /// \param val The value to take the sine of - /// \return The sine of the value - mpfr sin(const mpfr &val); - - /// Cosine of a multiprecision floating point value: \f$ \cos (x) \f$ - /// \param val The value to take the cosine of - /// \return The cosine of the value - mpfr cos(const mpfr &val); - - /// Tangent of a multiprecision floating point value: \f$ \tan (x) \f$ - /// \param val The value to take the tangent of - /// \return The tangent of the value - mpfr tan(const mpfr &val); - - /// Arcsine of a multiprecision floating point value: \f$ \sin^{-1} (x) \f$ - /// \param val The value to take the arcsine of - /// \return The arcsine of the value - /// \see sin(const mpfr &) - mpfr asin(const mpfr &val); - - /// Arccosine of a multiprecision floating point value: \f$ \cos^{-1} (x) \f$ - /// \param val The value to take the arccosine of - /// \return The arccosine of the value - /// \see cos(const mpfr &) - mpfr acos(const mpfr &val); - - /// Arctangent of a multiprecision floating point value: \f$ \tan^{-1} (x) \f$ - /// \param val The value to take the arctangent of - /// \return The arctangent of the value - /// \see tan(const mpfr &) - mpfr atan(const mpfr &val); - - /// Atan2 of a multiprecision floating point value: \f$ \tan^{-1}\left(rac{y}{x}\right) \f$ - /// \param dy The y value - /// \param dx The x value - /// \return The atan2 of the value - mpfr atan2(const mpfr &dy, const mpfr &dx); - - /// Cosec of a multiprecision floating point value: \f$ \csc (x) \f$ - /// \param val The value to take the cosec of - /// \return The cosec of the value - mpfr csc(const mpfr &val); - - /// Secant of a multiprecision floating point value: \f$ \sec (x) \f$ - /// \param val The value to take the secant of - /// \return The secant of the value - mpfr sec(const mpfr &val); - - /// Cotangent of a multiprecision floating point value: \f$ \cot (x) \f$ - /// \param val The value to take the cotangent of - /// \return The cotangent of the value - mpfr cot(const mpfr &val); - - /// Arccosec of a multiprecision floating point value: \f$ \csc^{-1} (x) \f$ - /// \param val The value to take the arccosec of - /// \return The arccosec of the value - mpfr acsc(const mpfr &val); - - /// Arcsecant of a multiprecision floating point value: \f$ \sec^{-1} (x) \f$ - /// \param val The value to take the arcsecant of - /// \return The arcsecant of the value - mpfr asec(const mpfr &val); - - /// Arccotangent of a multiprecision floating point value: \f$ \cot^{-1} (x) \f$ - /// \param val The value to take the arccotangent of - /// \return The arccotangent of the value - mpfr acot(const mpfr &val); - - // Hyperbolic Functionality for mpf - - /// Hyperbolic sine of a multiprecision floating point value: \f$ \sinh (x) \f$ - /// \param val The value to take the hyperbolic sine of - /// \return The hyperbolic sine of the value - mpfr sinh(const mpfr &val); - - /// Hyperbolic cosine of a multiprecision floating point value: \f$ \cosh (x) \f$ - /// \param val The value to take the hyperbolic cosine of - /// \return The hyperbolic cosine of the value - mpfr cosh(const mpfr &val); - - /// Hyperbolic tangent of a multiprecision floating point value: \f$ \tanh (x) \f$ - /// \param val The value to take the hyperbolic tangent of - /// \return The hyperbolic tangent of the value - mpfr tanh(const mpfr &val); - - /// Hyperbolic arcsine of a multiprecision floating point value: \f$ \sinh^{-1} (x) \f$ - /// \param val The value to take the hyperbolic arcsine of - /// \return The hyperbolic arcsine of the value - mpfr asinh(const mpfr &val); - - /// Hyperbolic arccosine of a multiprecision floating point value: \f$ \cosh^{-1} (x) \f$ - /// \param val The value to take the hyperbolic arccosine of - /// \return The hyperbolic arccosine of the value - mpfr acosh(const mpfr &val); - - /// Hyperbolic arctangent of a multiprecision floating point value: \f$ \tanh^{-1} (x) \f$ - /// \param val The value to take the hyperbolic arctangent of - /// \return The hyperbolic arctangent of the value - mpfr atanh(const mpfr &val); - - /// Hyperbolic cosec of a multiprecision floating point value: \f$ csch(x) \f$ - /// \param val The value to take the hyperbolic cosec of - /// \return The hyperbolic cosec of the value - mpfr csch(const mpfr &val); - - /// Hyperbolic secant of a multiprecision floating point value: \f$ sech(x) \f$ - /// \param val The value to take the hyperbolic secant of - /// \return The hyperbolic secant of the value - mpfr sech(const mpfr &val); - - /// Hyperbolic cotangent of a multiprecision floating point value: \f$ coth(x) \f$ - /// \param val The value to take the hyperbolic cotangent of - /// \return The hyperbolic cotangent of the value - mpfr coth(const mpfr &val); - - /// Hyperbolic arccosec of a multiprecision floating point value: \f$ csch^{-1}(x) \f$ - /// \param val The value to take the hyperbolic arccosec of - /// \return The hyperbolic arccosec of the value - mpfr acsch(const mpfr &val); - - /// Hyperbolic arcsecant of a multiprecision floating point value: \f$ sech^{-1}(x) \f$ - /// \param val The value to take the hyperbolic arcsecant of - /// \return The hyperbolic arcsecant of the value - mpfr asech(const mpfr &val); - - /// Hyperbolic arccotangent of a multiprecision floating point value: \f$ coth^{-1}(x) - /// \f$ \param val The value to take the hyperbolic arccotangent of \return The hyperbolic - /// \return arccotangent of the value - mpfr acoth(const mpfr &val); - - /// Absolute value of a multiprecision floating point value: \f$ |x| \f$ - /// \param val The value to take the absolute value of - /// \return Absolute value - mpfr abs(const mpfr &val); - - /// Return true if two values are close to each other - /// \tparam T The type of the tolerance - /// \param val1 The first value - /// \param val2 The second value - /// \param tolerance The tolerance - /// \return True if the values are close to each other - // template - // bool isClose(const mpfr &val1, const mpfr &val2, const T &tolerance = 1e-6) { - // return ::librapid::abs(val1 - val2) < tolerance; - // } - - /// Absolute value of a multiprecision integer value: \f$ |x| \f$ - /// \param val The value to take the absolute value of - /// \return Absolute value - mpz abs(const mpz &val); - - /// Absolute value of a multiprecision rational value: \f$ |x| \f$ - /// \param val The value to take the absolute value of - /// \return Absolute value - mpq abs(const mpq &val); - - /// Absolute value of a multiprecision floating point value: \f$ |x| \f$ - /// \param val The value to take the absolute value of - /// \return Absolute value - mpf abs(const mpf &val); - - /// Square root of a multiprecision floating point value: \f$ \sqrt{x} \f$ - /// \param val The value to take the square root of - /// \return The square root of the value - mpfr sqrt(const mpfr &val); - - /// Cube root of a multiprecision floating point value: \f$ \sqrt[3]{x} \f$ - /// \param val The value to take the cube root of - /// \return The cube root of the value - mpfr cbrt(const mpfr &val); - - /// Raise a multiprecision floating point value to a power: \f$ x^y \f$ - /// \param base The value to raise to a power - /// \param pow The power to raise the value to - mpfr pow(const mpfr &base, const mpfr &pow); - - /// Exponential of a multiprecision floating point value: \f$ e^x \f$ - /// \param val The value to take the exponential of - /// \return The exponential of the value - mpfr exp(const mpfr &val); - - /// Raise 2 to the power of a multiprecision floating point value: \f$ 2^x \f$ - /// \param val The value to raise 2 to the power of - /// \return 2 raised to the power of the value - mpfr exp2(const mpfr &val); - - /// Raise 10 to the power of a multiprecision floating point value: \f$ 10^x \f$ - /// \param val The value to raise 10 to the power of - /// \return 10 raised to the power of the value - mpfr exp10(const mpfr &val); - - /// ldexp of a multiprecision floating point value: \f$ x \times 2^exp \f$ - /// \param val The value to take the ldexp of - /// \param exponent The exponent to multiply the value by - /// \return The ldexp of the value - mpfr ldexp(const mpfr &val, int exponent); - - /// Logarithm of a multiprecision floating point value: \f$ \log (x) \f$ - /// \param val The value to take the logarithm of - /// \return The logarithm of the value - mpfr log(const mpfr &val); - - /// Logarithm of a multiprecision floating point value with a given base: \f$ \log_b (x) \f$ - /// \param val The value to take the logarithm of - /// \param base The base to take the logarithm with - /// \return The logarithm of the value with the given base - mpfr log(const mpfr &val, const mpfr &base); - - /// Logarithm of a multiprecision floating point value with base 2: \f$ \log_2 (x) \f$ - /// \param val The value to take the logarithm of - /// \return The logarithm of the value with base 2 - mpfr log2(const mpfr &val); - - /// Logarithm of a multiprecision floating point value with base 10: \f$ \log_{10} (x) \f$ - /// \param val The value to take the logarithm of - /// \return The logarithm of the value with base 10 - mpfr log10(const mpfr &val); - - /// Floor of a multiprecision floating point value: \f$ \lfloor x \rfloor \f$ - /// \param val The value to take the floor of - /// \return The floor of the value - mpfr floor(const mpfr &val); - - /// Ceiling of a multiprecision floating point value: \f$ \lceil x \rceil \f$ - /// \param val The value to take the ceiling of - /// \return The ceiling of the value - mpfr ceil(const mpfr &val); - - /// Floating point modulus of a multiprecision floating point value: \f$ x \bmod y \f$ - /// \param val The value to take the modulus of - /// \param mod The modulus to take the value by - /// \return The modulus of the value - mpfr mod(const mpfr &val, const mpfr &mod); - - /// Hypotenuse of a multiprecision floating point value: \f$ \sqrt{a^2 + b^2} \f$ - /// \param a The first value to take the hypotenuse of - /// \param b The second value to take the hypotenuse of - /// \return The hypotenuse of the values - mpfr hypot(const mpfr &a, const mpfr &b); - - /// Calculate and return \f$ \pi \f$ with LibRapid's current precision - /// \return \f$ \pi \f$ - /// \see prec - LIBRAPID_ALWAYS_INLINE mpfr constPi() { return ::mpfr::const_pi(); } - - /// Calculate and return \f$ \gamma \f$ with LibRapid's current precision, where \f$ \gamma \f$ - /// is the Euler-Mascheroni constant - /// \return \f$ \gamma \f$ - /// \see prec - LIBRAPID_ALWAYS_INLINE mpfr constEulerMascheroni() { return ::mpfr::const_euler(); } - - /// Calculate and return \f$ \log_e(2) \f$ with LibRapid's current precision - /// \return \f$ \log_e(2) \f$ - /// \see prec - LIBRAPID_ALWAYS_INLINE mpfr constLog2() { return ::mpfr::const_log2(); } - - /// Calculate and return Catalan's constant \f$ \gamma \f$ with LibRapid's current precision - /// \return \f$ \gamma \f$ - /// \see prec - LIBRAPID_ALWAYS_INLINE mpfr constCatalan() { return ::mpfr::const_catalan(); } - - /// Evaluates to true if the given type is a multiprecision value - /// \tparam T - template - struct IsMultiprecision : public std::false_type {}; - - template<> - struct IsMultiprecision : public std::true_type {}; - - template<> - struct IsMultiprecision : public std::true_type {}; - - template<> - struct IsMultiprecision : public std::true_type {}; - - template<> - struct IsMultiprecision : public std::true_type {}; - - /// Set the number of base 10 digits to store accurately - /// \param dig10 - inline void prec(int64_t dig10) { - int64_t dig2 = ::mpfr::digits2bits((int)dig10); - mpf_set_default_prec(dig2); - mpfr::mpreal::set_default_prec((mpfr_prec_t)dig2); - } - - /// Set the number of bits used to represent each number - /// \param dig2 - inline void prec2(int64_t dig2) { - mpf_set_default_prec(dig2); - mpfr::mpreal::set_default_prec((mpfr_prec_t)dig2); - } - - /// Returns true if the passed value is not a number (NaN) - /// Note: MPIR does not support NaN, so chances are it'll have errored already... - /// \tparam A - /// \tparam B - /// \param val The value to check - /// \return True if the value is NaN - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isNaN(const __gmp_expr &val) noexcept { - return false; - } - - /// Returns true if the passed value is not a number (NaN) - /// \param val The value to check - /// \return True if the value is NaN - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isNaN(const mpfr &val) noexcept { - return ::mpfr::isnan(val); - } - - /// Returns true if the passed value is finite. - /// Note: MPIR does not support Inf, so we can probably just return true - /// \tparam A - /// \tparam B - /// \param val The value to check - /// \return True if the value is finite - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isFinite(const __gmp_expr &val) noexcept { - return true; - } - - /// Returns true if the passed value is finite. - /// \param val The value to check - /// \return True if the value is finite - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isFinite(const mpfr &val) noexcept { - return ::mpfr::isfinite(val); - } - - /// Returns true if the passed value is infinite. - /// Note: MPIR does not support Inf, so we can probably just return false - /// \tparam A - /// \tparam B - /// \param val The value to check - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isInf(const __gmp_expr &val) noexcept { - return false; - } - - /// Returns true if the passed value is infinite. - /// \param val The value to check - /// \return True if the value is infinite - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isInf(const mpfr &val) noexcept { - return ::mpfr::isinf(val); - } - - /// Copy the sign of a value to another value - /// \param mag The magnitude of the returned value - /// \param sign The sign of the returned value - /// \return ( ) - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE mpfr copySign(const mpfr &mag, - const mpfr &sign) noexcept { - return ::mpfr::copysign(mag, sign); - } - - /// Copy the sign of a value to another value - /// \tparam A - /// \tparam B - /// \return ( ) - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE __gmp_expr - copySign(const __gmp_expr &mag, const __gmp_expr &sign) noexcept { - if (sign >= 0 && mag >= 0) return mag; - if (sign >= 0 && mag < 0) return -mag; - if (sign < 0 && mag >= 0) return -mag; - if (sign < 0 && mag < 0) return mag; - return 0; // Should never get here - } - - /// Extract the sign bit of a value - /// \tparam A - /// \tparam B - /// \param val The value to extract the sign bit from - /// \return The sign bit of the value - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const __gmp_expr &val) noexcept { - return val < 0 || val == -0.0; // I have no idea if this works - } - - /// Extract the sign bit of a value - /// \param val The value to extract the sign bit from - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const mpfr &val) noexcept { - return ::mpfr::signbit(val); - } - - /// Multiply a value by 2 raised to the power of an exponent - /// \return x * 2^exp - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE mpfr ldexp(const mpfr &x, - const int64_t exp) noexcept { - return ::mpfr::ldexp(x, static_cast(exp)); - } - - /// Multiply a value by 2 raised to the power of an exponent - /// \tparam x The value - /// \tparam exp The exponent - /// \return x * 2^exp - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE __gmp_expr ldexp(const __gmp_expr &x, - const int64_t exp) noexcept { - return x << exp; - } - - namespace typetraits { - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = mpz; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "mpz"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = true; - static constexpr bool allowVectorisation = false; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64I; -# endif - - static constexpr bool canAlign = false; - static constexpr bool canMemcpy = false; - - LIMIT_IMPL(min) { return NUM_LIM(min); } - LIMIT_IMPL(max) { return NUM_LIM(max); } - LIMIT_IMPL(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = mpq; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "mpq"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = false; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; -# endif - - static constexpr bool canAlign = false; - static constexpr bool canMemcpy = false; - - LIMIT_IMPL(min) { return NUM_LIM(min); } - LIMIT_IMPL(max) { return NUM_LIM(max); } - LIMIT_IMPL(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = mpf; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "mpf"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = false; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; -# endif - - static constexpr bool canAlign = false; - static constexpr bool canMemcpy = false; - - LIMIT_IMPL(min) { return NUM_LIM(min); } - LIMIT_IMPL(max) { return NUM_LIM(max); } - LIMIT_IMPL(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - - template<> - struct TypeInfo { - static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; - using Scalar = mpfr; - using Packet = std::false_type; - using Backend = backend::CPU; - static constexpr int64_t packetWidth = 1; - static constexpr char name[] = "mpfr"; - static constexpr bool supportsArithmetic = true; - static constexpr bool supportsLogical = true; - static constexpr bool supportsBinary = false; - static constexpr bool allowVectorisation = false; - -# if defined(LIBRAPID_HAS_CUDA) - static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; -# endif - - static constexpr bool canAlign = false; - static constexpr bool canMemcpy = false; - - LIMIT_IMPL(min) { return NUM_LIM(min); } - LIMIT_IMPL(max) { return NUM_LIM(max); } - LIMIT_IMPL(epsilon) { return NUM_LIM(epsilon); } - LIMIT_IMPL(roundError) { return NUM_LIM(round_error); } - LIMIT_IMPL(denormMin) { return NUM_LIM(denorm_min); } - LIMIT_IMPL(infinity) { return NUM_LIM(infinity); } - LIMIT_IMPL(quietNaN) { return NUM_LIM(quiet_NaN); } - LIMIT_IMPL(signalingNaN) { return NUM_LIM(signaling_NaN); } - }; - } // namespace typetraits + /// Multiprecision integer type + using mpz = mpz_class; + /// Multiprecision floating point type + using mpf = mpf_class; + /// Multiprecision rational type + using mpq = mpq_class; + /// Multiprecision floating point type with greater functionality + using mpfr = mpfr::mpreal; + + /// Convert a multiprecision integer type to a string with a given base + /// \param val The value to convert + /// \param base The base to convert to + /// \return The converted value + std::string str(const mpz &val, int64_t digits = -1, int base = 10); + + /// Convert a multiprecision floating point type to a string with a given base + /// \param val The value to convert + /// \param base The base to convert to + /// \return The converted value + std::string str(const mpf &val, int64_t digits = -1, int base = 10); + + /// Convert a multiprecision rational type to a string with a given base + /// \param val The value to convert + /// \param base The base to convert to + /// \return The converted value + std::string str(const mpq &val, int64_t digits = -1, int base = 10); + + /// Convert a multiprecision floating point type to a string with a given base + /// \param val The value to convert + /// \param base The base to convert to + /// \return The converted value + std::string str(const mpfr &val, int64_t digits = -1, int base = 10); + + /// Multiprecision integer to multiprecision integer cast + /// \param other The value to cast + /// \return The cast value + mpz toMpz(const mpz &other); + + /// Multiprecision floating point to multiprecision integer cast + /// \param other The value to cast + /// \return The cast value + mpz toMpz(const mpf &other); + + /// Multiprecision rational to multiprecision integer cast + /// \param other The value to cast + /// \return The cast value + mpz toMpz(const mpq &other); + + /// Multiprecision floating point to multiprecision integer cast + /// \param other The value to cast + /// \return The cast value + mpz toMpz(const mpfr &other); + + /// Multiprecision integer to multiprecision floating point cast + /// \param other The value to cast + /// \return The cast value + mpf toMpf(const mpz &other); + + /// Multiprecision floating point to multiprecision floating point cast + /// \param other The value to cast + /// \return The cast value + mpf toMpf(const mpf &other); + + /// Multiprecision rational to multiprecision floating point cast + /// \param other The value to cast + /// \return The cast value + mpf toMpf(const mpq &other); + + /// Multiprecision floating point to multiprecision floating point cast + /// \param other The value to cast + /// \return The cast value + mpf toMpf(const mpfr &other); + + /// Multiprecision integer to multiprecision rational cast + /// \param other The value to cast + /// \return The cast value + + mpq toMpq(const mpz &other); + + /// Multiprecision floating point to multiprecision rational cast + /// \param other The value to cast + /// \return The cast value + mpq toMpq(const mpf &other); + + /// Multiprecision rational to multiprecision rational cast + /// \param other The value to cast + /// \return The cast value + mpq toMpq(const mpq &other); + + /// Multiprecision floating point to multiprecision rational cast + /// \param other The value to cast + /// \return The cast value + mpq toMpq(const mpfr &other); + + /// Multiprecision integer to multiprecision floating point cast + /// \param other The value to cast + /// \return The cast value + mpfr toMpfr(const mpz &other); + + /// Multiprecision floating point to multiprecision floating point cast + /// \param other The value to cast + /// \return The cast value + mpfr toMpfr(const mpf &other); + + /// Multiprecision rational to multiprecision floating point cast + /// \param other The value to cast + /// \return The cast value + mpfr toMpfr(const mpq &other); + + /// Multiprecision floating point to multiprecision floating point cast + /// \param other The value to cast + /// \return The cast value + mpfr toMpfr(const mpfr &other); + + // Trigonometric Functionality for mpf + + /// Sine of a multiprecision floating point value: \f$ \sin (x) \f$ + /// \param val The value to take the sine of + /// \return The sine of the value + mpfr sin(const mpfr &val); + + /// Cosine of a multiprecision floating point value: \f$ \cos (x) \f$ + /// \param val The value to take the cosine of + /// \return The cosine of the value + mpfr cos(const mpfr &val); + + /// Tangent of a multiprecision floating point value: \f$ \tan (x) \f$ + /// \param val The value to take the tangent of + /// \return The tangent of the value + mpfr tan(const mpfr &val); + + /// Arcsine of a multiprecision floating point value: \f$ \sin^{-1} (x) \f$ + /// \param val The value to take the arcsine of + /// \return The arcsine of the value + /// \see sin(const mpfr &) + mpfr asin(const mpfr &val); + + /// Arccosine of a multiprecision floating point value: \f$ \cos^{-1} (x) \f$ + /// \param val The value to take the arccosine of + /// \return The arccosine of the value + /// \see cos(const mpfr &) + mpfr acos(const mpfr &val); + + /// Arctangent of a multiprecision floating point value: \f$ \tan^{-1} (x) \f$ + /// \param val The value to take the arctangent of + /// \return The arctangent of the value + /// \see tan(const mpfr &) + mpfr atan(const mpfr &val); + + /// Atan2 of a multiprecision floating point value: \f$ \tan^{-1}\left(rac{y}{x}\right) \f$ + /// \param dy The y value + /// \param dx The x value + /// \return The atan2 of the value + mpfr atan2(const mpfr &dy, const mpfr &dx); + + /// Cosec of a multiprecision floating point value: \f$ \csc (x) \f$ + /// \param val The value to take the cosec of + /// \return The cosec of the value + mpfr csc(const mpfr &val); + + /// Secant of a multiprecision floating point value: \f$ \sec (x) \f$ + /// \param val The value to take the secant of + /// \return The secant of the value + mpfr sec(const mpfr &val); + + /// Cotangent of a multiprecision floating point value: \f$ \cot (x) \f$ + /// \param val The value to take the cotangent of + /// \return The cotangent of the value + mpfr cot(const mpfr &val); + + /// Arccosec of a multiprecision floating point value: \f$ \csc^{-1} (x) \f$ + /// \param val The value to take the arccosec of + /// \return The arccosec of the value + mpfr acsc(const mpfr &val); + + /// Arcsecant of a multiprecision floating point value: \f$ \sec^{-1} (x) \f$ + /// \param val The value to take the arcsecant of + /// \return The arcsecant of the value + mpfr asec(const mpfr &val); + + /// Arccotangent of a multiprecision floating point value: \f$ \cot^{-1} (x) \f$ + /// \param val The value to take the arccotangent of + /// \return The arccotangent of the value + mpfr acot(const mpfr &val); + + // Hyperbolic Functionality for mpf + + /// Hyperbolic sine of a multiprecision floating point value: \f$ \sinh (x) \f$ + /// \param val The value to take the hyperbolic sine of + /// \return The hyperbolic sine of the value + mpfr sinh(const mpfr &val); + + /// Hyperbolic cosine of a multiprecision floating point value: \f$ \cosh (x) \f$ + /// \param val The value to take the hyperbolic cosine of + /// \return The hyperbolic cosine of the value + mpfr cosh(const mpfr &val); + + /// Hyperbolic tangent of a multiprecision floating point value: \f$ \tanh (x) \f$ + /// \param val The value to take the hyperbolic tangent of + /// \return The hyperbolic tangent of the value + mpfr tanh(const mpfr &val); + + /// Hyperbolic arcsine of a multiprecision floating point value: \f$ \sinh^{-1} (x) \f$ + /// \param val The value to take the hyperbolic arcsine of + /// \return The hyperbolic arcsine of the value + mpfr asinh(const mpfr &val); + + /// Hyperbolic arccosine of a multiprecision floating point value: \f$ \cosh^{-1} (x) \f$ + /// \param val The value to take the hyperbolic arccosine of + /// \return The hyperbolic arccosine of the value + mpfr acosh(const mpfr &val); + + /// Hyperbolic arctangent of a multiprecision floating point value: \f$ \tanh^{-1} (x) \f$ + /// \param val The value to take the hyperbolic arctangent of + /// \return The hyperbolic arctangent of the value + mpfr atanh(const mpfr &val); + + /// Hyperbolic cosec of a multiprecision floating point value: \f$ csch(x) \f$ + /// \param val The value to take the hyperbolic cosec of + /// \return The hyperbolic cosec of the value + mpfr csch(const mpfr &val); + + /// Hyperbolic secant of a multiprecision floating point value: \f$ sech(x) \f$ + /// \param val The value to take the hyperbolic secant of + /// \return The hyperbolic secant of the value + mpfr sech(const mpfr &val); + + /// Hyperbolic cotangent of a multiprecision floating point value: \f$ coth(x) \f$ + /// \param val The value to take the hyperbolic cotangent of + /// \return The hyperbolic cotangent of the value + mpfr coth(const mpfr &val); + + /// Hyperbolic arccosec of a multiprecision floating point value: \f$ csch^{-1}(x) \f$ + /// \param val The value to take the hyperbolic arccosec of + /// \return The hyperbolic arccosec of the value + mpfr acsch(const mpfr &val); + + /// Hyperbolic arcsecant of a multiprecision floating point value: \f$ sech^{-1}(x) \f$ + /// \param val The value to take the hyperbolic arcsecant of + /// \return The hyperbolic arcsecant of the value + mpfr asech(const mpfr &val); + + /// Hyperbolic arccotangent of a multiprecision floating point value: \f$ coth^{-1}(x) + /// \f$ \param val The value to take the hyperbolic arccotangent of \return The hyperbolic + /// \return arccotangent of the value + mpfr acoth(const mpfr &val); + + /// Absolute value of a multiprecision floating point value: \f$ |x| \f$ + /// \param val The value to take the absolute value of + /// \return Absolute value + mpfr abs(const mpfr &val); + + /// Return true if two values are close to each other + /// \tparam T The type of the tolerance + /// \param val1 The first value + /// \param val2 The second value + /// \param tolerance The tolerance + /// \return True if the values are close to each other + // template + // bool isClose(const mpfr &val1, const mpfr &val2, const T &tolerance = 1e-6) { + // return ::librapid::abs(val1 - val2) < tolerance; + // } + + /// Absolute value of a multiprecision integer value: \f$ |x| \f$ + /// \param val The value to take the absolute value of + /// \return Absolute value + mpz abs(const mpz &val); + + /// Absolute value of a multiprecision rational value: \f$ |x| \f$ + /// \param val The value to take the absolute value of + /// \return Absolute value + mpq abs(const mpq &val); + + /// Absolute value of a multiprecision floating point value: \f$ |x| \f$ + /// \param val The value to take the absolute value of + /// \return Absolute value + mpf abs(const mpf &val); + + /// Square root of a multiprecision floating point value: \f$ \sqrt{x} \f$ + /// \param val The value to take the square root of + /// \return The square root of the value + mpfr sqrt(const mpfr &val); + + /// Cube root of a multiprecision floating point value: \f$ \sqrt[3]{x} \f$ + /// \param val The value to take the cube root of + /// \return The cube root of the value + mpfr cbrt(const mpfr &val); + + /// Raise a multiprecision floating point value to a power: \f$ x^y \f$ + /// \param base The value to raise to a power + /// \param pow The power to raise the value to + mpfr pow(const mpfr &base, const mpfr &pow); + + /// Exponential of a multiprecision floating point value: \f$ e^x \f$ + /// \param val The value to take the exponential of + /// \return The exponential of the value + mpfr exp(const mpfr &val); + + /// Raise 2 to the power of a multiprecision floating point value: \f$ 2^x \f$ + /// \param val The value to raise 2 to the power of + /// \return 2 raised to the power of the value + mpfr exp2(const mpfr &val); + + /// Raise 10 to the power of a multiprecision floating point value: \f$ 10^x \f$ + /// \param val The value to raise 10 to the power of + /// \return 10 raised to the power of the value + mpfr exp10(const mpfr &val); + + /// ldexp of a multiprecision floating point value: \f$ x \times 2^exp \f$ + /// \param val The value to take the ldexp of + /// \param exponent The exponent to multiply the value by + /// \return The ldexp of the value + mpfr ldexp(const mpfr &val, int exponent); + + /// Logarithm of a multiprecision floating point value: \f$ \log (x) \f$ + /// \param val The value to take the logarithm of + /// \return The logarithm of the value + mpfr log(const mpfr &val); + + /// Logarithm of a multiprecision floating point value with a given base: \f$ \log_b (x) \f$ + /// \param val The value to take the logarithm of + /// \param base The base to take the logarithm with + /// \return The logarithm of the value with the given base + mpfr log(const mpfr &val, const mpfr &base); + + /// Logarithm of a multiprecision floating point value with base 2: \f$ \log_2 (x) \f$ + /// \param val The value to take the logarithm of + /// \return The logarithm of the value with base 2 + mpfr log2(const mpfr &val); + + /// Logarithm of a multiprecision floating point value with base 10: \f$ \log_{10} (x) \f$ + /// \param val The value to take the logarithm of + /// \return The logarithm of the value with base 10 + mpfr log10(const mpfr &val); + + /// Floor of a multiprecision floating point value: \f$ \lfloor x \rfloor \f$ + /// \param val The value to take the floor of + /// \return The floor of the value + mpfr floor(const mpfr &val); + + /// Ceiling of a multiprecision floating point value: \f$ \lceil x \rceil \f$ + /// \param val The value to take the ceiling of + /// \return The ceiling of the value + mpfr ceil(const mpfr &val); + + /// Floating point modulus of a multiprecision floating point value: \f$ x \bmod y \f$ + /// \param val The value to take the modulus of + /// \param mod The modulus to take the value by + /// \return The modulus of the value + mpfr mod(const mpfr &val, const mpfr &mod); + + /// Hypotenuse of a multiprecision floating point value: \f$ \sqrt{a^2 + b^2} \f$ + /// \param a The first value to take the hypotenuse of + /// \param b The second value to take the hypotenuse of + /// \return The hypotenuse of the values + mpfr hypot(const mpfr &a, const mpfr &b); + + /// Calculate and return \f$ \pi \f$ with LibRapid's current precision + /// \return \f$ \pi \f$ + /// \see prec + LIBRAPID_ALWAYS_INLINE mpfr constPi() { return ::mpfr::const_pi(); } + + /// Calculate and return \f$ \gamma \f$ with LibRapid's current precision, where \f$ \gamma \f$ + /// is the Euler-Mascheroni constant + /// \return \f$ \gamma \f$ + /// \see prec + LIBRAPID_ALWAYS_INLINE mpfr constEulerMascheroni() { return ::mpfr::const_euler(); } + + /// Calculate and return \f$ \log_e(2) \f$ with LibRapid's current precision + /// \return \f$ \log_e(2) \f$ + /// \see prec + LIBRAPID_ALWAYS_INLINE mpfr constLog2() { return ::mpfr::const_log2(); } + + /// Calculate and return Catalan's constant \f$ \gamma \f$ with LibRapid's current precision + /// \return \f$ \gamma \f$ + /// \see prec + LIBRAPID_ALWAYS_INLINE mpfr constCatalan() { return ::mpfr::const_catalan(); } + + /// Evaluates to true if the given type is a multiprecision value + /// \tparam T + template + struct IsMultiprecision : public std::false_type {}; + + template<> + struct IsMultiprecision : public std::true_type {}; + + template<> + struct IsMultiprecision : public std::true_type {}; + + template<> + struct IsMultiprecision : public std::true_type {}; + + template<> + struct IsMultiprecision : public std::true_type {}; + + /// Set the number of base 10 digits to store accurately + /// \param dig10 + inline void prec(int64_t dig10) { + int64_t dig2 = ::mpfr::digits2bits((int)dig10); + mpf_set_default_prec(dig2); + mpfr::mpreal::set_default_prec((mpfr_prec_t)dig2); + } + + /// Set the number of bits used to represent each number + /// \param dig2 + inline void prec2(int64_t dig2) { + mpf_set_default_prec(dig2); + mpfr::mpreal::set_default_prec((mpfr_prec_t)dig2); + } + + /// Returns true if the passed value is not a number (NaN) + /// Note: MPIR does not support NaN, so chances are it'll have errored already... + /// \tparam A + /// \tparam B + /// \param val The value to check + /// \return True if the value is NaN + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isNaN(const __gmp_expr &val) noexcept { + return false; + } + + /// Returns true if the passed value is not a number (NaN) + /// \param val The value to check + /// \return True if the value is NaN + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isNaN(const mpfr &val) noexcept { + return ::mpfr::isnan(val); + } + + /// Returns true if the passed value is finite. + /// Note: MPIR does not support Inf, so we can probably just return true + /// \tparam A + /// \tparam B + /// \param val The value to check + /// \return True if the value is finite + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isFinite(const __gmp_expr &val) noexcept { + return true; + } + + /// Returns true if the passed value is finite. + /// \param val The value to check + /// \return True if the value is finite + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isFinite(const mpfr &val) noexcept { + return ::mpfr::isfinite(val); + } + + /// Returns true if the passed value is infinite. + /// Note: MPIR does not support Inf, so we can probably just return false + /// \tparam A + /// \tparam B + /// \param val The value to check + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isInf(const __gmp_expr &val) noexcept { + return false; + } + + /// Returns true if the passed value is infinite. + /// \param val The value to check + /// \return True if the value is infinite + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isInf(const mpfr &val) noexcept { + return ::mpfr::isinf(val); + } + + /// Copy the sign of a value to another value + /// \param mag The magnitude of the returned value + /// \param sign The sign of the returned value + /// \return ( ) + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE mpfr copySign(const mpfr &mag, + const mpfr &sign) noexcept { + return ::mpfr::copysign(mag, sign); + } + + /// Copy the sign of a value to another value + /// \tparam A + /// \tparam B + /// \return ( ) + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE __gmp_expr + copySign(const __gmp_expr &mag, const __gmp_expr &sign) noexcept { + if (sign >= 0 && mag >= 0) return mag; + if (sign >= 0 && mag < 0) return -mag; + if (sign < 0 && mag >= 0) return -mag; + if (sign < 0 && mag < 0) return mag; + return 0; // Should never get here + } + + /// Extract the sign bit of a value + /// \tparam A + /// \tparam B + /// \param val The value to extract the sign bit from + /// \return The sign bit of the value + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const __gmp_expr &val) noexcept { + return val < 0 || val == -0.0; // I have no idea if this works + } + + /// Extract the sign bit of a value + /// \param val The value to extract the sign bit from + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const mpfr &val) noexcept { + return ::mpfr::signbit(val); + } + + /// Multiply a value by 2 raised to the power of an exponent + /// \return x * 2^exp + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE mpfr ldexp(const mpfr &x, + const int64_t exp) noexcept { + return ::mpfr::ldexp(x, static_cast(exp)); + } + + /// Multiply a value by 2 raised to the power of an exponent + /// \tparam x The value + /// \tparam exp The exponent + /// \return x * 2^exp + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE __gmp_expr ldexp(const __gmp_expr &x, + const int64_t exp) noexcept { + return x << exp; + } + + namespace typetraits { + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = mpz; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "mpz"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = true; + static constexpr bool allowVectorisation = false; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64I; +# endif + + static constexpr bool canAlign = false; + static constexpr bool canMemcpy = false; + + LIMIT_IMPL(min) { return NUM_LIM(min); } + LIMIT_IMPL(max) { return NUM_LIM(max); } + LIMIT_IMPL(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = mpq; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "mpq"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = false; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; +# endif + + static constexpr bool canAlign = false; + static constexpr bool canMemcpy = false; + + LIMIT_IMPL(min) { return NUM_LIM(min); } + LIMIT_IMPL(max) { return NUM_LIM(max); } + LIMIT_IMPL(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = mpf; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "mpf"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = false; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; +# endif + + static constexpr bool canAlign = false; + static constexpr bool canMemcpy = false; + + LIMIT_IMPL(min) { return NUM_LIM(min); } + LIMIT_IMPL(max) { return NUM_LIM(max); } + LIMIT_IMPL(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + + template<> + struct TypeInfo { + static constexpr detail::LibRapidType type = detail::LibRapidType::Scalar; + using Scalar = mpfr; + using Packet = std::false_type; + using Backend = backend::CPU; + static constexpr int64_t packetWidth = 1; + static constexpr char name[] = "mpfr"; + static constexpr bool supportsArithmetic = true; + static constexpr bool supportsLogical = true; + static constexpr bool supportsBinary = false; + static constexpr bool allowVectorisation = false; + +# if defined(LIBRAPID_HAS_CUDA) + static constexpr cudaDataType_t CudaType = cudaDataType_t::CUDA_R_64F; +# endif + + static constexpr bool canAlign = false; + static constexpr bool canMemcpy = false; + + LIMIT_IMPL(min) { return NUM_LIM(min); } + LIMIT_IMPL(max) { return NUM_LIM(max); } + LIMIT_IMPL(epsilon) { return NUM_LIM(epsilon); } + LIMIT_IMPL(roundError) { return NUM_LIM(round_error); } + LIMIT_IMPL(denormMin) { return NUM_LIM(denorm_min); } + LIMIT_IMPL(infinity) { return NUM_LIM(infinity); } + LIMIT_IMPL(quietNaN) { return NUM_LIM(quiet_NaN); } + LIMIT_IMPL(signalingNaN) { return NUM_LIM(signaling_NaN); } + }; + } // namespace typetraits } // namespace librapid // Provide {fmt} printing capabilities -# ifdef FMT_API +# ifdef FMT_API template<> struct fmt::formatter { - detail::dynamic_format_specs specs_; - - template - constexpr auto parse(ParseContext &ctx) { - auto type = ::fmt::detail::type_constant::value; - auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); - return end; - } - - template - inline auto format(const mpz_class &num, FormatContext &ctx) { - try { - std::stringstream ss; - ss << std::fixed; - ss.precision(specs_.precision < 0 ? 10 : specs_.precision); - ss << num; - return fmt::format_to(ctx.out(), ss.str()); - } catch (std::exception &e) { - return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); - } - } + detail::dynamic_format_specs specs_; + + template + constexpr auto parse(ParseContext &ctx) { + auto type = ::fmt::detail::type_constant::value; + auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); + return end; + } + + template + inline auto format(const mpz_class &num, FormatContext &ctx) { + try { + std::stringstream ss; + ss << std::fixed; + ss.precision(specs_.precision < 0 ? 10 : specs_.precision); + ss << num; + return fmt::format_to(ctx.out(), ss.str()); + } catch (std::exception &e) { + return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); + } + } }; template<> struct fmt::formatter { - detail::dynamic_format_specs specs_; - - template - constexpr auto parse(ParseContext &ctx) { - auto type = ::fmt::detail::type_constant::value; - auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); - return end; - } - - template - inline auto format(const mpf_class &num, FormatContext &ctx) { - try { - if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); - - std::stringstream ss; - ss << std::fixed; - ss.precision(specs_.precision); - ss << num; - return fmt::format_to(ctx.out(), ss.str()); - } catch (std::exception &e) { - return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); - } - } + detail::dynamic_format_specs specs_; + + template + constexpr auto parse(ParseContext &ctx) { + auto type = ::fmt::detail::type_constant::value; + auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); + return end; + } + + template + inline auto format(const mpf_class &num, FormatContext &ctx) { + try { + if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); + + std::stringstream ss; + ss << std::fixed; + ss.precision(specs_.precision); + ss << num; + return fmt::format_to(ctx.out(), ss.str()); + } catch (std::exception &e) { + return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); + } + } }; template struct fmt::formatter<__gmp_expr> { - detail::dynamic_format_specs specs_; - - template - constexpr auto parse(ParseContext &ctx) { - auto type = ::fmt::detail::type_constant::value; - auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); - return end; - } - - template - inline auto format(const __gmp_expr &num, FormatContext &ctx) { - try { - if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); - - std::stringstream ss; - ss << std::fixed; - ss.precision(specs_.precision); - ss << num; - return fmt::format_to(ctx.out(), ss.str()); - } catch (std::exception &e) { - return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); - } - } + detail::dynamic_format_specs specs_; + + template + constexpr auto parse(ParseContext &ctx) { + auto type = ::fmt::detail::type_constant::value; + auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); + return end; + } + + template + inline auto format(const __gmp_expr &num, FormatContext &ctx) { + try { + if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); + + std::stringstream ss; + ss << std::fixed; + ss.precision(specs_.precision); + ss << num; + return fmt::format_to(ctx.out(), ss.str()); + } catch (std::exception &e) { + return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); + } + } }; template<> struct fmt::formatter { - detail::dynamic_format_specs specs_; - - template - constexpr auto parse(ParseContext &ctx) { - auto type = ::fmt::detail::type_constant::value; - auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); - return end; - } - - template - inline auto format(const mpq_class &num, FormatContext &ctx) { - try { - if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); - - std::stringstream ss; - ss << std::fixed; - ss.precision(specs_.precision); - ss << num; - return fmt::format_to(ctx.out(), ss.str()); - } catch (std::exception &e) { - return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); - } - } + detail::dynamic_format_specs specs_; + + template + constexpr auto parse(ParseContext &ctx) { + auto type = ::fmt::detail::type_constant::value; + auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); + return end; + } + + template + inline auto format(const mpq_class &num, FormatContext &ctx) { + try { + if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); + + std::stringstream ss; + ss << std::fixed; + ss.precision(specs_.precision); + ss << num; + return fmt::format_to(ctx.out(), ss.str()); + } catch (std::exception &e) { + return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); + } + } }; template<> struct fmt::formatter { - detail::dynamic_format_specs specs_; - - template - constexpr auto parse(ParseContext &ctx) { - auto type = ::fmt::detail::type_constant::value; - auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); - return end; - } - - template - inline auto format(const librapid::mpfr &num, FormatContext &ctx) { - try { - if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); - - std::stringstream ss; - ss << std::fixed; - ss.precision(specs_.precision); - ss << num; - return fmt::format_to(ctx.out(), ss.str()); - } catch (std::exception &e) { - return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); - } - } + detail::dynamic_format_specs specs_; + + template + constexpr auto parse(ParseContext &ctx) { + auto type = ::fmt::detail::type_constant::value; + auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); + return end; + } + + template + inline auto format(const librapid::mpfr &num, FormatContext &ctx) { + try { + if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); + + std::stringstream ss; + ss << std::fixed; + ss.precision(specs_.precision); + ss << num; + return fmt::format_to(ctx.out(), ss.str()); + } catch (std::exception &e) { + return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); + } + } }; -# endif // FMT_API +# endif // FMT_API -# if defined(SCN_SCN_H) +# if defined(SCN_SCN_H) namespace scn { - SCN_BEGIN_NAMESPACE - - template<> - struct scanner : public detail::string_scanner { - template - error scan(librapid::mpz &val, Context &ctx) { - if (set_parser.enabled()) { - bool loc = (common_options & localized) != 0; - bool mb = - (loc || set_parser.get_option(detail::set_parser_type::flag::use_ranges)) && - detail::is_multichar_type(typename Context::char_type {}); - std::string tmp; - auto ret = do_scan(ctx, tmp, pred {ctx, set_parser, loc, mb}); - val = librapid::mpz(tmp); - return ret; - } - - auto e = skip_range_whitespace(ctx, false); - if (!e) { return e; } - - auto is_space_pred = detail::make_is_space_predicate( - ctx.locale(), (common_options & localized) != 0, field_width); - std::string tmp; - auto ret = do_scan(ctx, tmp, is_space_pred); - val = librapid::mpz(tmp); - return ret; - } - }; - - template<> - struct scanner : public detail::string_scanner { - template - error scan(librapid::mpf &val, Context &ctx) { - if (set_parser.enabled()) { - bool loc = (common_options & localized) != 0; - bool mb = - (loc || set_parser.get_option(detail::set_parser_type::flag::use_ranges)) && - detail::is_multichar_type(typename Context::char_type {}); - std::string tmp; - auto ret = do_scan(ctx, tmp, pred {ctx, set_parser, loc, mb}); - val = librapid::mpf(tmp); - return ret; - } - - auto e = skip_range_whitespace(ctx, false); - if (!e) { return e; } - - auto is_space_pred = detail::make_is_space_predicate( - ctx.locale(), (common_options & localized) != 0, field_width); - std::string tmp; - auto ret = do_scan(ctx, tmp, is_space_pred); - val = librapid::mpf(tmp); - return ret; - } - }; - - template<> - struct scanner : public detail::string_scanner { - template - error scan(librapid::mpq &val, Context &ctx) { - if (set_parser.enabled()) { - bool loc = (common_options & localized) != 0; - bool mb = - (loc || set_parser.get_option(detail::set_parser_type::flag::use_ranges)) && - detail::is_multichar_type(typename Context::char_type {}); - std::string tmp; - auto ret = do_scan(ctx, tmp, pred {ctx, set_parser, loc, mb}); - val = librapid::mpq(tmp); - return ret; - } - - auto e = skip_range_whitespace(ctx, false); - if (!e) { return e; } - - auto is_space_pred = detail::make_is_space_predicate( - ctx.locale(), (common_options & localized) != 0, field_width); - std::string tmp; - auto ret = do_scan(ctx, tmp, is_space_pred); - val = librapid::mpq(tmp); - return ret; - } - }; - - template<> - struct scanner : public detail::string_scanner { - template - error scan(librapid::mpfr &val, Context &ctx) { - if (set_parser.enabled()) { - bool loc = (common_options & localized) != 0; - bool mb = - (loc || set_parser.get_option(detail::set_parser_type::flag::use_ranges)) && - detail::is_multichar_type(typename Context::char_type {}); - std::string tmp; - auto ret = do_scan(ctx, tmp, pred {ctx, set_parser, loc, mb}); - val = librapid::mpfr(tmp); - return ret; - } - - auto e = skip_range_whitespace(ctx, false); - if (!e) { return e; } - - auto is_space_pred = detail::make_is_space_predicate( - ctx.locale(), (common_options & localized) != 0, field_width); - std::string tmp; - auto ret = do_scan(ctx, tmp, is_space_pred); - val = librapid::mpfr(tmp); - return ret; - } - }; - - SCN_END_NAMESPACE + SCN_BEGIN_NAMESPACE + + template<> + struct scanner : public detail::string_scanner { + template + error scan(librapid::mpz &val, Context &ctx) { + if (set_parser.enabled()) { + bool loc = (common_options & localized) != 0; + bool mb = + (loc || set_parser.get_option(detail::set_parser_type::flag::use_ranges)) && + detail::is_multichar_type(typename Context::char_type {}); + std::string tmp; + auto ret = do_scan(ctx, tmp, pred {ctx, set_parser, loc, mb}); + val = librapid::mpz(tmp); + return ret; + } + + auto e = skip_range_whitespace(ctx, false); + if (!e) { return e; } + + auto is_space_pred = detail::make_is_space_predicate( + ctx.locale(), (common_options & localized) != 0, field_width); + std::string tmp; + auto ret = do_scan(ctx, tmp, is_space_pred); + val = librapid::mpz(tmp); + return ret; + } + }; + + template<> + struct scanner : public detail::string_scanner { + template + error scan(librapid::mpf &val, Context &ctx) { + if (set_parser.enabled()) { + bool loc = (common_options & localized) != 0; + bool mb = + (loc || set_parser.get_option(detail::set_parser_type::flag::use_ranges)) && + detail::is_multichar_type(typename Context::char_type {}); + std::string tmp; + auto ret = do_scan(ctx, tmp, pred {ctx, set_parser, loc, mb}); + val = librapid::mpf(tmp); + return ret; + } + + auto e = skip_range_whitespace(ctx, false); + if (!e) { return e; } + + auto is_space_pred = detail::make_is_space_predicate( + ctx.locale(), (common_options & localized) != 0, field_width); + std::string tmp; + auto ret = do_scan(ctx, tmp, is_space_pred); + val = librapid::mpf(tmp); + return ret; + } + }; + + template<> + struct scanner : public detail::string_scanner { + template + error scan(librapid::mpq &val, Context &ctx) { + if (set_parser.enabled()) { + bool loc = (common_options & localized) != 0; + bool mb = + (loc || set_parser.get_option(detail::set_parser_type::flag::use_ranges)) && + detail::is_multichar_type(typename Context::char_type {}); + std::string tmp; + auto ret = do_scan(ctx, tmp, pred {ctx, set_parser, loc, mb}); + val = librapid::mpq(tmp); + return ret; + } + + auto e = skip_range_whitespace(ctx, false); + if (!e) { return e; } + + auto is_space_pred = detail::make_is_space_predicate( + ctx.locale(), (common_options & localized) != 0, field_width); + std::string tmp; + auto ret = do_scan(ctx, tmp, is_space_pred); + val = librapid::mpq(tmp); + return ret; + } + }; + + template<> + struct scanner : public detail::string_scanner { + template + error scan(librapid::mpfr &val, Context &ctx) { + if (set_parser.enabled()) { + bool loc = (common_options & localized) != 0; + bool mb = + (loc || set_parser.get_option(detail::set_parser_type::flag::use_ranges)) && + detail::is_multichar_type(typename Context::char_type {}); + std::string tmp; + auto ret = do_scan(ctx, tmp, pred {ctx, set_parser, loc, mb}); + val = librapid::mpfr(tmp); + return ret; + } + + auto e = skip_range_whitespace(ctx, false); + if (!e) { return e; } + + auto is_space_pred = detail::make_is_space_predicate( + ctx.locale(), (common_options & localized) != 0, field_width); + std::string tmp; + auto ret = do_scan(ctx, tmp, is_space_pred); + val = librapid::mpfr(tmp); + return ret; + } + }; + + SCN_END_NAMESPACE } // namespace scn -# endif // SCN_SCN_H -#endif // LIBRAPID_USE_MULTIPREC +# endif // SCN_SCN_H +#endif // LIBRAPID_USE_MULTIPREC -#endif // LIBRAPID_MATH_MULTIPREC_HPP \ No newline at end of file +#endif // LIBRAPID_MATH_MULTIPREC_HPP \ No newline at end of file diff --git a/librapid/include/librapid/math/random.hpp b/librapid/include/librapid/math/random.hpp index 88265e94..7b29136b 100644 --- a/librapid/include/librapid/math/random.hpp +++ b/librapid/include/librapid/math/random.hpp @@ -2,74 +2,74 @@ #define LIBRAPID_MATH_RANDOM_HPP namespace librapid { - template - LIBRAPID_NODISCARD LIBRAPID_INLINE T random(T lower = 0, T upper = 1) { - // Random floating point value in range [lower, upper) + template + LIBRAPID_NODISCARD LIBRAPID_INLINE T random(T lower = 0, T upper = 1) { + // Random floating point value in range [lower, upper) - static std::uniform_real_distribution distribution(0., 1.); - static std::mt19937 generator((uint32_t) global::randomSeed); + static std::uniform_real_distribution distribution(0., 1.); + static std::mt19937 generator((uint32_t)global::randomSeed); - if (global::reseed) { - generator.seed((uint32_t) global::randomSeed); - global::reseed = false; - } + if (global::reseed) { + generator.seed((uint32_t)global::randomSeed); + global::reseed = false; + } - return (T)(lower + (upper - lower) * distribution(generator)); - } + return (T)(lower + (upper - lower) * distribution(generator)); + } - LIBRAPID_NODISCARD LIBRAPID_INLINE int64_t randint(int64_t lower, int64_t upper) { - // Random integral value in range [lower, upper] - return (int64_t)random((double)(lower - (lower < 0 ? 1 : 0)), (double)upper + 1); - } + LIBRAPID_NODISCARD LIBRAPID_INLINE int64_t randint(int64_t lower, int64_t upper) { + // Random integral value in range [lower, upper] + return (int64_t)random((double)(lower - (lower < 0 ? 1 : 0)), (double)upper + 1); + } - LIBRAPID_NODISCARD LIBRAPID_INLINE double trueRandomEntropy() { - static std::random_device rd; - return rd.entropy(); - } + LIBRAPID_NODISCARD LIBRAPID_INLINE double trueRandomEntropy() { + static std::random_device rd; + return rd.entropy(); + } - template - LIBRAPID_NODISCARD LIBRAPID_INLINE double trueRandom(T lower = 0, T upper = 1) { - // Truly random value in range [lower, upper) - static std::random_device rd; - std::uniform_real_distribution dist((double)lower, (double)upper); - return dist(rd); - } + template + LIBRAPID_NODISCARD LIBRAPID_INLINE double trueRandom(T lower = 0, T upper = 1) { + // Truly random value in range [lower, upper) + static std::random_device rd; + std::uniform_real_distribution dist((double)lower, (double)upper); + return dist(rd); + } - LIBRAPID_NODISCARD LIBRAPID_INLINE int64_t trueRandint(int64_t lower, int64_t upper) { - // Truly random value in range [lower, upper) - return (int64_t)trueRandom((double)(lower - (lower < 0 ? 1 : 0)), (double)upper + 1); - } + LIBRAPID_NODISCARD LIBRAPID_INLINE int64_t trueRandint(int64_t lower, int64_t upper) { + // Truly random value in range [lower, upper) + return (int64_t)trueRandom((double)(lower - (lower < 0 ? 1 : 0)), (double)upper + 1); + } - /** - * Adapted from - * https://docs.oracle.com/javase/6/docs/api/java/util/Random.html#nextGaussian() - */ - template - LIBRAPID_NODISCARD LIBRAPID_INLINE double randomGaussian() { - static double nextGaussian; - static bool hasNextGaussian = false; + /** + * Adapted from + * https://docs.oracle.com/javase/6/docs/api/java/util/Random.html#nextGaussian() + */ + template + LIBRAPID_NODISCARD LIBRAPID_INLINE double randomGaussian() { + static double nextGaussian; + static bool hasNextGaussian = false; - double res; - if (hasNextGaussian) { - hasNextGaussian = false; - res = nextGaussian; - } else { - double v1; - double v2; - double s; - do { - v1 = random(-1, 1); // between -1.0 and 1.0 - v2 = random(-1, 1); // between -1.0 and 1.0 - s = v1 * v1 + v2 * v2; - } while (s >= 1 || s == 0); - double multiplier = sqrt(-2 * ::librapid::log(s) / s); - nextGaussian = v2 * multiplier; - hasNextGaussian = true; - res = v1 * multiplier; - } + double res; + if (hasNextGaussian) { + hasNextGaussian = false; + res = nextGaussian; + } else { + double v1; + double v2; + double s; + do { + v1 = random(-1, 1); // between -1.0 and 1.0 + v2 = random(-1, 1); // between -1.0 and 1.0 + s = v1 * v1 + v2 * v2; + } while (s >= 1 || s == 0); + double multiplier = sqrt(-2 * ::librapid::log(s) / s); + nextGaussian = v2 * multiplier; + hasNextGaussian = true; + res = v1 * multiplier; + } - return static_cast(res); - } + return static_cast(res); + } } // namespace librapid #endif // LIBRAPID_MATH_RANDOM_HPP \ No newline at end of file diff --git a/librapid/include/librapid/math/round.hpp b/librapid/include/librapid/math/round.hpp index 7e474b32..2b6837a0 100644 --- a/librapid/include/librapid/math/round.hpp +++ b/librapid/include/librapid/math/round.hpp @@ -2,142 +2,142 @@ #define LIBRAPID_MATH_ROUND_HPP namespace librapid { - enum class RoundingMode { - // Rounding Mode Information: - // Bit mask: - // [0] -> Round up if difference >= 0.5 - // [1] -> Round up if difference < 0.5 - // [2] -> Round to nearest even - // [3] -> Round to nearest odd - // [4] -> Round only if difference == 0.5 - - UP = 0b00000011, - DOWN = 0b00000000, - TRUNC = 0b00000000, - HALF_EVEN = 0b00010100, - MATH = 0b00000001, - }; - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto round(T num, int64_t dp = 0, - RoundingMode mode = RoundingMode::MATH) { - using Scalar = typename typetraits::TypeInfo::Scalar; - - int8_t mode_ = static_cast(mode); - const double alpha = fastmath::pow10(dp); - const double beta = fastmath::pow10(-dp); - const double absNum = ::librapid::abs(static_cast(num) * alpha); - double y = ::librapid::floor(absNum); - double diff = absNum - y; - if (mode_ & (1 << 0) && diff >= 0.5) y += 1; - if (mode_ & (1 << 2)) { - auto integer = (uint64_t)y; - auto nearestEven = (integer & 1) ? (y + 1) : (double)integer; - if (mode_ & (1 << 4) && diff == 0.5) y = nearestEven; - } - - return static_cast(::librapid::copySign(y * beta, num)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto round(const Complex &num, int64_t dp = 0, - RoundingMode mode = RoundingMode::MATH) { - return Complex(round(real(num), dp, mode), round(imag(num), dp, mode)); - } + enum class RoundingMode { + // Rounding Mode Information: + // Bit mask: + // [0] -> Round up if difference >= 0.5 + // [1] -> Round up if difference < 0.5 + // [2] -> Round to nearest even + // [3] -> Round to nearest odd + // [4] -> Round only if difference == 0.5 + + UP = 0b00000011, + DOWN = 0b00000000, + TRUNC = 0b00000000, + HALF_EVEN = 0b00010100, + MATH = 0b00000001, + }; + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto round(T num, int64_t dp = 0, + RoundingMode mode = RoundingMode::MATH) { + using Scalar = typename typetraits::TypeInfo::Scalar; + + int8_t mode_ = static_cast(mode); + const double alpha = fastmath::pow10(dp); + const double beta = fastmath::pow10(-dp); + const double absNum = ::librapid::abs(static_cast(num) * alpha); + double y = ::librapid::floor(absNum); + double diff = absNum - y; + if (mode_ & (1 << 0) && diff >= 0.5) y += 1; + if (mode_ & (1 << 2)) { + auto integer = (uint64_t)y; + auto nearestEven = (integer & 1) ? (y + 1) : (double)integer; + if (mode_ & (1 << 4) && diff == 0.5) y = nearestEven; + } + + return static_cast(::librapid::copySign(y * beta, num)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto round(const Complex &num, int64_t dp = 0, + RoundingMode mode = RoundingMode::MATH) { + return Complex(round(real(num), dp, mode), round(imag(num), dp, mode)); + } #if defined(LIBRAPID_USE_MULTIPREC) - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto round(const mpfr &num, int64_t dp, - RoundingMode mode) { - using Scalar = mpfr; - int8_t mode_ = static_cast(mode); - const Scalar alpha = ::librapid::exp10(mpfr(dp)); - const Scalar beta = ::librapid::exp10(mpfr(-dp)); - const Scalar absNum = ::librapid::abs(num * alpha); - Scalar y = ::librapid::floor(absNum); - Scalar diff = absNum - y; - if (mode_ & (1 << 0) && diff >= 0.5) y += 1; - if (mode_ & (1 << 2)) { - auto integer = (uint64_t)y; - auto nearestEven = (integer & 1) ? (y + 1) : (Scalar)integer; - if (mode_ & (1 << 4) && diff == 0.5) y = nearestEven; - } - return (num >= 0 ? y : -y) * beta; - } + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto round(const mpfr &num, int64_t dp, + RoundingMode mode) { + using Scalar = mpfr; + int8_t mode_ = static_cast(mode); + const Scalar alpha = ::librapid::exp10(mpfr(dp)); + const Scalar beta = ::librapid::exp10(mpfr(-dp)); + const Scalar absNum = ::librapid::abs(num * alpha); + Scalar y = ::librapid::floor(absNum); + Scalar diff = absNum - y; + if (mode_ & (1 << 0) && diff >= 0.5) y += 1; + if (mode_ & (1 << 2)) { + auto integer = (uint64_t)y; + auto nearestEven = (integer & 1) ? (y + 1) : (Scalar)integer; + if (mode_ & (1 << 4) && diff == 0.5) y = nearestEven; + } + return (num >= 0 ? y : -y) * beta; + } #endif - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T2 roundTo(T1 num, T2 val) { - if (num == static_cast(0)) return 0; - T2 rem = ::librapid::mod(::librapid::abs(static_cast(num)), val); - if (rem >= val / static_cast(2)) - return ::librapid::copySign((::librapid::abs(static_cast(num)) + val) - rem, num); - return ::librapid::copySign(static_cast(num) - rem, num); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex roundTo(const Complex &num, T2 val) { - return Complex(roundTo(real(num), val), roundTo(imag(num), val)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex roundTo(const Complex &num, - const Complex &val) { - return Complex(roundTo(real(num), real(val)), roundTo(imag(num), imag(val))); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T2 roundUpTo(T1 num, T2 val) { - T2 rem = ::librapid::mod(T2(num), val); - if (rem == T2(0)) return static_cast(num); - return (static_cast(num) + val) - rem; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex roundUpTo(const Complex &num, - T2 val) { - return Complex(roundUpTo(real(num), val), roundUpTo(imag(num), val)); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex roundUpTo(const Complex &num, - const Complex &val) { - return Complex(roundUpTo(real(num), real(val)), roundUpTo(imag(num), imag(val))); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T roundSigFig(T num, int64_t figs = 3) { - LIBRAPID_ASSERT(figs > 0, - "Cannot round to {} significant figures. Value must be greater than zero", - figs); - - using Scalar = std::conditional_t, double, T>; - - if (num == static_cast(0)) return static_cast(0); - - auto tmp = ::librapid::abs(static_cast(num)); - int64_t n = 0; - - const auto ten = static_cast(10); - const auto one = static_cast(1); - while (tmp > ten) { - tmp /= ten; - ++n; - } - - while (tmp < one) { - tmp *= ten; - --n; - } - - return ::librapid::copySign(static_cast(round(tmp, figs - 1) * fastmath::pow10(n)), num); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex roundSigFig(const Complex &num, - int64_t figs = 3) { - return Complex(roundSigFig(real(num), figs), roundSigFig(imag(num), figs)); - } + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T2 roundTo(T1 num, T2 val) { + if (num == static_cast(0)) return 0; + T2 rem = ::librapid::mod(::librapid::abs(static_cast(num)), val); + if (rem >= val / static_cast(2)) + return ::librapid::copySign((::librapid::abs(static_cast(num)) + val) - rem, num); + return ::librapid::copySign(static_cast(num) - rem, num); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex roundTo(const Complex &num, T2 val) { + return Complex(roundTo(real(num), val), roundTo(imag(num), val)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex roundTo(const Complex &num, + const Complex &val) { + return Complex(roundTo(real(num), real(val)), roundTo(imag(num), imag(val))); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T2 roundUpTo(T1 num, T2 val) { + T2 rem = ::librapid::mod(T2(num), val); + if (rem == T2(0)) return static_cast(num); + return (static_cast(num) + val) - rem; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex roundUpTo(const Complex &num, + T2 val) { + return Complex(roundUpTo(real(num), val), roundUpTo(imag(num), val)); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex roundUpTo(const Complex &num, + const Complex &val) { + return Complex(roundUpTo(real(num), real(val)), roundUpTo(imag(num), imag(val))); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T roundSigFig(T num, int64_t figs = 3) { + LIBRAPID_ASSERT(figs > 0, + "Cannot round to {} significant figures. Value must be greater than zero", + figs); + + using Scalar = std::conditional_t, double, T>; + + if (num == static_cast(0)) return static_cast(0); + + auto tmp = ::librapid::abs(static_cast(num)); + int64_t n = 0; + + const auto ten = static_cast(10); + const auto one = static_cast(1); + while (tmp > ten) { + tmp /= ten; + ++n; + } + + while (tmp < one) { + tmp *= ten; + --n; + } + + return ::librapid::copySign(static_cast(round(tmp, figs - 1) * fastmath::pow10(n)), num); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Complex roundSigFig(const Complex &num, + int64_t figs = 3) { + return Complex(roundSigFig(real(num), figs), roundSigFig(imag(num), figs)); + } } // namespace librapid #endif // LIBRAPID_MATH_ROUND_HPP \ No newline at end of file diff --git a/librapid/include/librapid/math/utilityFunctions.hpp b/librapid/include/librapid/math/utilityFunctions.hpp index e810243e..b0fd2afc 100644 --- a/librapid/include/librapid/math/utilityFunctions.hpp +++ b/librapid/include/librapid/math/utilityFunctions.hpp @@ -2,157 +2,157 @@ #define LIBRAPID_MATH_UTLIITY_FUNCTIONS_HPP namespace librapid { - /// \brief Limit a value to a specified range - /// - /// \f$ C(x, m, M) = \left\{ \begin{align*} x & \quad m \le x \le M \\ m & \quad x < m \\ M & - /// \quad x > M \end{align*}\right. \f$ - /// - /// If M < m, the values are swapped to make the function valid. - /// For example, `clamp(5, 10, 0)` still returns `5`. - /// - /// \tparam X Type of \p x - /// \tparam Lower Type of \p lowerLimit - /// \tparam Upper Type of \p upperLimit - /// \param x Value to limit - /// \param lowerLimit Lower bound (m) - /// \param upperLimit Upper bound (M) - /// \return \p x limited to the range [\p lowerLimit, \p upperLimit] - template::type == detail::LibRapidType::Scalar && - typetraits::TypeInfo::type == detail::LibRapidType::Scalar && - typetraits::TypeInfo::type == detail::LibRapidType::Scalar, - int> = 0, - typename ST = typetraits::ScalarReturnType> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ST clamp(X x, Lower lowerLimit, Upper upperLimit) { - LIBRAPID_ASSERT(lowerLimit < upperLimit, "Lower limit must be below upper limit"); - if (x < lowerLimit) return static_cast(lowerLimit); - if (x > upperLimit) return static_cast(upperLimit); - return x; - } + /// \brief Limit a value to a specified range + /// + /// \f$ C(x, m, M) = \left\{ \begin{align*} x & \quad m \le x \le M \\ m & \quad x < m \\ M & + /// \quad x > M \end{align*}\right. \f$ + /// + /// If M < m, the values are swapped to make the function valid. + /// For example, `clamp(5, 10, 0)` still returns `5`. + /// + /// \tparam X Type of \p x + /// \tparam Lower Type of \p lowerLimit + /// \tparam Upper Type of \p upperLimit + /// \param x Value to limit + /// \param lowerLimit Lower bound (m) + /// \param upperLimit Upper bound (M) + /// \return \p x limited to the range [\p lowerLimit, \p upperLimit] + template::type == detail::LibRapidType::Scalar && + typetraits::TypeInfo::type == detail::LibRapidType::Scalar && + typetraits::TypeInfo::type == detail::LibRapidType::Scalar, + int> = 0, + typename ST = typetraits::ScalarReturnType> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ST clamp(X x, Lower lowerLimit, Upper upperLimit) { + LIBRAPID_ASSERT(lowerLimit < upperLimit, "Lower limit must be below upper limit"); + if (x < lowerLimit) return static_cast(lowerLimit); + if (x > upperLimit) return static_cast(upperLimit); + return x; + } - /// \brief Linearly interpolate between two values - /// - /// \f$ \mathrm{lerp}(t, L, U) = L+t\left( U-L \right) \f$ - /// - /// \tparam T Type of \p t - /// \tparam Lower Type of \p lower - /// \tparam Upper Type of \p upper - /// \param t Interpolation Percentage - /// \param lower Lower bound (L) - /// \param upper Upper bound (U) - /// \return - template::type == detail::LibRapidType::Scalar && - typetraits::TypeInfo::type == detail::LibRapidType::Scalar && - typetraits::TypeInfo::type == detail::LibRapidType::Scalar && - std::is_floating_point_v && std::is_floating_point_v && - std::is_floating_point_v, - int> = 0, - typename ST = typetraits::ScalarReturnType> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ST lerp(T t, Lower lower, Upper upper) { - if (isNaN(t) || isNaN(lower) || isNaN(upper)) - return std::numeric_limits::quiet_NaN(); - else if ((t <= ST {0} && upper >= Upper {0}) || (lower >= Lower {0} && upper <= Upper {0})) - // ab <= 0 but product could overflow. + /// \brief Linearly interpolate between two values + /// + /// \f$ \mathrm{lerp}(t, L, U) = L+t\left( U-L \right) \f$ + /// + /// \tparam T Type of \p t + /// \tparam Lower Type of \p lower + /// \tparam Upper Type of \p upper + /// \param t Interpolation Percentage + /// \param lower Lower bound (L) + /// \param upper Upper bound (U) + /// \return + template::type == detail::LibRapidType::Scalar && + typetraits::TypeInfo::type == detail::LibRapidType::Scalar && + typetraits::TypeInfo::type == detail::LibRapidType::Scalar && + std::is_floating_point_v && std::is_floating_point_v && + std::is_floating_point_v, + int> = 0, + typename ST = typetraits::ScalarReturnType> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ST lerp(T t, Lower lower, Upper upper) { + if (isNaN(t) || isNaN(lower) || isNaN(upper)) + return std::numeric_limits::quiet_NaN(); + else if ((t <= ST {0} && upper >= Upper {0}) || (lower >= Lower {0} && upper <= Upper {0})) + // ab <= 0 but product could overflow. #ifndef FMA - return t * upper + (ST {1} - t) * lower; + return t * upper + (ST {1} - t) * lower; #else - return std::fma(t, upper, (_Float {1} - t) * upper); + return std::fma(t, upper, (_Float {1} - t) * upper); #endif - else if (t == ST {1}) - return upper; - else { // monotonic near t == 1. + else if (t == ST {1}) + return upper; + else { // monotonic near t == 1. #ifndef FMA - const auto x = lower + t * (upper - lower); + const auto x = lower + t * (upper - lower); #else - const auto x = std::fma(t, upper - lower, lower); + const auto x = std::fma(t, upper - lower, lower); #endif - return (t > ST {1}) == (upper > lower) ? max(upper, x) : min(upper, x); - } - } + return (t > ST {1}) == (upper > lower) ? max(upper, x) : min(upper, x); + } + } - /// \brief Linearly interpolate between two values - /// - /// \f$ \mathrm{lerp}(t, L, U) = L+t\left( U-L \right) \f$. The result is clamped to the - /// specified range. - /// - /// \tparam T Type of \p t - /// \tparam Lower Type of \p lower - /// \tparam Upper Type of \p upper - /// \param t Interpolation Percentage - /// \param lower Lower bound (L) - /// \param upper Upper bound (U) - /// \return - template::type == detail::LibRapidType::Scalar && - typetraits::TypeInfo::type == detail::LibRapidType::Scalar && - typetraits::TypeInfo::type == detail::LibRapidType::Scalar && - !std::is_floating_point_v || - !std::is_floating_point_v || !std::is_floating_point_v, - int> = 0, - typename ST = typetraits::ScalarReturnType> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ST lerp(T t, Lower lower, Upper upper) { - if (isNaN(t) || isNaN(lower) || isNaN(upper)) return std::numeric_limits::quiet_NaN(); - return static_cast(lower) + (static_cast(upper) - static_cast(lower)) * t; - } + /// \brief Linearly interpolate between two values + /// + /// \f$ \mathrm{lerp}(t, L, U) = L+t\left( U-L \right) \f$. The result is clamped to the + /// specified range. + /// + /// \tparam T Type of \p t + /// \tparam Lower Type of \p lower + /// \tparam Upper Type of \p upper + /// \param t Interpolation Percentage + /// \param lower Lower bound (L) + /// \param upper Upper bound (U) + /// \return + template::type == detail::LibRapidType::Scalar && + typetraits::TypeInfo::type == detail::LibRapidType::Scalar && + typetraits::TypeInfo::type == detail::LibRapidType::Scalar && + !std::is_floating_point_v || + !std::is_floating_point_v || !std::is_floating_point_v, + int> = 0, + typename ST = typetraits::ScalarReturnType> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ST lerp(T t, Lower lower, Upper upper) { + if (isNaN(t) || isNaN(lower) || isNaN(upper)) return std::numeric_limits::quiet_NaN(); + return static_cast(lower) + (static_cast(upper) - static_cast(lower)) * t; + } - /// \brief Smoothly interpolate between two values - /// - /// This smooth step implementation is based on Ken Perlin's algorithm. - /// \f$ S(x)= \begin{cases}0 & x \leq 0 \\ 6 x^5-15 x^4+10 x^3 & 0 \leq x \leq 1 \\ 1 & 1 \leq - /// x\end{cases} \f$ - /// - /// This function allows you to specify a lower and upper edge, which can be used to scale - /// the range of inputs. - /// - /// \tparam T Type of \p t - /// \tparam Lower Type of \p lowerEdge - /// \tparam Upper Type of \p upperEdge - /// \param t Value to smooth step - /// \param lowerEdge At t=lowerEdge, the function returns 0 - /// \param upperEdge At t=upperEdge, the function returns 1 - /// \return \p t interpolated between \p lowerEdge and \p upperEdge - template::type == detail::LibRapidType::Scalar && - typetraits::TypeInfo::type == detail::LibRapidType::Scalar && - typetraits::TypeInfo::type == detail::LibRapidType::Scalar, - int> = 0, - typename ST = typetraits::ScalarReturnType> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ST smoothStep(T t, Lower lowerEdge = 0, - Upper upperEdge = 1) { - ST tt = clamp((t - lowerEdge) / (upperEdge - lowerEdge), 0.0, 1.0); - return tt * tt * tt * (tt * (tt * T(6) - ST(15)) + ST(10)); - } + /// \brief Smoothly interpolate between two values + /// + /// This smooth step implementation is based on Ken Perlin's algorithm. + /// \f$ S(x)= \begin{cases}0 & x \leq 0 \\ 6 x^5-15 x^4+10 x^3 & 0 \leq x \leq 1 \\ 1 & 1 \leq + /// x\end{cases} \f$ + /// + /// This function allows you to specify a lower and upper edge, which can be used to scale + /// the range of inputs. + /// + /// \tparam T Type of \p t + /// \tparam Lower Type of \p lowerEdge + /// \tparam Upper Type of \p upperEdge + /// \param t Value to smooth step + /// \param lowerEdge At t=lowerEdge, the function returns 0 + /// \param upperEdge At t=upperEdge, the function returns 1 + /// \return \p t interpolated between \p lowerEdge and \p upperEdge + template::type == detail::LibRapidType::Scalar && + typetraits::TypeInfo::type == detail::LibRapidType::Scalar && + typetraits::TypeInfo::type == detail::LibRapidType::Scalar, + int> = 0, + typename ST = typetraits::ScalarReturnType> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ST smoothStep(T t, Lower lowerEdge = 0, + Upper upperEdge = 1) { + ST tt = clamp((t - lowerEdge) / (upperEdge - lowerEdge), 0.0, 1.0); + return tt * tt * tt * (tt * (tt * T(6) - ST(15)) + ST(10)); + } - /// \brief Compare the absolute and relative difference between two values, and return true if - /// they are close enough to be considered equal. - /// - /// \f$ \left| x-y \right| \leq \max\left( \mathrm{absTol}, \mathrm{relTol} \cdot \max\left( - /// \left| x \right|, \left| y \right| \right) \right) \f$ - /// - /// This is more precise than using an absolute tolerance alone, since it also takes into - /// account the magnitude of the values being compared. - /// - /// \tparam V1 Data type of the first value - /// \tparam V2 Data type of the second value - /// \tparam T Data type of the tolerance value - /// \tparam T Data type of the tolerance value - /// \param val1 First value - /// \param val2 Second value - /// \param absTol Absolute tolerance - /// \param relTol Relative tolerance - /// \return True if values are close - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool - isClose(const V1 &val1, const V2 &val2, const T &absTol = 1e-5, const T &relTol = 1e-5) { - return ::librapid::abs(val2 - val1) <= - ::librapid::max( - relTol * ::librapid::max(::librapid::abs(val1), ::librapid::abs(val2)), absTol); - } + /// \brief Compare the absolute and relative difference between two values, and return true if + /// they are close enough to be considered equal. + /// + /// \f$ \left| x-y \right| \leq \max\left( \mathrm{absTol}, \mathrm{relTol} \cdot \max\left( + /// \left| x \right|, \left| y \right| \right) \right) \f$ + /// + /// This is more precise than using an absolute tolerance alone, since it also takes into + /// account the magnitude of the values being compared. + /// + /// \tparam V1 Data type of the first value + /// \tparam V2 Data type of the second value + /// \tparam T Data type of the tolerance value + /// \tparam T Data type of the tolerance value + /// \param val1 First value + /// \param val2 Second value + /// \param absTol Absolute tolerance + /// \param relTol Relative tolerance + /// \return True if values are close + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool + isClose(const V1 &val1, const V2 &val2, const T &absTol = 1e-5, const T &relTol = 1e-5) { + return ::librapid::abs(val2 - val1) <= + ::librapid::max( + relTol * ::librapid::max(::librapid::abs(val1), ::librapid::abs(val2)), absTol); + } } // namespace librapid #endif // LIBRAPID_MATH_UTLIITY_FUNCTIONS_HPP \ No newline at end of file diff --git a/librapid/include/librapid/math/vector.hpp b/librapid/include/librapid/math/vector.hpp index b696bcc8..8dc25aac 100644 --- a/librapid/include/librapid/math/vector.hpp +++ b/librapid/include/librapid/math/vector.hpp @@ -5,19 +5,19 @@ #include "vectorImpl.hpp" namespace librapid { - using Vec2i = Vector; - using Vec3i = Vector; - using Vec4i = Vector; - using Vec2f = Vector; - using Vec3f = Vector; - using Vec4f = Vector; - using Vec2d = Vector; - using Vec3d = Vector; - using Vec4d = Vector; + using Vec2i = Vector; + using Vec3i = Vector; + using Vec4i = Vector; + using Vec2f = Vector; + using Vec3f = Vector; + using Vec4f = Vector; + using Vec2d = Vector; + using Vec3d = Vector; + using Vec4d = Vector; - using Vec2 = Vector; - using Vec3 = Vector; - using Vec4 = Vector; + using Vec2 = Vector; + using Vec3 = Vector; + using Vec4 = Vector; } // namespace librapid #endif // LIBRAPID_MATH_VECTOR_OLD_HPP \ No newline at end of file diff --git a/librapid/include/librapid/math/vectorForward.hpp b/librapid/include/librapid/math/vectorForward.hpp index cacc514a..3d829f86 100644 --- a/librapid/include/librapid/math/vectorForward.hpp +++ b/librapid/include/librapid/math/vectorForward.hpp @@ -2,124 +2,124 @@ #define LIBRAPID_MATH_VECTOR_FORWARD_HPP namespace librapid { - namespace vectorDetail { - template - struct GenericVectorStorage; - - template - struct SimdVectorStorage; - - template - struct VectorStorageType { - using type = std::conditional_t<(typetraits::TypeInfo::packetWidth > 1), - SimdVectorStorage, GenericVectorStorage>; - }; - - template - auto vectorStorageTypeMerger() { - using Scalar0 = typename typetraits::TypeInfo::Scalar; - using Scalar1 = typename typetraits::TypeInfo::Scalar; - static constexpr size_t packetWidth0 = typetraits::TypeInfo::packetWidth; - static constexpr size_t packetWidth1 = typetraits::TypeInfo::packetWidth; - if constexpr (packetWidth0 > 1 && packetWidth1 > 1) { - return SimdVectorStorage {}; - } else { - return GenericVectorStorage {}; - } - } - - template - using VectorStorage = typename VectorStorageType::type; - - template - using VectorStorageMerger = decltype(vectorStorageTypeMerger()); - - template - class VectorBase { - public: - using Scalar = typename typetraits::TypeInfo::Scalar; - using IndexType = typename typetraits::TypeInfo::IndexType; - using IndexTypeConst = typename typetraits::TypeInfo::IndexTypeConst; - using GetType = typename typetraits::TypeInfo::GetType; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const auto &derived() const { - return static_cast(*this); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto &derived() { - return static_cast(*this); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const { return derived(); } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE virtual IndexTypeConst - operator[](int64_t index) const { - return derived()[index]; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE virtual IndexType operator[](int64_t index) { - return derived()[index]; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst x() const { - return derived()[0]; - } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst y() const { - return derived()[1]; - } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst z() const { - return derived()[2]; - } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst w() const { - return derived()[3]; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType x() { return derived()[0]; } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType y() { return derived()[1]; } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType z() { return derived()[2]; } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType w() { return derived()[3]; } - - LIBRAPID_NODISCARD virtual std::string str(const std::string &format) const { - return derived().str(format); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE virtual GetType _get(size_t index) const { - return derived()._get(index); - } - }; - } // namespace vectorDetail - - template - class Vector; - - namespace vectorDetail { - template - struct BinaryVecOp; - - template - struct UnaryVecOp; - - template - LIBRAPID_ALWAYS_INLINE void assignImpl(Vector &dst, - const BinaryVecOp &src, - std::index_sequence); - - template - LIBRAPID_ALWAYS_INLINE void assignImpl(Vector &dst, - const UnaryVecOp &src, - std::index_sequence); - - template - LIBRAPID_ALWAYS_INLINE void assign(Vector &dst, - const BinaryVecOp &src); - - template - LIBRAPID_ALWAYS_INLINE void assign(Vector &dst, const UnaryVecOp &src); - } // namespace vectorDetail - - template - class Vector; + namespace vectorDetail { + template + struct GenericVectorStorage; + + template + struct SimdVectorStorage; + + template + struct VectorStorageType { + using type = std::conditional_t<(typetraits::TypeInfo::packetWidth > 1), + SimdVectorStorage, GenericVectorStorage>; + }; + + template + auto vectorStorageTypeMerger() { + using Scalar0 = typename typetraits::TypeInfo::Scalar; + using Scalar1 = typename typetraits::TypeInfo::Scalar; + static constexpr size_t packetWidth0 = typetraits::TypeInfo::packetWidth; + static constexpr size_t packetWidth1 = typetraits::TypeInfo::packetWidth; + if constexpr (packetWidth0 > 1 && packetWidth1 > 1) { + return SimdVectorStorage {}; + } else { + return GenericVectorStorage {}; + } + } + + template + using VectorStorage = typename VectorStorageType::type; + + template + using VectorStorageMerger = decltype(vectorStorageTypeMerger()); + + template + class VectorBase { + public: + using Scalar = typename typetraits::TypeInfo::Scalar; + using IndexType = typename typetraits::TypeInfo::IndexType; + using IndexTypeConst = typename typetraits::TypeInfo::IndexTypeConst; + using GetType = typename typetraits::TypeInfo::GetType; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const auto &derived() const { + return static_cast(*this); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto &derived() { + return static_cast(*this); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto eval() const { return derived(); } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE virtual IndexTypeConst + operator[](int64_t index) const { + return derived()[index]; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE virtual IndexType operator[](int64_t index) { + return derived()[index]; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst x() const { + return derived()[0]; + } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst y() const { + return derived()[1]; + } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst z() const { + return derived()[2]; + } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst w() const { + return derived()[3]; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType x() { return derived()[0]; } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType y() { return derived()[1]; } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType z() { return derived()[2]; } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType w() { return derived()[3]; } + + LIBRAPID_NODISCARD virtual std::string str(const std::string &format) const { + return derived().str(format); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE virtual GetType _get(size_t index) const { + return derived()._get(index); + } + }; + } // namespace vectorDetail + + template + class Vector; + + namespace vectorDetail { + template + struct BinaryVecOp; + + template + struct UnaryVecOp; + + template + LIBRAPID_ALWAYS_INLINE void assignImpl(Vector &dst, + const BinaryVecOp &src, + std::index_sequence); + + template + LIBRAPID_ALWAYS_INLINE void assignImpl(Vector &dst, + const UnaryVecOp &src, + std::index_sequence); + + template + LIBRAPID_ALWAYS_INLINE void assign(Vector &dst, + const BinaryVecOp &src); + + template + LIBRAPID_ALWAYS_INLINE void assign(Vector &dst, const UnaryVecOp &src); + } // namespace vectorDetail + + template + class Vector; } // namespace librapid #endif // LIBRAPID_MATH_VECTOR_FORWARD_HPP \ No newline at end of file diff --git a/librapid/include/librapid/math/vectorImpl.hpp b/librapid/include/librapid/math/vectorImpl.hpp index a53f2dea..0c4986b0 100644 --- a/librapid/include/librapid/math/vectorImpl.hpp +++ b/librapid/include/librapid/math/vectorImpl.hpp @@ -4,979 +4,979 @@ #include "../simd/simd.hpp" // Required for SIMD operations namespace librapid { - namespace typetraits { - template - struct IsVector : std::false_type {}; - - template - struct IsVector> : std::true_type {}; - - template - struct IsVector> : std::true_type {}; - - template - struct IsVector> : std::true_type {}; - - template - struct IsVector> : std::true_type {}; - - template - struct IsVector> : std::true_type {}; - - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; - using Scalar = T; - using IndexType = T &; - using IndexTypeConst = const T &; - using GetType = const T &; - - using StorageType = vectorDetail::GenericVectorStorage; - - static constexpr size_t length = N; - }; - - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; - using Scalar = T; - using Packet = typename TypeInfo::Packet; - using IndexType = Scalar &; - using IndexTypeConst = const Scalar &; - using GetType = const Packet &; - - using StorageType = vectorDetail::SimdVectorStorage; - - static constexpr size_t packetWidth = TypeInfo::packetWidth; - static constexpr size_t length = - (N + TypeInfo::packetWidth - 1) / TypeInfo::packetWidth; - }; - - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; - using Scalar = ScalarType; - static constexpr size_t dims = NumDims; - using StorageType = vectorDetail::VectorStorage; - static constexpr size_t length = StorageType::length; - using IndexTypeConst = typename StorageType::IndexTypeConst; - using IndexType = typename StorageType::IndexType; - using GetType = typename StorageType::GetType; - }; - - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; - using ScalarLHS = typename typetraits::TypeInfo::Scalar; - using ScalarRHS = typename typetraits::TypeInfo::Scalar; - using Scalar = decltype(Op()(std::declval(), std::declval())); - using IndexTypeConst = Scalar; - using IndexType = Scalar; - using StorageType = typename vectorDetail::VectorStorageMerger; - static constexpr size_t dims = StorageType::dims; - static constexpr size_t length = StorageType::length; - using GetType = typename std::decay_t; - }; - - template - struct TypeInfo> { - static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; - using Scalar = typename typetraits::TypeInfo::Scalar; - using IndexTypeConst = Scalar; - using IndexType = Scalar; - using StorageType = typename vectorDetail::VectorStorage; - static constexpr size_t dims = StorageType::dims; - static constexpr size_t length = StorageType::length; - using GetType = typename std::decay_t; - }; - } // namespace typetraits - - namespace vectorDetail { - template - void vectorStorageAssigner(std::index_sequence, GenericVectorStorage &dst, - const Args &...args) { - ((dst[Indices] = args), ...); - } - - template - void vectorStorageAssigner(std::index_sequence, SimdVectorStorage &dst, - const Args &...args) { - ((dst[Indices] = args), ...); - } - - template - void vectorStorageAssigner(std::index_sequence, GenericVectorStorage &dst, - const GenericVectorStorage &src) { - ((dst[Indices] = src[Indices]), ...); - } - - template - void vectorStorageAssigner(std::index_sequence, GenericVectorStorage &dst, - const SimdVectorStorage &src) { - ((dst[Indices] = src[Indices]), ...); - } - - template - void vectorStorageAssigner(std::index_sequence, SimdVectorStorage &dst, - const GenericVectorStorage &src) { - ((dst[Indices] = src[Indices]), ...); - } - - template - void vectorStorageAssigner(std::index_sequence, SimdVectorStorage &dst, - const SimdVectorStorage &src) { - ((dst[Indices] = src[Indices]), ...); - } - - template - struct GenericVectorStorage { - using Scalar = ScalarType; - static constexpr size_t dims = NumDims; - static constexpr size_t length = typetraits::TypeInfo::length; - using IndexType = typename typetraits::TypeInfo::IndexType; - using IndexTypeConst = - typename typetraits::TypeInfo::IndexTypeConst; - using GetType = typename typetraits::TypeInfo::GetType; - - // Scalar data[length] {}; - // std::array data {}; + namespace typetraits { + template + struct IsVector : std::false_type {}; + + template + struct IsVector> : std::true_type {}; + + template + struct IsVector> : std::true_type {}; + + template + struct IsVector> : std::true_type {}; + + template + struct IsVector> : std::true_type {}; + + template + struct IsVector> : std::true_type {}; + + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; + using Scalar = T; + using IndexType = T &; + using IndexTypeConst = const T &; + using GetType = const T &; + + using StorageType = vectorDetail::GenericVectorStorage; + + static constexpr size_t length = N; + }; + + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; + using Scalar = T; + using Packet = typename TypeInfo::Packet; + using IndexType = Scalar &; + using IndexTypeConst = const Scalar &; + using GetType = const Packet &; + + using StorageType = vectorDetail::SimdVectorStorage; + + static constexpr size_t packetWidth = TypeInfo::packetWidth; + static constexpr size_t length = + (N + TypeInfo::packetWidth - 1) / TypeInfo::packetWidth; + }; + + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; + using Scalar = ScalarType; + static constexpr size_t dims = NumDims; + using StorageType = vectorDetail::VectorStorage; + static constexpr size_t length = StorageType::length; + using IndexTypeConst = typename StorageType::IndexTypeConst; + using IndexType = typename StorageType::IndexType; + using GetType = typename StorageType::GetType; + }; + + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; + using ScalarLHS = typename typetraits::TypeInfo::Scalar; + using ScalarRHS = typename typetraits::TypeInfo::Scalar; + using Scalar = decltype(Op()(std::declval(), std::declval())); + using IndexTypeConst = Scalar; + using IndexType = Scalar; + using StorageType = typename vectorDetail::VectorStorageMerger; + static constexpr size_t dims = StorageType::dims; + static constexpr size_t length = StorageType::length; + using GetType = typename std::decay_t; + }; + + template + struct TypeInfo> { + static constexpr detail::LibRapidType type = detail::LibRapidType::Vector; + using Scalar = typename typetraits::TypeInfo::Scalar; + using IndexTypeConst = Scalar; + using IndexType = Scalar; + using StorageType = typename vectorDetail::VectorStorage; + static constexpr size_t dims = StorageType::dims; + static constexpr size_t length = StorageType::length; + using GetType = typename std::decay_t; + }; + } // namespace typetraits + + namespace vectorDetail { + template + void vectorStorageAssigner(std::index_sequence, GenericVectorStorage &dst, + const Args &...args) { + ((dst[Indices] = args), ...); + } + + template + void vectorStorageAssigner(std::index_sequence, SimdVectorStorage &dst, + const Args &...args) { + ((dst[Indices] = args), ...); + } + + template + void vectorStorageAssigner(std::index_sequence, GenericVectorStorage &dst, + const GenericVectorStorage &src) { + ((dst[Indices] = src[Indices]), ...); + } + + template + void vectorStorageAssigner(std::index_sequence, GenericVectorStorage &dst, + const SimdVectorStorage &src) { + ((dst[Indices] = src[Indices]), ...); + } + + template + void vectorStorageAssigner(std::index_sequence, SimdVectorStorage &dst, + const GenericVectorStorage &src) { + ((dst[Indices] = src[Indices]), ...); + } + + template + void vectorStorageAssigner(std::index_sequence, SimdVectorStorage &dst, + const SimdVectorStorage &src) { + ((dst[Indices] = src[Indices]), ...); + } + + template + struct GenericVectorStorage { + using Scalar = ScalarType; + static constexpr size_t dims = NumDims; + static constexpr size_t length = typetraits::TypeInfo::length; + using IndexType = typename typetraits::TypeInfo::IndexType; + using IndexTypeConst = + typename typetraits::TypeInfo::IndexTypeConst; + using GetType = typename typetraits::TypeInfo::GetType; + + // Scalar data[length] {}; + // std::array data {}; #if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE) - alignas(LIBRAPID_DEFAULT_MEM_ALIGN) std::array data {}; + alignas(LIBRAPID_DEFAULT_MEM_ALIGN) std::array data {}; #else - // No memory alignment on Apple platforms or if it is disabled - std::array data {}; + // No memory alignment on Apple platforms or if it is disabled + std::array data {}; #endif - template - GenericVectorStorage(Args... args) : data {args...} {} - - template - GenericVectorStorage(const T &other) { - for (size_t i = 0; i < length; ++i) { data[i] = other[i]; } - } - - template - GenericVectorStorage(const std::initializer_list &other) { - LIBRAPID_ASSERT(other.size() <= dims, - "Initializer list for Vector is too long ({} > {})", - other.size(), - dims); - const size_t minDims = (other.size() < dims) ? other.size() : dims; - for (size_t i = 0; i < minDims; ++i) { this->operator[](i) = *(other.begin() + i); } - } - - template - GenericVectorStorage(const std::vector &other) { - LIBRAPID_ASSERT(other.size() <= dims, - "Initializer list for Vector is too long ({} > {})", - other.size(), - dims); - const size_t minDims = (other.size() < dims) ? other.size() : dims; - for (size_t i = 0; i < minDims; ++i) { this->operator[](i) = other[i]; } - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst - operator[](int64_t index) const { - LIBRAPID_ASSERT(index >= 0 && index < dims, - "Index {} out of bounds for Vector of length {}", - index, - length); - return data[index]; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) { - LIBRAPID_ASSERT(index >= 0 && index < dims, - "Index {} out of bounds for Vector of length {}", - index, - length); - return data[index]; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar sum() const { - Scalar sum = Scalar(0); - for (size_t i = 0; i < dims; ++i) { sum += data[i]; } - return sum; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar sum2() const { - Scalar sum = Scalar(0); - for (size_t i = 0; i < dims; ++i) { sum += data[i] * data[i]; } - return sum; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const Scalar &_get(size_t index) const { - LIBRAPID_ASSERT(index >= 0 && index < dims, - "Index {} out of bounds for Vector of length {}", - index, - length); - return data[index]; - } - - LIBRAPID_ALWAYS_INLINE void _set(size_t index, const Scalar &value) { - LIBRAPID_ASSERT(index >= 0 && index < dims, - "Index {} out of bounds for Vector of length {}", - index, - length); - data[index] = value; - } - }; - - template - struct SimdVectorStorage { - using Scalar = ScalarType; - static constexpr size_t dims = NumDims; - using Packet = typename typetraits::TypeInfo::Packet; - static constexpr size_t packetWidth = typetraits::TypeInfo::packetWidth; - static constexpr size_t length = (dims + packetWidth - 1) / packetWidth; - - using IndexType = typename typetraits::TypeInfo::IndexType; - using IndexTypeConst = typename typetraits::TypeInfo::IndexTypeConst; - using GetType = typename typetraits::TypeInfo::GetType; - - static_assert(typetraits::TypeInfo::packetWidth > 1, - "SimdVectorStorage can only be used with SIMD types"); + template + GenericVectorStorage(Args... args) : data {args...} {} + + template + GenericVectorStorage(const T &other) { + for (size_t i = 0; i < length; ++i) { data[i] = other[i]; } + } + + template + GenericVectorStorage(const std::initializer_list &other) { + LIBRAPID_ASSERT(other.size() <= dims, + "Initializer list for Vector is too long ({} > {})", + other.size(), + dims); + const size_t minDims = (other.size() < dims) ? other.size() : dims; + for (size_t i = 0; i < minDims; ++i) { this->operator[](i) = *(other.begin() + i); } + } + + template + GenericVectorStorage(const std::vector &other) { + LIBRAPID_ASSERT(other.size() <= dims, + "Initializer list for Vector is too long ({} > {})", + other.size(), + dims); + const size_t minDims = (other.size() < dims) ? other.size() : dims; + for (size_t i = 0; i < minDims; ++i) { this->operator[](i) = other[i]; } + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst + operator[](int64_t index) const { + LIBRAPID_ASSERT(index >= 0 && index < dims, + "Index {} out of bounds for Vector of length {}", + index, + length); + return data[index]; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) { + LIBRAPID_ASSERT(index >= 0 && index < dims, + "Index {} out of bounds for Vector of length {}", + index, + length); + return data[index]; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar sum() const { + Scalar sum = Scalar(0); + for (size_t i = 0; i < dims; ++i) { sum += data[i]; } + return sum; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Scalar sum2() const { + Scalar sum = Scalar(0); + for (size_t i = 0; i < dims; ++i) { sum += data[i] * data[i]; } + return sum; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const Scalar &_get(size_t index) const { + LIBRAPID_ASSERT(index >= 0 && index < dims, + "Index {} out of bounds for Vector of length {}", + index, + length); + return data[index]; + } + + LIBRAPID_ALWAYS_INLINE void _set(size_t index, const Scalar &value) { + LIBRAPID_ASSERT(index >= 0 && index < dims, + "Index {} out of bounds for Vector of length {}", + index, + length); + data[index] = value; + } + }; + + template + struct SimdVectorStorage { + using Scalar = ScalarType; + static constexpr size_t dims = NumDims; + using Packet = typename typetraits::TypeInfo::Packet; + static constexpr size_t packetWidth = typetraits::TypeInfo::packetWidth; + static constexpr size_t length = (dims + packetWidth - 1) / packetWidth; + + using IndexType = typename typetraits::TypeInfo::IndexType; + using IndexTypeConst = typename typetraits::TypeInfo::IndexTypeConst; + using GetType = typename typetraits::TypeInfo::GetType; + + static_assert(typetraits::TypeInfo::packetWidth > 1, + "SimdVectorStorage can only be used with SIMD types"); #if defined(LIBRAPID_NATIVE_ARCH) && !defined(LIBRAPID_APPLE) - alignas(LIBRAPID_DEFAULT_MEM_ALIGN) std::array data {}; + alignas(LIBRAPID_DEFAULT_MEM_ALIGN) std::array data {}; #else - // No memory alignment on Apple platforms or if it is disabled - std::array data {}; + // No memory alignment on Apple platforms or if it is disabled + std::array data {}; #endif - template - explicit SimdVectorStorage(Args... args) { - constexpr size_t minLength = (sizeof...(Args) < dims) ? sizeof...(Args) : dims; - vectorDetail::vectorStorageAssigner( - std::make_index_sequence(), *this, args...); - } - - template - SimdVectorStorage(const std::initializer_list &other) { - LIBRAPID_ASSERT(other.size() <= dims, - "Initializer list for Vector is too long ({} > {})", - other.size(), - dims); - const size_t minDims = (other.size() < dims) ? other.size() : dims; - for (size_t i = 0; i < minDims; ++i) { this->operator[](i) = *(other.begin() + i); } - } - - template - SimdVectorStorage(const std::vector &other) { - LIBRAPID_ASSERT(other.size() <= dims, - "Initializer list for Vector is too long ({} > {})", - other.size(), - dims); - const size_t minDims = (other.size() < dims) ? other.size() : dims; - for (size_t i = 0; i < minDims; ++i) { this->operator[](i) = other[i]; } - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst - operator[](int64_t index) const { - LIBRAPID_ASSERT(index >= 0 && index < dims, - "Index {} out of bounds for Vector of length {}", - index, - length); - const int64_t packetIndex = index / packetWidth; - const int64_t elementIndex = index % packetWidth; - return data[packetIndex].get(elementIndex); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) { - LIBRAPID_ASSERT(index >= 0 && index < dims, - "Index {} out of bounds for Vector of length {}", - index, - length); - const int64_t packetIndex = index / packetWidth; - const int64_t elementIndex = index % packetWidth; - return data[packetIndex][elementIndex]; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sum() const -> Scalar { - Packet sum = Packet(0); - for (size_t i = 0; i < length; ++i) { sum += data[i]; } - return sum.sum(); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sum2() const -> Scalar { - Packet sum = Packet(0); - for (size_t i = 0; i < length; ++i) { sum += data[i] * data[i]; } - return sum.sum(); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const Packet &_get(size_t index) const { - LIBRAPID_ASSERT(index >= 0 && index < dims, - "Index {} out of bounds for Vector of length {}", - index, - length); - return data[index]; - } - - LIBRAPID_ALWAYS_INLINE void _set(size_t index, const Packet &value) { - LIBRAPID_ASSERT(index >= 0 && index < dims, - "Index {} out of bounds for Vector of length {}", - index, - length); - data[index] = value; - } - }; - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto - scalarSubscriptHelper(const T &val, size_t index) { - return val; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto - scalarSubscriptHelper(const Vector &val, size_t index) { - return val[index]; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto - scalarSubscriptHelper(const BinaryVecOp &val, size_t index) { - return val[index]; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto - scalarSubscriptHelper(const UnaryVecOp &val, size_t index) { - return val[index]; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto scalarGetHelper(const T &val, - size_t index) { - return val; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto constexpr scalarGetHelper( - const Vector &val, size_t index) { - return val._get(index); - } - - template - LIBRAPID_NODISCARD - LIBRAPID_ALWAYS_INLINE auto constexpr scalarGetHelper(Vector &val, - size_t index) { - return val._get(index); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto constexpr scalarGetHelper( - const BinaryVecOp &val, size_t index) { - return val._get(index); - } - - template - LIBRAPID_NODISCARD - LIBRAPID_ALWAYS_INLINE auto constexpr scalarGetHelper(const UnaryVecOp &val, - size_t index) { - return val._get(index); - } - - template - struct VectorScalarStorageExtractor { - using type = std::false_type; - }; - - template - struct VectorScalarStorageExtractor> { - using type = typename typetraits::TypeInfo>::StorageType; - }; - - template - struct VectorScalarStorageExtractor> { - using type = typename typetraits::TypeInfo>::StorageType; - }; - - template - struct VectorScalarStorageExtractor> { - using type = typename typetraits::TypeInfo>::StorageType; - }; - - template - struct VectorScalarDimensionExtractor { - static constexpr size_t value = 0; - }; - - template - struct VectorScalarDimensionExtractor> { - static constexpr size_t value = NumDims; - }; - - template - struct VectorScalarDimensionExtractor> { - static constexpr size_t value = BinaryVecOp::dims; - }; - - template - struct VectorScalarDimensionExtractor> { - static constexpr size_t value = UnaryVecOp::dims; - }; - } // namespace vectorDetail - - template - class Vector : public vectorDetail::VectorBase> { - public: - using Scalar = ScalarType; - static constexpr size_t dims = NumDims; - using StorageType = vectorDetail::VectorStorage; - static constexpr size_t length = StorageType::length; - using IndexTypeConst = typename StorageType::IndexTypeConst; - using IndexType = typename StorageType::IndexType; - using GetType = typename StorageType::GetType; - - Vector() = default; - Vector(const Vector &other) = default; - Vector(Vector &&other) noexcept = default; - - template - explicit Vector(Args... args) : m_data {args...} {} - - template - Vector(const std::initializer_list &args) : m_data(args) {} - - template - explicit Vector(const std::vector &args) : m_data(args) {} - - template - explicit Vector(const Vector &other) { - *this = other.template cast(); - } - - template - explicit Vector(const vectorDetail::BinaryVecOp &other) { - vectorDetail::assign(*this, other); - } - - template - explicit Vector(const vectorDetail::UnaryVecOp &other) { - vectorDetail::assign(*this, other); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto zero() -> Vector { return Vector(); } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto one() -> Vector { - Vector ret; - for (size_t i = 0; i < dims; ++i) { ret[i] = Scalar(1); } - return ret; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto full(Scalar val) -> Vector { - Vector ret; - for (size_t i = 0; i < dims; ++i) { ret[i] = val; } - return ret; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto random(Scalar lower = 0, - Scalar upper = 1) { - Vector ret; - for (size_t i = 0; i < dims; ++i) { ret[i] = ::librapid::random(lower, upper); } - return ret; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto - random(const Vector &lower = Vector::zero(), const Vector &upper = Vector::one()) { - Vector ret; - for (size_t i = 0; i < dims; ++i) { - ret[i] = ::librapid::random(lower[i], upper[i]); - } - return ret; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto fromPolar(Scalar r, Scalar theta) { - return Vector(::librapid::cos(theta) * r, ::librapid::sin(theta) * r); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto fromPolar(Scalar r, Scalar theta, - Scalar phi) { - return Vector(::librapid::cos(theta) * ::librapid::cos(phi) * r, - ::librapid::sin(theta) * ::librapid::cos(phi) * r, - ::librapid::sin(phi) * r); - } - - auto operator=(const Vector &other) -> Vector & = default; - auto operator=(Vector &&other) noexcept -> Vector & = default; - - template - auto operator=(const Vector &other) -> Vector & { - *this = other.template cast(); - return *this; - } - - template - auto operator=(const vectorDetail::BinaryVecOp &other) -> Vector & { - vectorDetail::assign(*this, other); - return *this; - } - - template - auto operator=(const vectorDetail::UnaryVecOp &other) -> Vector & { - vectorDetail::assign(*this, other); - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst - operator[](int64_t index) const override { - return m_data[index]; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) override { - return m_data[index]; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Vector eval() const { return *this; } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cast() const { - using NewVectorType = Vector; - constexpr size_t minDims = (NewVectorType::dims < dims) ? NewVectorType::dims : dims; - NewVectorType ret; - vectorDetail::vectorStorageAssigner( - std::make_index_sequence(), ret.storage(), m_data); - return ret; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE operator Vector() const { - return cast(); - } + template + explicit SimdVectorStorage(Args... args) { + constexpr size_t minLength = (sizeof...(Args) < dims) ? sizeof...(Args) : dims; + vectorDetail::vectorStorageAssigner( + std::make_index_sequence(), *this, args...); + } + + template + SimdVectorStorage(const std::initializer_list &other) { + LIBRAPID_ASSERT(other.size() <= dims, + "Initializer list for Vector is too long ({} > {})", + other.size(), + dims); + const size_t minDims = (other.size() < dims) ? other.size() : dims; + for (size_t i = 0; i < minDims; ++i) { this->operator[](i) = *(other.begin() + i); } + } + + template + SimdVectorStorage(const std::vector &other) { + LIBRAPID_ASSERT(other.size() <= dims, + "Initializer list for Vector is too long ({} > {})", + other.size(), + dims); + const size_t minDims = (other.size() < dims) ? other.size() : dims; + for (size_t i = 0; i < minDims; ++i) { this->operator[](i) = other[i]; } + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst + operator[](int64_t index) const { + LIBRAPID_ASSERT(index >= 0 && index < dims, + "Index {} out of bounds for Vector of length {}", + index, + length); + const int64_t packetIndex = index / packetWidth; + const int64_t elementIndex = index % packetWidth; + return data[packetIndex].get(elementIndex); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) { + LIBRAPID_ASSERT(index >= 0 && index < dims, + "Index {} out of bounds for Vector of length {}", + index, + length); + const int64_t packetIndex = index / packetWidth; + const int64_t elementIndex = index % packetWidth; + return data[packetIndex][elementIndex]; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sum() const -> Scalar { + Packet sum = Packet(0); + for (size_t i = 0; i < length; ++i) { sum += data[i]; } + return sum.sum(); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sum2() const -> Scalar { + Packet sum = Packet(0); + for (size_t i = 0; i < length; ++i) { sum += data[i] * data[i]; } + return sum.sum(); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const Packet &_get(size_t index) const { + LIBRAPID_ASSERT(index >= 0 && index < dims, + "Index {} out of bounds for Vector of length {}", + index, + length); + return data[index]; + } + + LIBRAPID_ALWAYS_INLINE void _set(size_t index, const Packet &value) { + LIBRAPID_ASSERT(index >= 0 && index < dims, + "Index {} out of bounds for Vector of length {}", + index, + length); + data[index] = value; + } + }; + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto + scalarSubscriptHelper(const T &val, size_t index) { + return val; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto + scalarSubscriptHelper(const Vector &val, size_t index) { + return val[index]; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto + scalarSubscriptHelper(const BinaryVecOp &val, size_t index) { + return val[index]; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto + scalarSubscriptHelper(const UnaryVecOp &val, size_t index) { + return val[index]; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr auto scalarGetHelper(const T &val, + size_t index) { + return val; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto constexpr scalarGetHelper( + const Vector &val, size_t index) { + return val._get(index); + } + + template + LIBRAPID_NODISCARD + LIBRAPID_ALWAYS_INLINE auto constexpr scalarGetHelper(Vector &val, + size_t index) { + return val._get(index); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto constexpr scalarGetHelper( + const BinaryVecOp &val, size_t index) { + return val._get(index); + } + + template + LIBRAPID_NODISCARD + LIBRAPID_ALWAYS_INLINE auto constexpr scalarGetHelper(const UnaryVecOp &val, + size_t index) { + return val._get(index); + } + + template + struct VectorScalarStorageExtractor { + using type = std::false_type; + }; + + template + struct VectorScalarStorageExtractor> { + using type = typename typetraits::TypeInfo>::StorageType; + }; + + template + struct VectorScalarStorageExtractor> { + using type = typename typetraits::TypeInfo>::StorageType; + }; + + template + struct VectorScalarStorageExtractor> { + using type = typename typetraits::TypeInfo>::StorageType; + }; + + template + struct VectorScalarDimensionExtractor { + static constexpr size_t value = 0; + }; + + template + struct VectorScalarDimensionExtractor> { + static constexpr size_t value = NumDims; + }; + + template + struct VectorScalarDimensionExtractor> { + static constexpr size_t value = BinaryVecOp::dims; + }; + + template + struct VectorScalarDimensionExtractor> { + static constexpr size_t value = UnaryVecOp::dims; + }; + } // namespace vectorDetail + + template + class Vector : public vectorDetail::VectorBase> { + public: + using Scalar = ScalarType; + static constexpr size_t dims = NumDims; + using StorageType = vectorDetail::VectorStorage; + static constexpr size_t length = StorageType::length; + using IndexTypeConst = typename StorageType::IndexTypeConst; + using IndexType = typename StorageType::IndexType; + using GetType = typename StorageType::GetType; + + Vector() = default; + Vector(const Vector &other) = default; + Vector(Vector &&other) noexcept = default; + + template + explicit Vector(Args... args) : m_data {args...} {} + + template + Vector(const std::initializer_list &args) : m_data(args) {} + + template + explicit Vector(const std::vector &args) : m_data(args) {} + + template + explicit Vector(const Vector &other) { + *this = other.template cast(); + } + + template + explicit Vector(const vectorDetail::BinaryVecOp &other) { + vectorDetail::assign(*this, other); + } + + template + explicit Vector(const vectorDetail::UnaryVecOp &other) { + vectorDetail::assign(*this, other); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto zero() -> Vector { return Vector(); } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto one() -> Vector { + Vector ret; + for (size_t i = 0; i < dims; ++i) { ret[i] = Scalar(1); } + return ret; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto full(Scalar val) -> Vector { + Vector ret; + for (size_t i = 0; i < dims; ++i) { ret[i] = val; } + return ret; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto random(Scalar lower = 0, + Scalar upper = 1) { + Vector ret; + for (size_t i = 0; i < dims; ++i) { ret[i] = ::librapid::random(lower, upper); } + return ret; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto + random(const Vector &lower = Vector::zero(), const Vector &upper = Vector::one()) { + Vector ret; + for (size_t i = 0; i < dims; ++i) { + ret[i] = ::librapid::random(lower[i], upper[i]); + } + return ret; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto fromPolar(Scalar r, Scalar theta) { + return Vector(::librapid::cos(theta) * r, ::librapid::sin(theta) * r); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE static auto fromPolar(Scalar r, Scalar theta, + Scalar phi) { + return Vector(::librapid::cos(theta) * ::librapid::cos(phi) * r, + ::librapid::sin(theta) * ::librapid::cos(phi) * r, + ::librapid::sin(phi) * r); + } + + auto operator=(const Vector &other) -> Vector & = default; + auto operator=(Vector &&other) noexcept -> Vector & = default; + + template + auto operator=(const Vector &other) -> Vector & { + *this = other.template cast(); + return *this; + } + + template + auto operator=(const vectorDetail::BinaryVecOp &other) -> Vector & { + vectorDetail::assign(*this, other); + return *this; + } + + template + auto operator=(const vectorDetail::UnaryVecOp &other) -> Vector & { + vectorDetail::assign(*this, other); + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexTypeConst + operator[](int64_t index) const override { + return m_data[index]; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) override { + return m_data[index]; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Vector eval() const { return *this; } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cast() const { + using NewVectorType = Vector; + constexpr size_t minDims = (NewVectorType::dims < dims) ? NewVectorType::dims : dims; + NewVectorType ret; + vectorDetail::vectorStorageAssigner( + std::make_index_sequence(), ret.storage(), m_data); + return ret; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE operator Vector() const { + return cast(); + } #define LIBRAPID_VECTOR_INPLACE_OP(OP_) \ - template \ - LIBRAPID_ALWAYS_INLINE Vector &operator OP_##=(const Other &other) { \ - return *this = *this OP_ other; \ - } - - LIBRAPID_VECTOR_INPLACE_OP(+) - LIBRAPID_VECTOR_INPLACE_OP(-) - LIBRAPID_VECTOR_INPLACE_OP(*) - LIBRAPID_VECTOR_INPLACE_OP(/) - LIBRAPID_VECTOR_INPLACE_OP(%) - LIBRAPID_VECTOR_INPLACE_OP(&) - LIBRAPID_VECTOR_INPLACE_OP(|) - LIBRAPID_VECTOR_INPLACE_OP(^) - LIBRAPID_VECTOR_INPLACE_OP(<<) - LIBRAPID_VECTOR_INPLACE_OP(>>) - - LIBRAPID_NODISCARD std::string str(const std::string &format) const override { - std::string ret = "("; - for (size_t i = 0; i < dims; ++i) { - ret += fmt::format(format, m_data[i]); - if (i != dims - 1) { ret += ", "; } - } - - return ret + ")"; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const StorageType &storage() const { - return m_data; - } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE StorageType &storage() { return m_data; } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE GetType _get(size_t index) const override { - return m_data._get(index); - } - - LIBRAPID_ALWAYS_INLINE void _set(size_t index, const GetType &value) { - m_data._set(index, value); - } - - private: - StorageType m_data; - }; - - namespace vectorDetail { - template - struct BinaryVecOp : public VectorBase> { - using Scalar = typename typetraits::TypeInfo::Scalar; - using StorageLHS = typename VectorScalarStorageExtractor::type; - using StorageRHS = typename VectorScalarStorageExtractor::type; - using StorageType = VectorStorageMerger; - static constexpr size_t dims = StorageType::dims; - static constexpr size_t length = StorageType::length; - using IndexTypeConst = typename typetraits::TypeInfo::IndexTypeConst; - using IndexType = typename typetraits::TypeInfo::IndexType; - using GetType = typename typetraits::TypeInfo::GetType; - - LHS left; - RHS right; - Op op; - - BinaryVecOp() = default; - BinaryVecOp(const BinaryVecOp &) = default; - BinaryVecOp(BinaryVecOp &&) noexcept = default; - - BinaryVecOp(const LHS &lhs, const RHS &rhs, const Op &op) : - left(lhs), right(rhs), op(op) {} - - auto operator=(const BinaryVecOp &) -> BinaryVecOp & = default; - auto operator=(BinaryVecOp &&) noexcept -> BinaryVecOp & = default; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Vector eval() const { - Vector result(*this); - return result; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cast() const { - return eval().template cast(); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE operator Vector() const { - return cast(); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType - operator[](int64_t index) const override { - return op(scalarSubscriptHelper(left, index), scalarSubscriptHelper(right, index)); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) override { - return op(scalarSubscriptHelper(left, index), scalarSubscriptHelper(right, index)); - } - - LIBRAPID_NODISCARD std::string str(const std::string &format) const override { - return eval().str(format); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE GetType _get(size_t index) const override { - return op(scalarGetHelper(left, index), scalarGetHelper(right, index)); - } - }; - - template - struct UnaryVecOp : public VectorBase> { - using Scalar = typename typetraits::TypeInfo::Scalar; - using StorageType = typename VectorScalarStorageExtractor::type; - static constexpr size_t dims = StorageType::dims; - static constexpr size_t length = StorageType::length; - using IndexTypeConst = typename typetraits::TypeInfo::IndexTypeConst; - using IndexType = typename typetraits::TypeInfo::IndexType; - using GetType = typename typetraits::TypeInfo::GetType; - - Val val; - Op op; - - UnaryVecOp() = default; - UnaryVecOp(const UnaryVecOp &) = default; - UnaryVecOp(UnaryVecOp &&) noexcept = default; - - UnaryVecOp(const Val &value, const Op &op) : val(value), op(op) {} - - auto operator=(const UnaryVecOp &) -> UnaryVecOp & = default; - auto operator=(UnaryVecOp &&) noexcept -> UnaryVecOp & = default; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Vector eval() const { - Vector result(*this); - return result; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cast() const { - return eval().template cast(); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE operator Vector() const { - return cast(); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType - operator[](int64_t index) const override { - return op(scalarSubscriptHelper(val, index)); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) override { - return op(scalarSubscriptHelper(val, index)); - } - - LIBRAPID_NODISCARD std::string str(const std::string &format) const override { - return eval().str(format); - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE GetType _get(size_t index) const override { - return op(scalarGetHelper(val, index)); - } - }; - - template - LIBRAPID_ALWAYS_INLINE void assignImpl(Vector &dst, - const BinaryVecOp &src, - std::index_sequence) { - ((dst._set( - Indices, - src.op(scalarGetHelper(src.left, Indices), scalarGetHelper(src.right, Indices)))), - ...); - } - - template - LIBRAPID_ALWAYS_INLINE void assignImpl(Vector &dst, - const UnaryVecOp &src, - std::index_sequence) { - ((dst._set(Indices, src.op(scalarGetHelper(src.val, Indices)))), ...); - } - - template - LIBRAPID_ALWAYS_INLINE void assign(Vector &dst, - const BinaryVecOp &src) { - using ScalarDst = typename typetraits::TypeInfo>::Scalar; - using ScalarSrc = typename typetraits::TypeInfo>::Scalar; - if constexpr (std::is_same_v) { - constexpr size_t lengthDst = Vector::length; - constexpr size_t lengthSrc = BinaryVecOp::length; - constexpr size_t minLength = (lengthDst < lengthSrc) ? lengthDst : lengthSrc; - assignImpl(dst, src, std::make_index_sequence()); - } else { - dst = src.template cast(); - } - } - - template - LIBRAPID_ALWAYS_INLINE void assign(Vector &dst, const UnaryVecOp &src) { - using ScalarDst = typename typetraits::TypeInfo>::Scalar; - using ScalarSrc = typename typetraits::TypeInfo>::Scalar; - if constexpr (std::is_same_v) { - constexpr size_t lengthDst = Vector::length; - constexpr size_t lengthSrc = UnaryVecOp::length; - constexpr size_t minLength = (lengthDst < lengthSrc) ? lengthDst : lengthSrc; - assignImpl(dst, src, std::make_index_sequence()); - } else { - dst = src.template cast(); - } - } - - template - constexpr auto scalarExtractor(const T &val) { - return val; - } - - // template - // constexpr auto scalarExtractor(const Vc_1::Detail::ElementReference &val) { - // using Scalar = typename Vc_1::Detail::ElementReference::value_type; - // return static_cast(val); - // } - - template - constexpr auto scalarVectorCaster(const T &val) { - return static_cast(val); - } - - template - constexpr auto scalarVectorCaster(const Vector &val) { - return val.template cast(); - } - - template - constexpr auto scalarVectorCaster(const BinaryVecOp &val) { - return val.template cast(); - } - - template - constexpr auto scalarVectorCaster(const UnaryVecOp &val) { - return val.template cast(); - } + template \ + LIBRAPID_ALWAYS_INLINE Vector &operator OP_##=(const Other &other) { \ + return *this = *this OP_ other; \ + } + + LIBRAPID_VECTOR_INPLACE_OP(+) + LIBRAPID_VECTOR_INPLACE_OP(-) + LIBRAPID_VECTOR_INPLACE_OP(*) + LIBRAPID_VECTOR_INPLACE_OP(/) + LIBRAPID_VECTOR_INPLACE_OP(%) + LIBRAPID_VECTOR_INPLACE_OP(&) + LIBRAPID_VECTOR_INPLACE_OP(|) + LIBRAPID_VECTOR_INPLACE_OP(^) + LIBRAPID_VECTOR_INPLACE_OP(<<) + LIBRAPID_VECTOR_INPLACE_OP(>>) + + LIBRAPID_NODISCARD std::string str(const std::string &format) const override { + std::string ret = "("; + for (size_t i = 0; i < dims; ++i) { + ret += fmt::format(format, m_data[i]); + if (i != dims - 1) { ret += ", "; } + } + + return ret + ")"; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const StorageType &storage() const { + return m_data; + } + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE StorageType &storage() { return m_data; } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE GetType _get(size_t index) const override { + return m_data._get(index); + } + + LIBRAPID_ALWAYS_INLINE void _set(size_t index, const GetType &value) { + m_data._set(index, value); + } + + private: + StorageType m_data; + }; + + namespace vectorDetail { + template + struct BinaryVecOp : public VectorBase> { + using Scalar = typename typetraits::TypeInfo::Scalar; + using StorageLHS = typename VectorScalarStorageExtractor::type; + using StorageRHS = typename VectorScalarStorageExtractor::type; + using StorageType = VectorStorageMerger; + static constexpr size_t dims = StorageType::dims; + static constexpr size_t length = StorageType::length; + using IndexTypeConst = typename typetraits::TypeInfo::IndexTypeConst; + using IndexType = typename typetraits::TypeInfo::IndexType; + using GetType = typename typetraits::TypeInfo::GetType; + + LHS left; + RHS right; + Op op; + + BinaryVecOp() = default; + BinaryVecOp(const BinaryVecOp &) = default; + BinaryVecOp(BinaryVecOp &&) noexcept = default; + + BinaryVecOp(const LHS &lhs, const RHS &rhs, const Op &op) : + left(lhs), right(rhs), op(op) {} + + auto operator=(const BinaryVecOp &) -> BinaryVecOp & = default; + auto operator=(BinaryVecOp &&) noexcept -> BinaryVecOp & = default; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Vector eval() const { + Vector result(*this); + return result; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cast() const { + return eval().template cast(); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE operator Vector() const { + return cast(); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType + operator[](int64_t index) const override { + return op(scalarSubscriptHelper(left, index), scalarSubscriptHelper(right, index)); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) override { + return op(scalarSubscriptHelper(left, index), scalarSubscriptHelper(right, index)); + } + + LIBRAPID_NODISCARD std::string str(const std::string &format) const override { + return eval().str(format); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE GetType _get(size_t index) const override { + return op(scalarGetHelper(left, index), scalarGetHelper(right, index)); + } + }; + + template + struct UnaryVecOp : public VectorBase> { + using Scalar = typename typetraits::TypeInfo::Scalar; + using StorageType = typename VectorScalarStorageExtractor::type; + static constexpr size_t dims = StorageType::dims; + static constexpr size_t length = StorageType::length; + using IndexTypeConst = typename typetraits::TypeInfo::IndexTypeConst; + using IndexType = typename typetraits::TypeInfo::IndexType; + using GetType = typename typetraits::TypeInfo::GetType; + + Val val; + Op op; + + UnaryVecOp() = default; + UnaryVecOp(const UnaryVecOp &) = default; + UnaryVecOp(UnaryVecOp &&) noexcept = default; + + UnaryVecOp(const Val &value, const Op &op) : val(value), op(op) {} + + auto operator=(const UnaryVecOp &) -> UnaryVecOp & = default; + auto operator=(UnaryVecOp &&) noexcept -> UnaryVecOp & = default; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Vector eval() const { + Vector result(*this); + return result; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cast() const { + return eval().template cast(); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE operator Vector() const { + return cast(); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType + operator[](int64_t index) const override { + return op(scalarSubscriptHelper(val, index)); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE IndexType operator[](int64_t index) override { + return op(scalarSubscriptHelper(val, index)); + } + + LIBRAPID_NODISCARD std::string str(const std::string &format) const override { + return eval().str(format); + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE GetType _get(size_t index) const override { + return op(scalarGetHelper(val, index)); + } + }; + + template + LIBRAPID_ALWAYS_INLINE void assignImpl(Vector &dst, + const BinaryVecOp &src, + std::index_sequence) { + ((dst._set( + Indices, + src.op(scalarGetHelper(src.left, Indices), scalarGetHelper(src.right, Indices)))), + ...); + } + + template + LIBRAPID_ALWAYS_INLINE void assignImpl(Vector &dst, + const UnaryVecOp &src, + std::index_sequence) { + ((dst._set(Indices, src.op(scalarGetHelper(src.val, Indices)))), ...); + } + + template + LIBRAPID_ALWAYS_INLINE void assign(Vector &dst, + const BinaryVecOp &src) { + using ScalarDst = typename typetraits::TypeInfo>::Scalar; + using ScalarSrc = typename typetraits::TypeInfo>::Scalar; + if constexpr (std::is_same_v) { + constexpr size_t lengthDst = Vector::length; + constexpr size_t lengthSrc = BinaryVecOp::length; + constexpr size_t minLength = (lengthDst < lengthSrc) ? lengthDst : lengthSrc; + assignImpl(dst, src, std::make_index_sequence()); + } else { + dst = src.template cast(); + } + } + + template + LIBRAPID_ALWAYS_INLINE void assign(Vector &dst, const UnaryVecOp &src) { + using ScalarDst = typename typetraits::TypeInfo>::Scalar; + using ScalarSrc = typename typetraits::TypeInfo>::Scalar; + if constexpr (std::is_same_v) { + constexpr size_t lengthDst = Vector::length; + constexpr size_t lengthSrc = UnaryVecOp::length; + constexpr size_t minLength = (lengthDst < lengthSrc) ? lengthDst : lengthSrc; + assignImpl(dst, src, std::make_index_sequence()); + } else { + dst = src.template cast(); + } + } + + template + constexpr auto scalarExtractor(const T &val) { + return val; + } + + // template + // constexpr auto scalarExtractor(const Vc_1::Detail::ElementReference &val) { + // using Scalar = typename Vc_1::Detail::ElementReference::value_type; + // return static_cast(val); + // } + + template + constexpr auto scalarVectorCaster(const T &val) { + return static_cast(val); + } + + template + constexpr auto scalarVectorCaster(const Vector &val) { + return val.template cast(); + } + + template + constexpr auto scalarVectorCaster(const BinaryVecOp &val) { + return val.template cast(); + } + + template + constexpr auto scalarVectorCaster(const UnaryVecOp &val) { + return val.template cast(); + } #define VECTOR_BINARY_OP(NAME_, OP_) \ - struct NAME_ { \ - template \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const A &a, const B &b) const { \ - using namespace ::librapid::vectorDetail; \ - return scalarExtractor(a) OP_ scalarExtractor(b); \ - } \ - }; \ + struct NAME_ { \ + template \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const A &a, const B &b) const { \ + using namespace ::librapid::vectorDetail; \ + return scalarExtractor(a) OP_ scalarExtractor(b); \ + } \ + }; \ \ - template::value || \ - typetraits::IsVector::value, \ - int> = 0> \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator OP_(const LHS &lhs, const RHS &rhs) { \ - using namespace ::librapid::vectorDetail; \ - using ScalarLeft = typename typetraits::TypeInfo::Scalar; \ - using ScalarRight = typename typetraits::TypeInfo::Scalar; \ - using Op = NAME_; \ + template::value || \ + typetraits::IsVector::value, \ + int> = 0> \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator OP_(const LHS &lhs, const RHS &rhs) { \ + using namespace ::librapid::vectorDetail; \ + using ScalarLeft = typename typetraits::TypeInfo::Scalar; \ + using ScalarRight = typename typetraits::TypeInfo::Scalar; \ + using Op = NAME_; \ \ - if constexpr (std::is_same_v) { \ - return BinaryVecOp {lhs, rhs, Op {}}; \ - } else { \ - using Scalar = decltype(std::declval() + std::declval()); \ - constexpr size_t dimsLhs = VectorScalarDimensionExtractor::value; \ - constexpr size_t dimsRhs = VectorScalarDimensionExtractor::value; \ - constexpr size_t maxDims = (dimsLhs > dimsRhs) ? dimsLhs : dimsRhs; \ - return BinaryVecOp {scalarVectorCaster(lhs), \ - scalarVectorCaster(rhs), \ - Op {}}; \ - } \ - } + if constexpr (std::is_same_v) { \ + return BinaryVecOp {lhs, rhs, Op {}}; \ + } else { \ + using Scalar = decltype(std::declval() + std::declval()); \ + constexpr size_t dimsLhs = VectorScalarDimensionExtractor::value; \ + constexpr size_t dimsRhs = VectorScalarDimensionExtractor::value; \ + constexpr size_t maxDims = (dimsLhs > dimsRhs) ? dimsLhs : dimsRhs; \ + return BinaryVecOp {scalarVectorCaster(lhs), \ + scalarVectorCaster(rhs), \ + Op {}}; \ + } \ + } #define VECTOR_UNARY_OP(NAME_, OP_NAME_, OP_) \ - struct NAME_ { \ - template::value, int> = 0> \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const Val &val) const { \ - using namespace ::librapid::vectorDetail; \ - return OP_(scalarExtractor(val)); \ - } \ - }; \ + struct NAME_ { \ + template::value, int> = 0> \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const Val &val) const { \ + using namespace ::librapid::vectorDetail; \ + return OP_(scalarExtractor(val)); \ + } \ + }; \ \ - template::value, int> = 0> \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto OP_NAME_(const Val &val) { \ - using namespace ::librapid::vectorDetail; \ - using Op = NAME_; \ - return ::librapid::vectorDetail::UnaryVecOp {val, Op {}}; \ - } - - VECTOR_BINARY_OP(Add, +); - VECTOR_BINARY_OP(Sub, -); - VECTOR_BINARY_OP(Mul, *); - VECTOR_BINARY_OP(Div, /); - VECTOR_BINARY_OP(Mod, %); - VECTOR_BINARY_OP(BitAnd, &); - VECTOR_BINARY_OP(BitOr, |); - VECTOR_BINARY_OP(BitXor, ^); - VECTOR_BINARY_OP(LeftShift, <<); - VECTOR_BINARY_OP(RightShift, >>); - VECTOR_BINARY_OP(And, &&); - VECTOR_BINARY_OP(Or, ||); - VECTOR_BINARY_OP(LessThan, <); - VECTOR_BINARY_OP(GreaterThan, >); - VECTOR_BINARY_OP(LessThanEqual, <=); - VECTOR_BINARY_OP(GreaterThanEqual, >=); - VECTOR_BINARY_OP(Equal, ==); - VECTOR_BINARY_OP(NotEqual, !=); - - VECTOR_UNARY_OP(Not, operator!, !); - VECTOR_UNARY_OP(BitNot, operator~, ~); - VECTOR_UNARY_OP(Negate, operator-, -); - VECTOR_UNARY_OP(Plus, operator+, +); + template::value, int> = 0> \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto OP_NAME_(const Val &val) { \ + using namespace ::librapid::vectorDetail; \ + using Op = NAME_; \ + return ::librapid::vectorDetail::UnaryVecOp {val, Op {}}; \ + } + + VECTOR_BINARY_OP(Add, +); + VECTOR_BINARY_OP(Sub, -); + VECTOR_BINARY_OP(Mul, *); + VECTOR_BINARY_OP(Div, /); + VECTOR_BINARY_OP(Mod, %); + VECTOR_BINARY_OP(BitAnd, &); + VECTOR_BINARY_OP(BitOr, |); + VECTOR_BINARY_OP(BitXor, ^); + VECTOR_BINARY_OP(LeftShift, <<); + VECTOR_BINARY_OP(RightShift, >>); + VECTOR_BINARY_OP(And, &&); + VECTOR_BINARY_OP(Or, ||); + VECTOR_BINARY_OP(LessThan, <); + VECTOR_BINARY_OP(GreaterThan, >); + VECTOR_BINARY_OP(LessThanEqual, <=); + VECTOR_BINARY_OP(GreaterThanEqual, >=); + VECTOR_BINARY_OP(Equal, ==); + VECTOR_BINARY_OP(NotEqual, !=); + + VECTOR_UNARY_OP(Not, operator!, !); + VECTOR_UNARY_OP(BitNot, operator~, ~); + VECTOR_UNARY_OP(Negate, operator-, -); + VECTOR_UNARY_OP(Plus, operator+, +); #define VECTOR_FUNC_STRUCT_DEF(NAME_) \ - struct Vector_##NAME_ { \ - template \ - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const Val &val) const { \ - return ::librapid::NAME_(scalarExtractor(val)); \ - } \ - } - - VECTOR_FUNC_STRUCT_DEF(sin); - VECTOR_FUNC_STRUCT_DEF(cos); - VECTOR_FUNC_STRUCT_DEF(tan); - VECTOR_FUNC_STRUCT_DEF(asin); - VECTOR_FUNC_STRUCT_DEF(acos); - VECTOR_FUNC_STRUCT_DEF(atan); - VECTOR_FUNC_STRUCT_DEF(sinh); - VECTOR_FUNC_STRUCT_DEF(cosh); - VECTOR_FUNC_STRUCT_DEF(tanh); - VECTOR_FUNC_STRUCT_DEF(asinh); - VECTOR_FUNC_STRUCT_DEF(acosh); - VECTOR_FUNC_STRUCT_DEF(atanh); - VECTOR_FUNC_STRUCT_DEF(exp); - VECTOR_FUNC_STRUCT_DEF(exp2); - VECTOR_FUNC_STRUCT_DEF(exp10); - VECTOR_FUNC_STRUCT_DEF(log); - VECTOR_FUNC_STRUCT_DEF(log2); - VECTOR_FUNC_STRUCT_DEF(log10); - VECTOR_FUNC_STRUCT_DEF(sqrt); - VECTOR_FUNC_STRUCT_DEF(cbrt); - } // namespace vectorDetail + struct Vector_##NAME_ { \ + template \ + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const Val &val) const { \ + return ::librapid::NAME_(scalarExtractor(val)); \ + } \ + } + + VECTOR_FUNC_STRUCT_DEF(sin); + VECTOR_FUNC_STRUCT_DEF(cos); + VECTOR_FUNC_STRUCT_DEF(tan); + VECTOR_FUNC_STRUCT_DEF(asin); + VECTOR_FUNC_STRUCT_DEF(acos); + VECTOR_FUNC_STRUCT_DEF(atan); + VECTOR_FUNC_STRUCT_DEF(sinh); + VECTOR_FUNC_STRUCT_DEF(cosh); + VECTOR_FUNC_STRUCT_DEF(tanh); + VECTOR_FUNC_STRUCT_DEF(asinh); + VECTOR_FUNC_STRUCT_DEF(acosh); + VECTOR_FUNC_STRUCT_DEF(atanh); + VECTOR_FUNC_STRUCT_DEF(exp); + VECTOR_FUNC_STRUCT_DEF(exp2); + VECTOR_FUNC_STRUCT_DEF(exp10); + VECTOR_FUNC_STRUCT_DEF(log); + VECTOR_FUNC_STRUCT_DEF(log2); + VECTOR_FUNC_STRUCT_DEF(log10); + VECTOR_FUNC_STRUCT_DEF(sqrt); + VECTOR_FUNC_STRUCT_DEF(cbrt); + } // namespace vectorDetail #define VECTOR_FUNC_IMPL_DEF(NAME_) \ - template \ - auto NAME_(const Vector &vec) { \ - return vectorDetail::UnaryVecOp {vec, vectorDetail::Vector_##NAME_ {}}; \ - } \ + template \ + auto NAME_(const Vector &vec) { \ + return vectorDetail::UnaryVecOp {vec, vectorDetail::Vector_##NAME_ {}}; \ + } \ \ - template \ - auto NAME_(const vectorDetail::BinaryVecOp &vec) { \ - return vectorDetail::UnaryVecOp {vec, vectorDetail::Vector_##NAME_ {}}; \ - } \ + template \ + auto NAME_(const vectorDetail::BinaryVecOp &vec) { \ + return vectorDetail::UnaryVecOp {vec, vectorDetail::Vector_##NAME_ {}}; \ + } \ \ - template \ - auto NAME_(const vectorDetail::UnaryVecOp &vec) { \ - return vectorDetail::UnaryVecOp {vec, vectorDetail::Vector_##NAME_ {}}; \ - } - - VECTOR_FUNC_IMPL_DEF(sin) - VECTOR_FUNC_IMPL_DEF(cos) - VECTOR_FUNC_IMPL_DEF(tan) - VECTOR_FUNC_IMPL_DEF(asin) - VECTOR_FUNC_IMPL_DEF(acos) - VECTOR_FUNC_IMPL_DEF(atan) - VECTOR_FUNC_IMPL_DEF(sinh) - VECTOR_FUNC_IMPL_DEF(cosh) - VECTOR_FUNC_IMPL_DEF(tanh) - VECTOR_FUNC_IMPL_DEF(asinh) - VECTOR_FUNC_IMPL_DEF(acosh) - VECTOR_FUNC_IMPL_DEF(atanh) - VECTOR_FUNC_IMPL_DEF(exp) - VECTOR_FUNC_IMPL_DEF(exp2) - VECTOR_FUNC_IMPL_DEF(exp10) - VECTOR_FUNC_IMPL_DEF(log) - VECTOR_FUNC_IMPL_DEF(log2) - VECTOR_FUNC_IMPL_DEF(log10) - VECTOR_FUNC_IMPL_DEF(sqrt) - VECTOR_FUNC_IMPL_DEF(cbrt) - - template::value, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto mag2(const T &val) { - return val.eval().storage().sum2(); - } - - template::value, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto mag(const T &val) { - return ::librapid::sqrt(mag2(val)); - } - - template::value, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sum(const T &val) { - return val.eval().storage().sum(); - } - - template::value && typetraits::IsVector::value, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto dot(const First &first, const Second &second) { - return (first * second).eval().storage().sum(); - } - - template::value && typetraits::IsVector::value, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cross(const First &first, const Second &second) { - LIBRAPID_ASSERT(typetraits::TypeInfo::dims == 3 && - typetraits::TypeInfo::dims == 3, - "Cross product is only defined for 3D vectors"); - using ScalarFirst = typename typetraits::TypeInfo::Scalar; - using ScalarSecond = typename typetraits::TypeInfo::Scalar; - using Scalar = decltype(std::declval() * std::declval()); - - Scalar x1 = first[0]; - Scalar y1 = first[1]; - Scalar z1 = first[2]; - Scalar x2 = second[0]; - Scalar y2 = second[1]; - Scalar z2 = second[2]; - - return Vector {y1 * z2 - z1 * y2, z1 * x2 - x1 * z2, x1 * y2 - y1 * x2}; - } - - template::value, int> = 0> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto norm(const T &val) { - return val / mag(val); - } + template \ + auto NAME_(const vectorDetail::UnaryVecOp &vec) { \ + return vectorDetail::UnaryVecOp {vec, vectorDetail::Vector_##NAME_ {}}; \ + } + + VECTOR_FUNC_IMPL_DEF(sin) + VECTOR_FUNC_IMPL_DEF(cos) + VECTOR_FUNC_IMPL_DEF(tan) + VECTOR_FUNC_IMPL_DEF(asin) + VECTOR_FUNC_IMPL_DEF(acos) + VECTOR_FUNC_IMPL_DEF(atan) + VECTOR_FUNC_IMPL_DEF(sinh) + VECTOR_FUNC_IMPL_DEF(cosh) + VECTOR_FUNC_IMPL_DEF(tanh) + VECTOR_FUNC_IMPL_DEF(asinh) + VECTOR_FUNC_IMPL_DEF(acosh) + VECTOR_FUNC_IMPL_DEF(atanh) + VECTOR_FUNC_IMPL_DEF(exp) + VECTOR_FUNC_IMPL_DEF(exp2) + VECTOR_FUNC_IMPL_DEF(exp10) + VECTOR_FUNC_IMPL_DEF(log) + VECTOR_FUNC_IMPL_DEF(log2) + VECTOR_FUNC_IMPL_DEF(log10) + VECTOR_FUNC_IMPL_DEF(sqrt) + VECTOR_FUNC_IMPL_DEF(cbrt) + + template::value, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto mag2(const T &val) { + return val.eval().storage().sum2(); + } + + template::value, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto mag(const T &val) { + return ::librapid::sqrt(mag2(val)); + } + + template::value, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sum(const T &val) { + return val.eval().storage().sum(); + } + + template::value && typetraits::IsVector::value, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto dot(const First &first, const Second &second) { + return (first * second).eval().storage().sum(); + } + + template::value && typetraits::IsVector::value, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cross(const First &first, const Second &second) { + LIBRAPID_ASSERT(typetraits::TypeInfo::dims == 3 && + typetraits::TypeInfo::dims == 3, + "Cross product is only defined for 3D vectors"); + using ScalarFirst = typename typetraits::TypeInfo::Scalar; + using ScalarSecond = typename typetraits::TypeInfo::Scalar; + using Scalar = decltype(std::declval() * std::declval()); + + Scalar x1 = first[0]; + Scalar y1 = first[1]; + Scalar z1 = first[2]; + Scalar x2 = second[0]; + Scalar y2 = second[1]; + Scalar z2 = second[2]; + + return Vector {y1 * z2 - z1 * y2, z1 * x2 - x1 * z2, x1 * y2 - y1 * x2}; + } + + template::value, int> = 0> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto norm(const T &val) { + return val / mag(val); + } } // namespace librapid LIBRAPID_SIMPLE_IO_IMPL(typename Derived, librapid::vectorDetail::VectorBase); LIBRAPID_SIMPLE_IO_IMPL(typename T COMMA size_t N, librapid::Vector); LIBRAPID_SIMPLE_IO_IMPL(typename LHS COMMA typename RHS COMMA typename Op, - librapid::vectorDetail::BinaryVecOp); + librapid::vectorDetail::BinaryVecOp); LIBRAPID_SIMPLE_IO_IMPL(typename Val COMMA typename Op, - librapid::vectorDetail::UnaryVecOp); + librapid::vectorDetail::UnaryVecOp); #endif // LIBRAPID_MATH_VECTOR_HPP diff --git a/librapid/include/librapid/ml/activations.hpp b/librapid/include/librapid/ml/activations.hpp index ed8f19a1..d4793321 100644 --- a/librapid/include/librapid/ml/activations.hpp +++ b/librapid/include/librapid/ml/activations.hpp @@ -2,200 +2,198 @@ #define LIBRAPID_ML_ACTIVATIONS namespace librapid::ml { - // 1. [X] Sigmoid . . . . . . f(x) = 1 / (1 + e^-x) - // f'(x) = x(1 - x) - // 2. [ ] Tanh . . . . . . . f(x) = tanh(x) - // f'(x) = 1 - x^2 - // 3. [ ] ReLU . . . . . . . f(x) = max(0, x) - // f'(x) = 1 if x > 0 else 0 - // 4. [ ] LeakyReLU . . . . . f(x) = max(0.01x, x) - // f'(x) = 1 if x > 0 else 0.01 - // 5. [ ] Softmax . . . . . . https://github.com/tiny-dnn/ - // tiny-dnn/blob/master/tiny_dnn/activations/softmax_layer.h - // 6. [ ] Softplus . . . . . f(x) = ln(1 + e^x) - // f'(x) = 1 / (1 + e^-x) - // 7. [ ] ELU . . . . . . . . f(x) = x if x > 0 else a(e^x - 1) - // f'(x) = 1 if x > 0 else a(e^x) - // 8. [ ] SELU . . . . . . . f(x) = lambda * a * (e^x - 1) if x <= 0 else lambda * x - // f'(x) = lambda * a * e^x if x <= 0 else lambda - // α ≈ 1.67326 and λ ≈ 1.0507 - // 9. [ ] Swish . . . . . . . f(x) = x / (1 + e^-x) - // f'(x) = x(1 + e^-x + xe^-x) / (1 + e^-x)^2 - // 10. [ ] Mish . . . . . . . f(x) = x * tanh(ln(1 + e^x)) - // f'(x) = (e^x * (4 * x + 4 + 4 * e^x + e^(2 * x))) / (2 * e^x + - // e^(2 - //* x) - //+ 2)^2 - // 11. [ ] HardSigmoid . . . . f(x) = max(0, min(1, x * 0.2 + 0.5)) - // f'(x) = 0.2 if 0 < x < 1 else 0 - // 12. [ ] LogSigmoid . . . . f(x) = ln(1 / (1 + e^-x)) - // f'(x) = 1 / (1 + e^x) - // 13. [ ] Softsign . . . . . f(x) = x / (1 + |x|) - // f'(x) = 1 / (1 + |x|)^2 - // 14. [ ] Exponential . . . . f(x) = e^x - // f'(x) = e^x - // 15. [ ] GELU . . . . . . . f(x) = x * (1 + erf(x / sqrt(2))) / 2 - // f'(x) = (erf(x / sqrt(2)) + x * e^(-x^2 / 2) / sqrt(2 * pi)) / 2 - // 18. [ ] Softmin . . . . . . f(x) = e^x / sum(e^x) - // f'(x) = f(x)(1 - f(x)) - - /// \brief Sigmoid activation function - /// - /// A class that implements the Sigmoid activation function. - /// - /// \f$\sigma(x) = \frac{1}{1 + e^{-x}}\f$ - /// - /// \f$\sigma'(x) = x(1 - x)\f$ - /// - class Sigmoid { - public: - Sigmoid() = default; - - /// Applies the Sigmoid activation function to the input array and returns the result. - /// - /// @tparam ShapeType The type of the shape of the input array. - /// @tparam StorageType The type of the storage of the input array. - /// @param src The input array to apply the activation function to. - /// @return A new array with the result of applying the Sigmoid activation function to the - /// input array. - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto forward(const T &src) const { - auto ret = emptyLike(src); - forward(ret, src); - return ret; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto backward(const T &src) const { - auto ret = emptyLike(src); - backward(ret, src); - return ret; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const T &src) const { - return forward(src); - } - - /// Applies the Sigmoid activation function to the input array and stores the result in the - /// output array. - /// - /// @tparam ShapeType The type of the shape of the input and output arrays. - /// @tparam StorageScalar The type of the scalar values stored in the input and output - /// arrays. - /// @param dst The output array to store the result of applying the Sigmoid activation - /// function to the input array. - /// @param src The input array to apply the activation function to. - template - LIBRAPID_ALWAYS_INLINE void - forward(array::ArrayContainer> &dst, - const array::ArrayContainer> - &src) const { - dst = StorageScalar(1) / (StorageScalar(1) + exp(-src)); - } - - template - LIBRAPID_ALWAYS_INLINE void - backward(array::ArrayContainer> &dst, - const array::ArrayContainer> - &src) const { - dst = src * (StorageScalar(1) - src); - } - - template - LIBRAPID_ALWAYS_INLINE void forward( - array::ArrayContainer> &dst, - const array::ArrayContainer> &src) const { - dst = StorageScalar(1) / (StorageScalar(1) + exp(-src)); - } - - template - LIBRAPID_ALWAYS_INLINE void backward( - array::ArrayContainer> &dst, - const array::ArrayContainer> &src) const { - dst = src * (StorageScalar(1) - src); - } - - template - LIBRAPID_ALWAYS_INLINE void - forward(array::ArrayContainer> &dst, - const detail::Function &src) const { - dst = StorageScalar(1) / (StorageScalar(1) + exp(-src)); - } - - template - LIBRAPID_ALWAYS_INLINE void - backward(array::ArrayContainer> &dst, - const detail::Function &src) const { - dst = src * (StorageScalar(1) - src); - } + // 1. [X] Sigmoid . . . . . . f(x) = 1 / (1 + e^-x) + // f'(x) = x(1 - x) + // 2. [ ] Tanh . . . . . . . f(x) = tanh(x) + // f'(x) = 1 - x^2 + // 3. [ ] ReLU . . . . . . . f(x) = max(0, x) + // f'(x) = 1 if x > 0 else 0 + // 4. [ ] LeakyReLU . . . . . f(x) = max(0.01x, x) + // f'(x) = 1 if x > 0 else 0.01 + // 5. [ ] Softmax . . . . . . https://github.com/tiny-dnn/ + // tiny-dnn/blob/master/tiny_dnn/activations/softmax_layer.h + // 6. [ ] Softplus . . . . . f(x) = ln(1 + e^x) + // f'(x) = 1 / (1 + e^-x) + // 7. [ ] ELU . . . . . . . . f(x) = x if x > 0 else a(e^x - 1) + // f'(x) = 1 if x > 0 else a(e^x) + // 8. [ ] SELU . . . . . . . f(x) = lambda * a * (e^x - 1) if x <= 0 else lambda * x + // f'(x) = lambda * a * e^x if x <= 0 else lambda + // α ≈ 1.67326 and λ ≈ 1.0507 + // 9. [ ] Swish . . . . . . . f(x) = x / (1 + e^-x) + // f'(x) = x(1 + e^-x + xe^-x) / (1 + e^-x)^2 + // 10. [ ] Mish . . . . . . . f(x) = x * tanh(ln(1 + e^x)) + // f'(x) = (e^x * (4 * x + 4 + 4 * e^x + e^(2 * x))) / (2 * e^x + + // e^(2 + //* x) + //+ 2)^2 + // 11. [ ] HardSigmoid . . . . f(x) = max(0, min(1, x * 0.2 + 0.5)) + // f'(x) = 0.2 if 0 < x < 1 else 0 + // 12. [ ] LogSigmoid . . . . f(x) = ln(1 / (1 + e^-x)) + // f'(x) = 1 / (1 + e^x) + // 13. [ ] Softsign . . . . . f(x) = x / (1 + |x|) + // f'(x) = 1 / (1 + |x|)^2 + // 14. [ ] Exponential . . . . f(x) = e^x + // f'(x) = e^x + // 15. [ ] GELU . . . . . . . f(x) = x * (1 + erf(x / sqrt(2))) / 2 + // f'(x) = (erf(x / sqrt(2)) + x * e^(-x^2 / 2) / sqrt(2 * pi)) / 2 + // 18. [ ] Softmin . . . . . . f(x) = e^x / sum(e^x) + // f'(x) = f(x)(1 - f(x)) + + /// \brief Sigmoid activation function + /// + /// A class that implements the Sigmoid activation function. + /// + /// \f$\sigma(x) = \frac{1}{1 + e^{-x}}\f$ + /// + /// \f$\sigma'(x) = x(1 - x)\f$ + /// + class Sigmoid { + public: + Sigmoid() = default; + + /// Applies the Sigmoid activation function to the input array and returns the result. + /// + /// @tparam ShapeType The type of the shape of the input array. + /// @tparam StorageType The type of the storage of the input array. + /// @param src The input array to apply the activation function to. + /// @return A new array with the result of applying the Sigmoid activation function to the + /// input array. + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto forward(const T &src) const { + auto ret = emptyLike(src); + forward(ret, src); + return ret; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto backward(const T &src) const { + auto ret = emptyLike(src); + backward(ret, src); + return ret; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator()(const T &src) const { + return forward(src); + } + + /// Applies the Sigmoid activation function to the input array and stores the result in the + /// output array. + /// + /// @tparam ShapeType The type of the shape of the input and output arrays. + /// @tparam StorageScalar The type of the scalar values stored in the input and output + /// arrays. + /// @param dst The output array to store the result of applying the Sigmoid activation + /// function to the input array. + /// @param src The input array to apply the activation function to. + template + LIBRAPID_ALWAYS_INLINE void + forward(array::ArrayContainer> &dst, + const array::ArrayContainer> &src) const { + dst = StorageScalar(1) / (StorageScalar(1) + exp(-src)); + } + + template + LIBRAPID_ALWAYS_INLINE void + backward(array::ArrayContainer> &dst, + const array::ArrayContainer> &src) const { + dst = src * (StorageScalar(1) - src); + } + + template + LIBRAPID_ALWAYS_INLINE void forward( + array::ArrayContainer> &dst, + const array::ArrayContainer> &src) const { + dst = StorageScalar(1) / (StorageScalar(1) + exp(-src)); + } + + template + LIBRAPID_ALWAYS_INLINE void backward( + array::ArrayContainer> &dst, + const array::ArrayContainer> &src) const { + dst = src * (StorageScalar(1) - src); + } + + template + LIBRAPID_ALWAYS_INLINE void + forward(array::ArrayContainer> &dst, + const detail::Function &src) const { + dst = StorageScalar(1) / (StorageScalar(1) + exp(-src)); + } + + template + LIBRAPID_ALWAYS_INLINE void + backward(array::ArrayContainer> &dst, + const detail::Function &src) const { + dst = src * (StorageScalar(1) - src); + } #if defined(LIBRAPID_HAS_OPENCL) - template< - typename ShapeType, typename StorageScalar, typename Src, - typename std::enable_if_t< - std::is_same_v::Backend, backend::OpenCL>, int> = 0> - LIBRAPID_ALWAYS_INLINE void - forward(array::ArrayContainer> &dst, - const Src &src) const { - auto temp = evaluated(src); - opencl::runLinearKernel("sigmoidActivationForward", - src.shape().size(), - dst.storage().data(), - src.storage().data()); - } - - template< - typename ShapeType, typename StorageScalar, typename Src, - typename std::enable_if_t< - std::is_same_v::Backend, backend::OpenCL>, int> = 0> - LIBRAPID_ALWAYS_INLINE void - backward(array::ArrayContainer> &dst, - const Src &src) const { - auto temp = evaluated(src); - opencl::runLinearKernel("sigmoidActivationBackward", - temp.shape().size(), - dst.storage().data(), - temp.storage().data()); - } + template< + typename ShapeType, typename StorageScalar, typename Src, + typename std::enable_if_t< + std::is_same_v::Backend, backend::OpenCL>, int> = 0> + LIBRAPID_ALWAYS_INLINE void + forward(array::ArrayContainer> &dst, + const Src &src) const { + auto temp = evaluated(src); + opencl::runLinearKernel("sigmoidActivationForward", + src.shape().size(), + dst.storage().data(), + src.storage().data()); + } + + template< + typename ShapeType, typename StorageScalar, typename Src, + typename std::enable_if_t< + std::is_same_v::Backend, backend::OpenCL>, int> = 0> + LIBRAPID_ALWAYS_INLINE void + backward(array::ArrayContainer> &dst, + const Src &src) const { + auto temp = evaluated(src); + opencl::runLinearKernel("sigmoidActivationBackward", + temp.shape().size(), + dst.storage().data(), + temp.storage().data()); + } #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - template< - typename ShapeType, typename StorageScalar, typename Src, - typename std::enable_if_t< - std::is_same_v::Backend, backend::CUDA>, int> = 0> - LIBRAPID_ALWAYS_INLINE void - forward(array::ArrayContainer> &dst, - const Src &src) const { - auto temp = evaluated(src); - cuda::runKernel("activations", - "sigmoidActivationForward", - dst.shape().size(), - temp.shape().size(), - dst.storage().begin().get(), - temp.storage().begin().get()); - } - - template< - typename ShapeType, typename StorageScalar, typename Src, - typename std::enable_if_t< - std::is_same_v::Backend, backend::CUDA>, int> = 0> - LIBRAPID_ALWAYS_INLINE void - backward(array::ArrayContainer> &dst, - const Src &src) const { - auto temp = evaluated(src); - cuda::runKernel("activations", - "sigmoidActivationBackward", - dst.shape().size(), - temp.shape().size(), - dst.storage().begin().get(), - temp.storage().begin().get()); - } + template< + typename ShapeType, typename StorageScalar, typename Src, + typename std::enable_if_t< + std::is_same_v::Backend, backend::CUDA>, int> = 0> + LIBRAPID_ALWAYS_INLINE void + forward(array::ArrayContainer> &dst, + const Src &src) const { + auto temp = evaluated(src); + cuda::runKernel("activations", + "sigmoidActivationForward", + dst.shape().size(), + temp.shape().size(), + dst.storage().begin().get(), + temp.storage().begin().get()); + } + + template< + typename ShapeType, typename StorageScalar, typename Src, + typename std::enable_if_t< + std::is_same_v::Backend, backend::CUDA>, int> = 0> + LIBRAPID_ALWAYS_INLINE void + backward(array::ArrayContainer> &dst, + const Src &src) const { + auto temp = evaluated(src); + cuda::runKernel("activations", + "sigmoidActivationBackward", + dst.shape().size(), + temp.shape().size(), + dst.storage().begin().get(), + temp.storage().begin().get()); + } #endif // LIBRAPID_HAS_CUDA - }; + }; } // namespace librapid::ml #endif // LIBRAPID_ML_ACTIVATIONS \ No newline at end of file diff --git a/librapid/include/librapid/opencl/kernels/abs.cl b/librapid/include/librapid/opencl/kernels/abs.cl index 0e9b9f71..de5e236d 100644 --- a/librapid/include/librapid/opencl/kernels/abs.cl +++ b/librapid/include/librapid/opencl/kernels/abs.cl @@ -1,20 +1,20 @@ #define ABS_KERNEL(DTYPE) \ - __kernel void absArrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = (data[gid] >= 0) ? data[gid] : -data[gid]; \ - } + __kernel void absArrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = (data[gid] >= 0) ? data[gid] : -data[gid]; \ + } #define ABS_IMPL \ - ABS_KERNEL(int8_t) \ - ABS_KERNEL(uint8_t) \ - ABS_KERNEL(int16_t) \ - ABS_KERNEL(uint16_t) \ - ABS_KERNEL(int32_t) \ - ABS_KERNEL(uint32_t) \ - ABS_KERNEL(int64_t) \ - ABS_KERNEL(uint64_t) \ - ABS_KERNEL(float) \ - ABS_KERNEL(double) + ABS_KERNEL(int8_t) \ + ABS_KERNEL(uint8_t) \ + ABS_KERNEL(int16_t) \ + ABS_KERNEL(uint16_t) \ + ABS_KERNEL(int32_t) \ + ABS_KERNEL(uint32_t) \ + ABS_KERNEL(int64_t) \ + ABS_KERNEL(uint64_t) \ + ABS_KERNEL(float) \ + ABS_KERNEL(double) ABS_IMPL diff --git a/librapid/include/librapid/opencl/kernels/activations.cl b/librapid/include/librapid/opencl/kernels/activations.cl index 9f2116dc..578c7d5a 100644 --- a/librapid/include/librapid/opencl/kernels/activations.cl +++ b/librapid/include/librapid/opencl/kernels/activations.cl @@ -1,43 +1,43 @@ #define SIGMOID_KERNEL(DTYPE) \ - __kernel void sigmoidActivationForward_##DTYPE(__global DTYPE *dst, \ - __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = 1 / (1 + exp(-data[gid])); \ - } \ + __kernel void sigmoidActivationForward_##DTYPE(__global DTYPE *dst, \ + __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = 1 / (1 + exp(-data[gid])); \ + } \ \ - __kernel void sigmoidActivationBackward_##DTYPE(__global DTYPE *dst, \ - __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = data[gid] * (1 - data[gid]); \ - } \ + __kernel void sigmoidActivationBackward_##DTYPE(__global DTYPE *dst, \ + __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = data[gid] * (1 - data[gid]); \ + } \ \ - struct _sigmoid_semicolon_forcer_##DTYPE {} + struct _sigmoid_semicolon_forcer_##DTYPE {} #define SIGMOID_KERNEL_CAST(DTYPE) \ - __kernel void sigmoidActivationForward_##DTYPE(__global DTYPE *dst, \ - __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = 1 / (1 + exp((float)-data[gid])); \ - } \ + __kernel void sigmoidActivationForward_##DTYPE(__global DTYPE *dst, \ + __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = 1 / (1 + exp((float)-data[gid])); \ + } \ \ - __kernel void sigmoidActivationBackward_##DTYPE(__global DTYPE *dst, \ - __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = data[gid] * (1 - data[gid]); \ - } \ + __kernel void sigmoidActivationBackward_##DTYPE(__global DTYPE *dst, \ + __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = data[gid] * (1 - data[gid]); \ + } \ \ - struct _sigmoid_semicolon_forcer_##DTYPE {} + struct _sigmoid_semicolon_forcer_##DTYPE {} #define KERNEL_IMPL(NAME) \ - NAME##_KERNEL_CAST(int8_t); \ - NAME##_KERNEL_CAST(int16_t); \ - NAME##_KERNEL_CAST(int32_t); \ - NAME##_KERNEL_CAST(int64_t); \ - NAME##_KERNEL_CAST(uint8_t); \ - NAME##_KERNEL_CAST(uint16_t); \ - NAME##_KERNEL_CAST(uint32_t); \ - NAME##_KERNEL_CAST(uint64_t); \ - NAME##_KERNEL(float); \ - NAME##_KERNEL(double); + NAME##_KERNEL_CAST(int8_t); \ + NAME##_KERNEL_CAST(int16_t); \ + NAME##_KERNEL_CAST(int32_t); \ + NAME##_KERNEL_CAST(int64_t); \ + NAME##_KERNEL_CAST(uint8_t); \ + NAME##_KERNEL_CAST(uint16_t); \ + NAME##_KERNEL_CAST(uint32_t); \ + NAME##_KERNEL_CAST(uint64_t); \ + NAME##_KERNEL(float); \ + NAME##_KERNEL(double); KERNEL_IMPL(SIGMOID) diff --git a/librapid/include/librapid/opencl/kernels/arithmetic.cl b/librapid/include/librapid/opencl/kernels/arithmetic.cl index 426671f2..88ca3021 100644 --- a/librapid/include/librapid/opencl/kernels/arithmetic.cl +++ b/librapid/include/librapid/opencl/kernels/arithmetic.cl @@ -1,33 +1,33 @@ #define ARITHMETIC_KERNEL(NAME, OP, DTYPE) \ - __kernel void NAME##Arrays_##DTYPE( \ - __global DTYPE *dst, __global const DTYPE *lhs, __global const DTYPE *rhs) { \ - int gid = get_global_id(0); \ - dst[gid] = lhs[gid] OP rhs[gid]; \ - } \ + __kernel void NAME##Arrays_##DTYPE( \ + __global DTYPE *dst, __global const DTYPE *lhs, __global const DTYPE *rhs) { \ + int gid = get_global_id(0); \ + dst[gid] = lhs[gid] OP rhs[gid]; \ + } \ \ - __kernel void NAME##ArraysScalarRhs_##DTYPE( \ - __global DTYPE *dst, __global const DTYPE *lhs, DTYPE rhs) { \ - int gid = get_global_id(0); \ - dst[gid] = lhs[gid] OP rhs; \ - } \ + __kernel void NAME##ArraysScalarRhs_##DTYPE( \ + __global DTYPE *dst, __global const DTYPE *lhs, DTYPE rhs) { \ + int gid = get_global_id(0); \ + dst[gid] = lhs[gid] OP rhs; \ + } \ \ - __kernel void NAME##ArraysScalarLhs_##DTYPE( \ - __global DTYPE *dst, DTYPE lhs, __global const DTYPE *rhs) { \ - int gid = get_global_id(0); \ - dst[gid] = lhs OP rhs[gid]; \ - } + __kernel void NAME##ArraysScalarLhs_##DTYPE( \ + __global DTYPE *dst, DTYPE lhs, __global const DTYPE *rhs) { \ + int gid = get_global_id(0); \ + dst[gid] = lhs OP rhs[gid]; \ + } #define ARITHMETIC_OP_IMPL(NAME, OP) \ - ARITHMETIC_KERNEL(NAME, OP, int8_t) \ - ARITHMETIC_KERNEL(NAME, OP, int16_t) \ - ARITHMETIC_KERNEL(NAME, OP, int32_t) \ - ARITHMETIC_KERNEL(NAME, OP, int64_t) \ - ARITHMETIC_KERNEL(NAME, OP, uint8_t) \ - ARITHMETIC_KERNEL(NAME, OP, uint16_t) \ - ARITHMETIC_KERNEL(NAME, OP, uint32_t) \ - ARITHMETIC_KERNEL(NAME, OP, uint64_t) \ - ARITHMETIC_KERNEL(NAME, OP, float) \ - ARITHMETIC_KERNEL(NAME, OP, double) + ARITHMETIC_KERNEL(NAME, OP, int8_t) \ + ARITHMETIC_KERNEL(NAME, OP, int16_t) \ + ARITHMETIC_KERNEL(NAME, OP, int32_t) \ + ARITHMETIC_KERNEL(NAME, OP, int64_t) \ + ARITHMETIC_KERNEL(NAME, OP, uint8_t) \ + ARITHMETIC_KERNEL(NAME, OP, uint16_t) \ + ARITHMETIC_KERNEL(NAME, OP, uint32_t) \ + ARITHMETIC_KERNEL(NAME, OP, uint64_t) \ + ARITHMETIC_KERNEL(NAME, OP, float) \ + ARITHMETIC_KERNEL(NAME, OP, double) ARITHMETIC_OP_IMPL(add, +) ARITHMETIC_OP_IMPL(sub, -) @@ -42,108 +42,108 @@ ARITHMETIC_OP_IMPL(elementWiseEqual, ==) ARITHMETIC_OP_IMPL(elementWiseNotEqual, !=) #define DUAL_ARITHMETIC_OP(DTYPE) \ - __kernel void addArrays_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - __global const struct Dual_##DTYPE *lhs, \ - __global const struct Dual_##DTYPE *rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs[gid].value + rhs[gid].value; \ - dst[gid].derivative = lhs[gid].derivative + rhs[gid].derivative; \ - } \ + __kernel void addArrays_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + __global const struct Dual_##DTYPE *lhs, \ + __global const struct Dual_##DTYPE *rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs[gid].value + rhs[gid].value; \ + dst[gid].derivative = lhs[gid].derivative + rhs[gid].derivative; \ + } \ \ - __kernel void addArraysScalarRhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - __global const struct Dual_##DTYPE *lhs, \ - struct Dual_##DTYPE rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs[gid].value + rhs.value; \ - dst[gid].derivative = lhs[gid].derivative + rhs.derivative; \ - } \ + __kernel void addArraysScalarRhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + __global const struct Dual_##DTYPE *lhs, \ + struct Dual_##DTYPE rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs[gid].value + rhs.value; \ + dst[gid].derivative = lhs[gid].derivative + rhs.derivative; \ + } \ \ - __kernel void addArraysScalarLhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - struct Dual_##DTYPE lhs, \ - __global const struct Dual_##DTYPE *rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs.value + rhs[gid].value; \ - dst[gid].derivative = lhs.derivative + rhs[gid].derivative; \ - } \ + __kernel void addArraysScalarLhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + struct Dual_##DTYPE lhs, \ + __global const struct Dual_##DTYPE *rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs.value + rhs[gid].value; \ + dst[gid].derivative = lhs.derivative + rhs[gid].derivative; \ + } \ \ - __kernel void subArrays_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - __global const struct Dual_##DTYPE *lhs, \ - __global const struct Dual_##DTYPE *rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs[gid].value - rhs[gid].value; \ - dst[gid].derivative = lhs[gid].derivative - rhs[gid].derivative; \ - } \ + __kernel void subArrays_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + __global const struct Dual_##DTYPE *lhs, \ + __global const struct Dual_##DTYPE *rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs[gid].value - rhs[gid].value; \ + dst[gid].derivative = lhs[gid].derivative - rhs[gid].derivative; \ + } \ \ - __kernel void subArraysScalarRhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - __global const struct Dual_##DTYPE *lhs, \ - struct Dual_##DTYPE rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs[gid].value - rhs.value; \ - dst[gid].derivative = lhs[gid].derivative - rhs.derivative; \ - } \ + __kernel void subArraysScalarRhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + __global const struct Dual_##DTYPE *lhs, \ + struct Dual_##DTYPE rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs[gid].value - rhs.value; \ + dst[gid].derivative = lhs[gid].derivative - rhs.derivative; \ + } \ \ - __kernel void subArraysScalarLhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - struct Dual_##DTYPE lhs, \ - __global const struct Dual_##DTYPE *rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs.value - rhs[gid].value; \ - dst[gid].derivative = lhs.derivative - rhs[gid].derivative; \ - } \ + __kernel void subArraysScalarLhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + struct Dual_##DTYPE lhs, \ + __global const struct Dual_##DTYPE *rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs.value - rhs[gid].value; \ + dst[gid].derivative = lhs.derivative - rhs[gid].derivative; \ + } \ \ - __kernel void mulArrays_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - __global const struct Dual_##DTYPE *lhs, \ - __global const struct Dual_##DTYPE *rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs[gid].value * rhs[gid].value; \ - dst[gid].derivative = \ - lhs[gid].derivative * rhs[gid].value + lhs[gid].value * rhs[gid].derivative; \ - } \ + __kernel void mulArrays_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + __global const struct Dual_##DTYPE *lhs, \ + __global const struct Dual_##DTYPE *rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs[gid].value * rhs[gid].value; \ + dst[gid].derivative = \ + lhs[gid].derivative * rhs[gid].value + lhs[gid].value * rhs[gid].derivative; \ + } \ \ - __kernel void mulArraysScalarRhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - __global const struct Dual_##DTYPE *lhs, \ - struct Dual_##DTYPE rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs[gid].value * rhs.value; \ - dst[gid].derivative = lhs[gid].derivative * rhs.value + lhs[gid].value * rhs.derivative; \ - } \ + __kernel void mulArraysScalarRhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + __global const struct Dual_##DTYPE *lhs, \ + struct Dual_##DTYPE rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs[gid].value * rhs.value; \ + dst[gid].derivative = lhs[gid].derivative * rhs.value + lhs[gid].value * rhs.derivative; \ + } \ \ - __kernel void mulArraysScalarLhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - struct Dual_##DTYPE lhs, \ - __global const struct Dual_##DTYPE *rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs.value * rhs[gid].value; \ - dst[gid].derivative = lhs.derivative * rhs[gid].value + lhs.value * rhs[gid].derivative; \ - } \ + __kernel void mulArraysScalarLhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + struct Dual_##DTYPE lhs, \ + __global const struct Dual_##DTYPE *rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs.value * rhs[gid].value; \ + dst[gid].derivative = lhs.derivative * rhs[gid].value + lhs.value * rhs[gid].derivative; \ + } \ \ - __kernel void divArrays_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - __global const struct Dual_##DTYPE *lhs, \ - __global const struct Dual_##DTYPE *rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs[gid].value / rhs[gid].value; \ - dst[gid].derivative = \ - (lhs[gid].derivative * rhs[gid].value - lhs[gid].value * rhs[gid].derivative) / \ - (rhs[gid].value * rhs[gid].value); \ - } \ + __kernel void divArrays_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + __global const struct Dual_##DTYPE *lhs, \ + __global const struct Dual_##DTYPE *rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs[gid].value / rhs[gid].value; \ + dst[gid].derivative = \ + (lhs[gid].derivative * rhs[gid].value - lhs[gid].value * rhs[gid].derivative) / \ + (rhs[gid].value * rhs[gid].value); \ + } \ \ - __kernel void divArraysScalarRhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - __global const struct Dual_##DTYPE *lhs, \ - struct Dual_##DTYPE rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs[gid].value / rhs.value; \ - dst[gid].derivative = \ - (lhs[gid].derivative * rhs.value - lhs[gid].value * rhs.derivative) / \ - (rhs.value * rhs.value); \ - } \ + __kernel void divArraysScalarRhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + __global const struct Dual_##DTYPE *lhs, \ + struct Dual_##DTYPE rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs[gid].value / rhs.value; \ + dst[gid].derivative = \ + (lhs[gid].derivative * rhs.value - lhs[gid].value * rhs.derivative) / \ + (rhs.value * rhs.value); \ + } \ \ - __kernel void divArraysScalarLhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ - struct Dual_##DTYPE lhs, \ - __global const struct Dual_##DTYPE *rhs) { \ - int gid = get_global_id(0); \ - dst[gid].value = lhs.value / rhs[gid].value; \ - dst[gid].derivative = \ - (lhs.derivative * rhs[gid].value - lhs.value * rhs[gid].derivative) / \ - (rhs[gid].value * rhs[gid].value); \ - } + __kernel void divArraysScalarLhs_Dual_##DTYPE(__global struct Dual_##DTYPE *dst, \ + struct Dual_##DTYPE lhs, \ + __global const struct Dual_##DTYPE *rhs) { \ + int gid = get_global_id(0); \ + dst[gid].value = lhs.value / rhs[gid].value; \ + dst[gid].derivative = \ + (lhs.derivative * rhs[gid].value - lhs.value * rhs[gid].derivative) / \ + (rhs[gid].value * rhs[gid].value); \ + } DUAL_ARITHMETIC_OP(int8_t) DUAL_ARITHMETIC_OP(int16_t) diff --git a/librapid/include/librapid/opencl/kernels/dual.cl b/librapid/include/librapid/opencl/kernels/dual.cl index 0a46daa1..c0b12b09 100644 --- a/librapid/include/librapid/opencl/kernels/dual.cl +++ b/librapid/include/librapid/opencl/kernels/dual.cl @@ -2,10 +2,10 @@ #define LIBRAPID_OPENCL_DUAL #define DUAL_DEF(TYPE) \ - struct Dual_##TYPE { \ - TYPE value; \ - TYPE derivative; \ - }; + struct Dual_##TYPE { \ + TYPE value; \ + TYPE derivative; \ + }; DUAL_DEF(int8_t); DUAL_DEF(int16_t); diff --git a/librapid/include/librapid/opencl/kernels/expLogPow.cl b/librapid/include/librapid/opencl/kernels/expLogPow.cl index a93ec7b5..3bf38997 100644 --- a/librapid/include/librapid/opencl/kernels/expLogPow.cl +++ b/librapid/include/librapid/opencl/kernels/expLogPow.cl @@ -1,32 +1,32 @@ #define EXPLOGPOW_KERNEL(NAME, DTYPE) \ - __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = NAME(data[gid]); \ - } + __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = NAME(data[gid]); \ + } #define ABS_KERNEL(DTYPE) \ - __kernel void absArrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = (data[gid] >= 0) ? data[gid] : -data[gid]; \ - } + __kernel void absArrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = (data[gid] >= 0) ? data[gid] : -data[gid]; \ + } #define EXPLOGPOW_KERNEL_CAST(NAME, DTYPE) \ - __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = (DTYPE)NAME((double)data[gid]); \ - } + __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = (DTYPE)NAME((double)data[gid]); \ + } #define EXPLOGPOW_IMPL(NAME) \ - EXPLOGPOW_KERNEL_CAST(NAME, int8_t) \ - EXPLOGPOW_KERNEL_CAST(NAME, uint8_t) \ - EXPLOGPOW_KERNEL_CAST(NAME, int16_t) \ - EXPLOGPOW_KERNEL_CAST(NAME, uint16_t) \ - EXPLOGPOW_KERNEL_CAST(NAME, int32_t) \ - EXPLOGPOW_KERNEL_CAST(NAME, uint32_t) \ - EXPLOGPOW_KERNEL_CAST(NAME, int64_t) \ - EXPLOGPOW_KERNEL_CAST(NAME, uint64_t) \ - EXPLOGPOW_KERNEL(NAME, float) \ - EXPLOGPOW_KERNEL(NAME, double) + EXPLOGPOW_KERNEL_CAST(NAME, int8_t) \ + EXPLOGPOW_KERNEL_CAST(NAME, uint8_t) \ + EXPLOGPOW_KERNEL_CAST(NAME, int16_t) \ + EXPLOGPOW_KERNEL_CAST(NAME, uint16_t) \ + EXPLOGPOW_KERNEL_CAST(NAME, int32_t) \ + EXPLOGPOW_KERNEL_CAST(NAME, uint32_t) \ + EXPLOGPOW_KERNEL_CAST(NAME, int64_t) \ + EXPLOGPOW_KERNEL_CAST(NAME, uint64_t) \ + EXPLOGPOW_KERNEL(NAME, float) \ + EXPLOGPOW_KERNEL(NAME, double) EXPLOGPOW_IMPL(exp) EXPLOGPOW_IMPL(log) diff --git a/librapid/include/librapid/opencl/kernels/fill.cl b/librapid/include/librapid/opencl/kernels/fill.cl index 85a9178a..2a13e436 100644 --- a/librapid/include/librapid/opencl/kernels/fill.cl +++ b/librapid/include/librapid/opencl/kernels/fill.cl @@ -8,16 +8,16 @@ Implements Mersenne twister generator. M. Matsumoto, T. Nishimura, Mersenne twister: a 623-dimensionally equidistributed uniform pseudo-random number generator, ACM Transactions on Modeling and Computer Simulation (TOMACS) 8 (1) (1998) 3–30. - */ + */ #define RNG32 -#define MT19937_FLOAT_MULTI 2.3283064365386962890625e-10f +#define MT19937_FLOAT_MULTI 2.3283064365386962890625e-10f #define MT19937_DOUBLE2_MULTI 2.3283064365386962890625e-10 #define MT19937_DOUBLE_MULTI 5.4210108624275221700372640e-20 -#define MT19937_N 624 -#define MT19937_M 397 +#define MT19937_N 624 +#define MT19937_M 397 #define MT19937_MATRIX_A 0x9908b0df /* constant vector a */ #define MT19937_UPPER_MASK 0x80000000 /* most significant w-r bits */ #define MT19937_LOWER_MASK 0x7fffffff /* least significant r bits */ @@ -26,8 +26,8 @@ pseudo-random number generator, ACM Transactions on Modeling and Computer Simula State of MT19937 RNG. */ typedef struct { - uint mt[MT19937_N]; /* the array for the state vector */ - int mti; + uint mt[MT19937_N]; /* the array for the state vector */ + int mti; } mt19937_state; /** @@ -37,33 +37,33 @@ Generates a random 32-bit unsigned integer using MT19937 RNG. */ #define mt19937_uint(state) _mt19937_uint(&state) uint _mt19937_uint(mt19937_state *state) { - uint y; - uint mag01[2] = {0x0, MT19937_MATRIX_A}; - /* mag01[x] = x * MT19937_MATRIX_A for x=0,1 */ - - if (state->mti < MT19937_N - MT19937_M) { - y = (state->mt[state->mti] & MT19937_UPPER_MASK) | - (state->mt[state->mti + 1] & MT19937_LOWER_MASK); - state->mt[state->mti] = state->mt[state->mti + MT19937_M] ^ (y >> 1) ^ mag01[y & 0x1]; - } else if (state->mti < MT19937_N - 1) { - y = (state->mt[state->mti] & MT19937_UPPER_MASK) | - (state->mt[state->mti + 1] & MT19937_LOWER_MASK); - state->mt[state->mti] = - state->mt[state->mti + (MT19937_M - MT19937_N)] ^ (y >> 1) ^ mag01[y & 0x1]; - } else { - y = (state->mt[MT19937_N - 1] & MT19937_UPPER_MASK) | (state->mt[0] & MT19937_LOWER_MASK); - state->mt[MT19937_N - 1] = state->mt[MT19937_M - 1] ^ (y >> 1) ^ mag01[y & 0x1]; - state->mti = 0; - } - y = state->mt[state->mti++]; - - /* Tempering */ - y ^= (y >> 11); - y ^= (y << 7) & 0x9d2c5680; - y ^= (y << 15) & 0xefc60000; - y ^= (y >> 18); - - return y; + uint y; + uint mag01[2] = {0x0, MT19937_MATRIX_A}; + /* mag01[x] = x * MT19937_MATRIX_A for x=0,1 */ + + if (state->mti < MT19937_N - MT19937_M) { + y = (state->mt[state->mti] & MT19937_UPPER_MASK) | + (state->mt[state->mti + 1] & MT19937_LOWER_MASK); + state->mt[state->mti] = state->mt[state->mti + MT19937_M] ^ (y >> 1) ^ mag01[y & 0x1]; + } else if (state->mti < MT19937_N - 1) { + y = (state->mt[state->mti] & MT19937_UPPER_MASK) | + (state->mt[state->mti + 1] & MT19937_LOWER_MASK); + state->mt[state->mti] = + state->mt[state->mti + (MT19937_M - MT19937_N)] ^ (y >> 1) ^ mag01[y & 0x1]; + } else { + y = (state->mt[MT19937_N - 1] & MT19937_UPPER_MASK) | (state->mt[0] & MT19937_LOWER_MASK); + state->mt[MT19937_N - 1] = state->mt[MT19937_M - 1] ^ (y >> 1) ^ mag01[y & 0x1]; + state->mti = 0; + } + y = state->mt[state->mti++]; + + /* Tempering */ + y ^= (y >> 11); + y ^= (y << 7) & 0x9d2c5680; + y ^= (y << 15) & 0xefc60000; + y ^= (y >> 18); + + return y; } /** Generates a random 32-bit unsigned integer using MT19937 RNG. @@ -74,36 +74,36 @@ This is alternative implementation of MT19937 RNG, that generates 32 values in s */ #define mt19937_loop_uint(state) _mt19937_loop_uint(&state) uint _mt19937_loop_uint(mt19937_state *state) { - uint y; - uint mag01[2] = {0x0, MT19937_MATRIX_A}; - /* mag01[x] = x * MT19937_MATRIX_A for x=0,1 */ - - if (state->mti >= MT19937_N) { - int kk; - - for (kk = 0; kk < MT19937_N - MT19937_M; kk++) { - y = (state->mt[kk] & MT19937_UPPER_MASK) | (state->mt[kk + 1] & MT19937_LOWER_MASK); - state->mt[kk] = state->mt[kk + MT19937_M] ^ (y >> 1) ^ mag01[y & 0x1]; - } - for (; kk < MT19937_N - 1; kk++) { - y = (state->mt[kk] & MT19937_UPPER_MASK) | (state->mt[kk + 1] & MT19937_LOWER_MASK); - state->mt[kk] = state->mt[kk + (MT19937_M - MT19937_N)] ^ (y >> 1) ^ mag01[y & 0x1]; - } - y = (state->mt[MT19937_N - 1] & MT19937_UPPER_MASK) | (state->mt[0] & MT19937_LOWER_MASK); - state->mt[MT19937_N - 1] = state->mt[MT19937_M - 1] ^ (y >> 1) ^ mag01[y & 0x1]; - - state->mti = 0; - } - - y = state->mt[state->mti++]; - - /* Tempering */ - y ^= (y >> 11); - y ^= (y << 7) & 0x9d2c5680; - y ^= (y << 15) & 0xefc60000; - y ^= (y >> 18); - - return y; + uint y; + uint mag01[2] = {0x0, MT19937_MATRIX_A}; + /* mag01[x] = x * MT19937_MATRIX_A for x=0,1 */ + + if (state->mti >= MT19937_N) { + int kk; + + for (kk = 0; kk < MT19937_N - MT19937_M; kk++) { + y = (state->mt[kk] & MT19937_UPPER_MASK) | (state->mt[kk + 1] & MT19937_LOWER_MASK); + state->mt[kk] = state->mt[kk + MT19937_M] ^ (y >> 1) ^ mag01[y & 0x1]; + } + for (; kk < MT19937_N - 1; kk++) { + y = (state->mt[kk] & MT19937_UPPER_MASK) | (state->mt[kk + 1] & MT19937_LOWER_MASK); + state->mt[kk] = state->mt[kk + (MT19937_M - MT19937_N)] ^ (y >> 1) ^ mag01[y & 0x1]; + } + y = (state->mt[MT19937_N - 1] & MT19937_UPPER_MASK) | (state->mt[0] & MT19937_LOWER_MASK); + state->mt[MT19937_N - 1] = state->mt[MT19937_M - 1] ^ (y >> 1) ^ mag01[y & 0x1]; + + state->mti = 0; + } + + y = state->mt[state->mti++]; + + /* Tempering */ + y ^= (y >> 11); + y ^= (y << 7) & 0x9d2c5680; + y ^= (y << 15) & 0xefc60000; + y ^= (y >> 18); + + return y; } /** @@ -114,17 +114,17 @@ Seeds MT19937 RNG. (thread). */ void mt19937_seed(mt19937_state *state, uint s) { - state->mt[0] = s; - uint mti; - for (mti = 1; mti < MT19937_N; mti++) { - state->mt[mti] = 1812433253 * (state->mt[mti - 1] ^ (state->mt[mti - 1] >> 30)) + mti; - - /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ - /* In the previous versions, MSBs of the seed affect */ - /* only MSBs of the array mt19937[]. */ - /* 2002/01/09 modified by Makoto Matsumoto */ - } - state->mti = mti; + state->mt[0] = s; + uint mti; + for (mti = 1; mti < MT19937_N; mti++) { + state->mt[mti] = 1812433253 * (state->mt[mti - 1] ^ (state->mt[mti - 1] >> 30)) + mti; + + /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ + /* In the previous versions, MSBs of the seed affect */ + /* only MSBs of the array mt19937[]. */ + /* 2002/01/09 modified by Makoto Matsumoto */ + } + state->mti = mti; } /** @@ -156,24 +156,24 @@ Generates a random double using MT19937 RNG. Generated using only 32 random bits #define mt19937_double2(state) (mt19937_uint(state) * MT19937_DOUBLE2_MULTI) #define RANDOM_IMPL(TYPE) \ - __kernel void fillRandom_##TYPE(__global TYPE *data, \ - int64_t elements, \ - TYPE lower, \ - TYPE upper, \ - __global int64_t *seeds, \ - int64_t numSeeds) { \ - int64_t gid = get_global_id(0); \ - int64_t seedIndex = gid % numSeeds; \ - mt19937_state state; \ - mt19937_seed(&state, seeds[seedIndex]); \ + __kernel void fillRandom_##TYPE(__global TYPE *data, \ + int64_t elements, \ + TYPE lower, \ + TYPE upper, \ + __global int64_t *seeds, \ + int64_t numSeeds) { \ + int64_t gid = get_global_id(0); \ + int64_t seedIndex = gid % numSeeds; \ + mt19937_state state; \ + mt19937_seed(&state, seeds[seedIndex]); \ \ - for (int64_t i = gid; i < elements; i += get_global_size(0)) { \ - data[i] = (TYPE)(mt19937_double(state) * (upper - lower) + lower); \ - } \ + for (int64_t i = gid; i < elements; i += get_global_size(0)) { \ + data[i] = (TYPE)(mt19937_double(state) * (upper - lower) + lower); \ + } \ \ - /* Change the seed for the next iteration */ \ - seeds[seedIndex] = mt19937_ulong(state); \ - } + /* Change the seed for the next iteration */ \ + seeds[seedIndex] = mt19937_ulong(state); \ + } RANDOM_IMPL(int8_t) RANDOM_IMPL(uint8_t) diff --git a/librapid/include/librapid/opencl/kernels/floorCeilRound.cl b/librapid/include/librapid/opencl/kernels/floorCeilRound.cl index f8282bce..aa996b48 100644 --- a/librapid/include/librapid/opencl/kernels/floorCeilRound.cl +++ b/librapid/include/librapid/opencl/kernels/floorCeilRound.cl @@ -1,26 +1,26 @@ #define FLOORCEILROUND_KERNEL(NAME, DTYPE) \ - __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = NAME(data[gid]); \ - } + __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = NAME(data[gid]); \ + } #define FLOORCEILROUND_KERNEL_CAST(NAME, DTYPE) \ - __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = data[gid]; \ - } + __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = data[gid]; \ + } #define FLOORCEILROUND_IMPL(NAME) \ - FLOORCEILROUND_KERNEL_CAST(NAME, int8_t) \ - FLOORCEILROUND_KERNEL_CAST(NAME, uint8_t) \ - FLOORCEILROUND_KERNEL_CAST(NAME, int16_t) \ - FLOORCEILROUND_KERNEL_CAST(NAME, uint16_t) \ - FLOORCEILROUND_KERNEL_CAST(NAME, int32_t) \ - FLOORCEILROUND_KERNEL_CAST(NAME, uint32_t) \ - FLOORCEILROUND_KERNEL_CAST(NAME, int64_t) \ - FLOORCEILROUND_KERNEL_CAST(NAME, uint64_t) \ - FLOORCEILROUND_KERNEL(NAME, float) \ - FLOORCEILROUND_KERNEL(NAME, double) + FLOORCEILROUND_KERNEL_CAST(NAME, int8_t) \ + FLOORCEILROUND_KERNEL_CAST(NAME, uint8_t) \ + FLOORCEILROUND_KERNEL_CAST(NAME, int16_t) \ + FLOORCEILROUND_KERNEL_CAST(NAME, uint16_t) \ + FLOORCEILROUND_KERNEL_CAST(NAME, int32_t) \ + FLOORCEILROUND_KERNEL_CAST(NAME, uint32_t) \ + FLOORCEILROUND_KERNEL_CAST(NAME, int64_t) \ + FLOORCEILROUND_KERNEL_CAST(NAME, uint64_t) \ + FLOORCEILROUND_KERNEL(NAME, float) \ + FLOORCEILROUND_KERNEL(NAME, double) FLOORCEILROUND_IMPL(floor) FLOORCEILROUND_IMPL(ceil) diff --git a/librapid/include/librapid/opencl/kernels/negate.cl b/librapid/include/librapid/opencl/kernels/negate.cl index d7542ffe..5629c856 100644 --- a/librapid/include/librapid/opencl/kernels/negate.cl +++ b/librapid/include/librapid/opencl/kernels/negate.cl @@ -1,8 +1,8 @@ #define NEGATE_KERNEL(DTYPE) \ - __kernel void negateArrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = -data[gid]; \ - } + __kernel void negateArrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = -data[gid]; \ + } NEGATE_KERNEL(int8_t) NEGATE_KERNEL(uint8_t) diff --git a/librapid/include/librapid/opencl/kernels/transpose.cl b/librapid/include/librapid/opencl/kernels/transpose.cl index a228e85f..0fca66c5 100644 --- a/librapid/include/librapid/opencl/kernels/transpose.cl +++ b/librapid/include/librapid/opencl/kernels/transpose.cl @@ -2,26 +2,26 @@ #define TRANSPOSEY 3 #define TRANSPOSE_KERNEL_IMPL(DTYPE, TILE_DIM) \ - __kernel void transpose_##DTYPE(__global DTYPE *out, \ - __global const DTYPE *in, \ - const int rows, \ - const int cols, \ - DTYPE alpha) { \ - __local DTYPE tile[TILE_DIM][TILE_DIM + 1]; \ + __kernel void transpose_##DTYPE(__global DTYPE *out, \ + __global const DTYPE *in, \ + const int rows, \ + const int cols, \ + DTYPE alpha) { \ + __local DTYPE tile[TILE_DIM][TILE_DIM + 1]; \ \ - int x = get_group_id(0) * TILE_DIM + get_local_id(0); \ - int y = get_group_id(1) * TILE_DIM + get_local_id(1); \ + int x = get_group_id(0) * TILE_DIM + get_local_id(0); \ + int y = get_group_id(1) * TILE_DIM + get_local_id(1); \ \ - if (x < cols && y < rows) { tile[get_local_id(1)][get_local_id(0)] = in[y * cols + x]; } \ - barrier(CLK_LOCAL_MEM_FENCE); \ + if (x < cols && y < rows) { tile[get_local_id(1)][get_local_id(0)] = in[y * cols + x]; } \ + barrier(CLK_LOCAL_MEM_FENCE); \ \ - x = get_group_id(1) * TILE_DIM + get_local_id(0); \ - y = get_group_id(0) * TILE_DIM + get_local_id(1); \ + x = get_group_id(1) * TILE_DIM + get_local_id(0); \ + y = get_group_id(0) * TILE_DIM + get_local_id(1); \ \ - if (x < rows && y < cols) { \ - out[y * rows + x] = tile[get_local_id(0)][get_local_id(1)] * alpha; \ - } \ - } + if (x < rows && y < cols) { \ + out[y * rows + x] = tile[get_local_id(0)][get_local_id(1)] * alpha; \ + } \ + } TRANSPOSE_KERNEL_IMPL(int8_t, 16) TRANSPOSE_KERNEL_IMPL(int16_t, 16) diff --git a/librapid/include/librapid/opencl/kernels/trigonometry.cl b/librapid/include/librapid/opencl/kernels/trigonometry.cl index d59978b3..907daf9c 100644 --- a/librapid/include/librapid/opencl/kernels/trigonometry.cl +++ b/librapid/include/librapid/opencl/kernels/trigonometry.cl @@ -1,27 +1,27 @@ #define TRIGONOMETRY_KERNEL(NAME, DTYPE) \ - __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = NAME(data[gid]); \ - } + __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = NAME(data[gid]); \ + } #define TRIGONOMETRY_KERNEL_CAST(NAME, DTYPE) \ - __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ - int gid = get_global_id(0); \ - dst[gid] = (DTYPE)NAME((double)data[gid]); \ - } + __kernel void NAME##Arrays_##DTYPE(__global DTYPE *dst, __global const DTYPE *data) { \ + int gid = get_global_id(0); \ + dst[gid] = (DTYPE)NAME((double)data[gid]); \ + } #define TRIG_IMPL(NAME) \ - TRIGONOMETRY_KERNEL_CAST(NAME, int8_t) \ - TRIGONOMETRY_KERNEL_CAST(NAME, uint8_t) \ - TRIGONOMETRY_KERNEL_CAST(NAME, int16_t) \ - TRIGONOMETRY_KERNEL_CAST(NAME, uint16_t) \ - TRIGONOMETRY_KERNEL_CAST(NAME, int32_t) \ - TRIGONOMETRY_KERNEL_CAST(NAME, uint32_t) \ - TRIGONOMETRY_KERNEL_CAST(NAME, int64_t) \ - TRIGONOMETRY_KERNEL_CAST(NAME, uint64_t) \ - TRIGONOMETRY_KERNEL(NAME, float) \ - TRIGONOMETRY_KERNEL(NAME, double) + TRIGONOMETRY_KERNEL_CAST(NAME, int8_t) \ + TRIGONOMETRY_KERNEL_CAST(NAME, uint8_t) \ + TRIGONOMETRY_KERNEL_CAST(NAME, int16_t) \ + TRIGONOMETRY_KERNEL_CAST(NAME, uint16_t) \ + TRIGONOMETRY_KERNEL_CAST(NAME, int32_t) \ + TRIGONOMETRY_KERNEL_CAST(NAME, uint32_t) \ + TRIGONOMETRY_KERNEL_CAST(NAME, int64_t) \ + TRIGONOMETRY_KERNEL_CAST(NAME, uint64_t) \ + TRIGONOMETRY_KERNEL(NAME, float) \ + TRIGONOMETRY_KERNEL(NAME, double) TRIG_IMPL(sin) TRIG_IMPL(cos) diff --git a/librapid/include/librapid/opencl/openclConfigure.hpp b/librapid/include/librapid/opencl/openclConfigure.hpp index 6475049c..b7262cc9 100644 --- a/librapid/include/librapid/opencl/openclConfigure.hpp +++ b/librapid/include/librapid/opencl/openclConfigure.hpp @@ -3,13 +3,13 @@ namespace librapid { #if defined(LIBRAPID_HAS_OPENCL) - int64_t openclDeviceCompute(const cl::Device &device); - void updateOpenCLDevices(bool verbose = false); - cl::Device findFastestDevice(const std::vector &devices); - void addOpenCLKernelSource(const std::string &source); - void addOpenCLKernelFile(const std::string &filename); - void compileOpenCLKernels(bool verbose = false); - void configureOpenCL(bool verbose = false, bool ask = false); + int64_t openclDeviceCompute(const cl::Device &device); + void updateOpenCLDevices(bool verbose = false); + cl::Device findFastestDevice(const std::vector &devices); + void addOpenCLKernelSource(const std::string &source); + void addOpenCLKernelFile(const std::string &filename); + void compileOpenCLKernels(bool verbose = false); + void configureOpenCL(bool verbose = false, bool ask = false); #endif // LIBRAPID_HAS_OPENCL } // namespace librapid diff --git a/librapid/include/librapid/opencl/openclErrorIdentifier.hpp b/librapid/include/librapid/opencl/openclErrorIdentifier.hpp index e672e27e..47562dec 100644 --- a/librapid/include/librapid/opencl/openclErrorIdentifier.hpp +++ b/librapid/include/librapid/opencl/openclErrorIdentifier.hpp @@ -3,8 +3,8 @@ namespace librapid::opencl { #if defined(LIBRAPID_HAS_OPENCL) - std::string getOpenCLErrorString(int64_t err); - std::string getCLBlastErrorString(clblast::StatusCode err); + std::string getOpenCLErrorString(int64_t err); + std::string getCLBlastErrorString(clblast::StatusCode err); #endif // LIBRAPID_HAS_OPENCL } // namespace librapid::opencl diff --git a/librapid/include/librapid/opencl/openclKernelProcessor.hpp b/librapid/include/librapid/opencl/openclKernelProcessor.hpp index 2c77b08c..93cf9447 100644 --- a/librapid/include/librapid/opencl/openclKernelProcessor.hpp +++ b/librapid/include/librapid/opencl/openclKernelProcessor.hpp @@ -4,43 +4,43 @@ #if defined(LIBRAPID_HAS_OPENCL) namespace librapid::opencl { - template - void setKernelArgs(cl::Kernel &kernel, const std::tuple &args, - std::index_sequence) { - constexpr auto caster = [](auto &&x) { - using T = std::decay_t; - if constexpr (noCast || std::is_same_v) { - return x; - } else if constexpr (typetraits::TypeInfo::type == detail::LibRapidType::Scalar) { - return static_cast(x); - } else { - return x; - } - }; - - ((kernel.setArg(I, caster(std::get(args)))), ...); - } - - template - void runLinearKernel(const std::string &kernelName, int64_t numElements, Args... args) { - static_assert(sizeof(Scalar) > 2, - "Scalar type must be larger than 2 bytes. Please create an issue on GitHub " - "if you need support for smaller types."); - - std::string kernelNameFull = kernelName + "_" + typetraits::TypeInfo::name; - cl::Kernel kernel(global::openCLProgram, kernelNameFull.c_str()); - setKernelArgs( - kernel, std::make_tuple(args...), std::make_index_sequence()); - - cl::NDRange range(numElements); - auto err = - global::openCLQueue.enqueueNDRangeKernel(kernel, cl::NullRange, range, cl::NullRange); - - LIBRAPID_ASSERT(err == CL_SUCCESS, - "OpenCL kernel execution failed with error code {}: {}", - err, - getOpenCLErrorString(err)); - } + template + void setKernelArgs(cl::Kernel &kernel, const std::tuple &args, + std::index_sequence) { + constexpr auto caster = [](auto &&x) { + using T = std::decay_t; + if constexpr (noCast || std::is_same_v) { + return x; + } else if constexpr (typetraits::TypeInfo::type == detail::LibRapidType::Scalar) { + return static_cast(x); + } else { + return x; + } + }; + + ((kernel.setArg(I, caster(std::get(args)))), ...); + } + + template + void runLinearKernel(const std::string &kernelName, int64_t numElements, Args... args) { + static_assert(sizeof(Scalar) > 2, + "Scalar type must be larger than 2 bytes. Please create an issue on GitHub " + "if you need support for smaller types."); + + std::string kernelNameFull = kernelName + "_" + typetraits::TypeInfo::name; + cl::Kernel kernel(global::openCLProgram, kernelNameFull.c_str()); + setKernelArgs( + kernel, std::make_tuple(args...), std::make_index_sequence()); + + cl::NDRange range(numElements); + auto err = + global::openCLQueue.enqueueNDRangeKernel(kernel, cl::NullRange, range, cl::NullRange); + + LIBRAPID_ASSERT(err == CL_SUCCESS, + "OpenCL kernel execution failed with error code {}: {}", + err, + getOpenCLErrorString(err)); + } } // namespace librapid::opencl #endif // LIBRAPID_HAS_OPENCL diff --git a/librapid/include/librapid/opencl/openclStorage.hpp b/librapid/include/librapid/opencl/openclStorage.hpp index fd5d909c..7305766e 100644 --- a/librapid/include/librapid/opencl/openclStorage.hpp +++ b/librapid/include/librapid/opencl/openclStorage.hpp @@ -10,419 +10,419 @@ #if defined(LIBRAPID_HAS_OPENCL) -# define LIBRAPID_CHECK_OPENCL \ - LIBRAPID_ASSERT(global::openCLConfigured, \ - "OpenCL has not been configured. Please call configureOpenCL() before " \ - "creating any Arrays with the OpenCL backend.") +# define LIBRAPID_CHECK_OPENCL \ + LIBRAPID_ASSERT(global::openCLConfigured, \ + "OpenCL has not been configured. Please call configureOpenCL() before " \ + "creating any Arrays with the OpenCL backend.") namespace librapid { - namespace typetraits { - template - struct TypeInfo> { - static constexpr bool isLibRapidType = true; - using Scalar = Scalar_; - using Backend = backend::OpenCL; - }; - - template - struct IsOpenCLStorage : std::false_type {}; - - template - struct IsOpenCLStorage> : std::true_type {}; - - LIBRAPID_DEFINE_AS_TYPE(typename Scalar_, OpenCLStorage); - } // namespace typetraits - - namespace detail { -# define OPENCL_REF_OPERATOR(OP) \ - template \ - auto operator OP(const OpenCLRef &lhs, const RHS &rhs) { \ - return lhs.get() OP rhs; \ - } \ + namespace typetraits { + template + struct TypeInfo> { + static constexpr bool isLibRapidType = true; + using Scalar = Scalar_; + using Backend = backend::OpenCL; + }; + + template + struct IsOpenCLStorage : std::false_type {}; + + template + struct IsOpenCLStorage> : std::true_type {}; + + LIBRAPID_DEFINE_AS_TYPE(typename Scalar_, OpenCLStorage); + } // namespace typetraits + + namespace detail { +# define OPENCL_REF_OPERATOR(OP) \ + template \ + auto operator OP(const OpenCLRef &lhs, const RHS &rhs) { \ + return lhs.get() OP rhs; \ + } \ \ - template \ - auto operator OP(const LHS &lhs, const OpenCLRef &rhs) { \ - return lhs OP rhs.get(); \ - } \ + template \ + auto operator OP(const LHS &lhs, const OpenCLRef &rhs) { \ + return lhs OP rhs.get(); \ + } \ \ - template \ - auto operator OP(const OpenCLRef &lhs, const OpenCLRef &rhs) { \ - return lhs.get() OP rhs.get(); \ - } \ + template \ + auto operator OP(const OpenCLRef &lhs, const OpenCLRef &rhs) { \ + return lhs.get() OP rhs.get(); \ + } \ \ - template \ - auto operator OP##=(OpenCLRef &lhs, const RHS &rhs) { \ - lhs = lhs.get() OP rhs; \ - } \ + template \ + auto operator OP##=(OpenCLRef &lhs, const RHS &rhs) { \ + lhs = lhs.get() OP rhs; \ + } \ \ - template \ - auto operator OP##=(OpenCLRef &lhs, const OpenCLRef &rhs) { \ - lhs = lhs.get() OP rhs.get(); \ - } - -# define OPENCL_REF_OPERATOR_NO_ASSIGN(OP) \ - template \ - auto operator OP(const OpenCLRef &lhs, const RHS &rhs) { \ - return lhs.get() OP rhs; \ - } \ + template \ + auto operator OP##=(OpenCLRef &lhs, const OpenCLRef &rhs) { \ + lhs = lhs.get() OP rhs.get(); \ + } + +# define OPENCL_REF_OPERATOR_NO_ASSIGN(OP) \ + template \ + auto operator OP(const OpenCLRef &lhs, const RHS &rhs) { \ + return lhs.get() OP rhs; \ + } \ \ - template \ - auto operator OP(const LHS &lhs, const OpenCLRef &rhs) { \ - return lhs OP rhs.get(); \ - } \ + template \ + auto operator OP(const LHS &lhs, const OpenCLRef &rhs) { \ + return lhs OP rhs.get(); \ + } \ \ - template \ - auto operator OP(const OpenCLRef &lhs, const OpenCLRef &rhs) { \ - return lhs.get() OP rhs.get(); \ - } - - template - class OpenCLRef { - public: - OpenCLRef(const cl::Buffer &buffer, size_t offset) : - m_buffer(buffer), m_offset(offset) {} - - LIBRAPID_ALWAYS_INLINE OpenCLRef &operator=(const T &val) { - global::openCLQueue.enqueueWriteBuffer( - m_buffer, CL_TRUE, m_offset * sizeof(T), sizeof(T), &val); - return *this; - } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T get() const { - T tmp; - global::openCLQueue.enqueueReadBuffer( - m_buffer, CL_TRUE, m_offset * sizeof(T), sizeof(T), &tmp); - global::openCLQueue.finish(); - return tmp; - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE operator CAST() const { - return static_cast(get()); - } - - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const { - return fmt::format(format, get()); - } - - private: - cl::Buffer m_buffer; - size_t m_offset; - }; - - OPENCL_REF_OPERATOR(+) - OPENCL_REF_OPERATOR(-) - OPENCL_REF_OPERATOR(*) - OPENCL_REF_OPERATOR(/) - OPENCL_REF_OPERATOR(%) - OPENCL_REF_OPERATOR(^) - OPENCL_REF_OPERATOR(&) - OPENCL_REF_OPERATOR(|) - OPENCL_REF_OPERATOR(<<) - OPENCL_REF_OPERATOR(>>) - OPENCL_REF_OPERATOR_NO_ASSIGN(==) - OPENCL_REF_OPERATOR_NO_ASSIGN(!=) - OPENCL_REF_OPERATOR_NO_ASSIGN(<) - OPENCL_REF_OPERATOR_NO_ASSIGN(>) - OPENCL_REF_OPERATOR_NO_ASSIGN(<=) - OPENCL_REF_OPERATOR_NO_ASSIGN(>=) - } // namespace detail - - template - class OpenCLStorage { - public: - using Scalar = Scalar_; - using Pointer = Scalar *; - using ConstPointer = const Scalar *; - using Reference = Scalar &; - using ConstReference = const Scalar &; - using SizeType = size_t; - using DifferenceType = ptrdiff_t; - static constexpr cl_int bufferFlags = CL_MEM_READ_WRITE; - - /// \brief Default constructor - OpenCLStorage() = default; - - /// \brief Construct an OpenCLStorage with the given size. The data is not initialised. - /// \param size The size of the OpenCLStorage - LIBRAPID_ALWAYS_INLINE explicit OpenCLStorage(SizeType size); - - /// \brief Construct an OpenCLStorage with the given size and initialise it with the given - /// value. - /// \param size The size of the OpenCLStorage - /// \param value The value to initialise the OpenCLStorage with - LIBRAPID_ALWAYS_INLINE OpenCLStorage(SizeType size, Scalar value); - - LIBRAPID_ALWAYS_INLINE OpenCLStorage(const cl::Buffer &buffer, SizeType size, - bool ownsData); - - /// \brief Construct an OpenCLStorage from another instance - /// \param other The other instance - LIBRAPID_ALWAYS_INLINE OpenCLStorage(const OpenCLStorage &other); - - /// \brief Move-construct an OpenCLStorage from another instance - /// \param other The other instance - LIBRAPID_ALWAYS_INLINE OpenCLStorage(OpenCLStorage &&other) LIBRAPID_RELEASE_NOEXCEPT; - - /// \brief Initialize an OpenCLStorage instance from an initializer-list - /// \param list Values to populate with - LIBRAPID_ALWAYS_INLINE OpenCLStorage(std::initializer_list list); - - /// \brief Initialize an OpenCLStorage instance from a vector - /// \param vec Values to populate with - LIBRAPID_ALWAYS_INLINE explicit OpenCLStorage(const std::vector &vec); - - LIBRAPID_ALWAYS_INLINE OpenCLStorage &operator=(const OpenCLStorage &other); - - LIBRAPID_ALWAYS_INLINE OpenCLStorage & - operator=(OpenCLStorage &&other) LIBRAPID_RELEASE_NOEXCEPT; - - void set(const OpenCLStorage &other); - - OpenCLStorage copy() const; - - template - static ShapeType defaultShape(); - - template - static OpenCLStorage fromData(const std::initializer_list &list); - - template - static OpenCLStorage fromData(const std::vector &vec); - - ~OpenCLStorage(); - - /// Resize a CudaStorage object to \p size elements. Existing elements are preserved where - /// possible. - /// \param size Number of elements - /// \see resize(SizeType, int) - LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize); - - /// Resize a CudaStorage object to \p size elements. Existing elements are not preserved. - /// This method of resizing is faster and more efficient than the version which preserves - /// the original data, but of course, this has the drawback that data will be lost. - /// \param size Number of elements - LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize, SizeType value); - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE SizeType size() const; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::OpenCLRef - operator[](SizeType index); - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const detail::OpenCLRef - operator[](SizeType index) const; - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const cl::Buffer &data() const; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE cl::Buffer &data(); - - private: - LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize); - LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize, int); - - SizeType m_size; - cl::Buffer m_buffer; - bool m_ownsData = true; - }; - - template - OpenCLStorage::OpenCLStorage(SizeType size) : - m_size(size), m_buffer(global::openCLContext, bufferFlags, size * sizeof(Scalar)), - m_ownsData(true) { - LIBRAPID_CHECK_OPENCL; - } - - template - OpenCLStorage::OpenCLStorage(SizeType size, Scalar value) : - m_size(size), m_buffer(global::openCLContext, bufferFlags, size * sizeof(Scalar)) { - LIBRAPID_CHECK_OPENCL; - global::openCLQueue.enqueueFillBuffer(m_buffer, value, 0, size * sizeof(Scalar)); - } - - template - OpenCLStorage::OpenCLStorage(const cl::Buffer &buffer, SizeType size, bool ownsData) : - m_size(size), m_buffer(buffer), m_ownsData(ownsData) { - LIBRAPID_CHECK_OPENCL; - } - - template - OpenCLStorage::OpenCLStorage(const OpenCLStorage &other) : - m_size(other.m_size), m_buffer(other.m_buffer), m_ownsData(true) { - LIBRAPID_CHECK_OPENCL; - } - - template - OpenCLStorage::OpenCLStorage(OpenCLStorage &&other) LIBRAPID_RELEASE_NOEXCEPT - : m_size(std::move(other.m_size)), - m_buffer(std::move(other.m_buffer)), - m_ownsData(other.m_ownsData) { - LIBRAPID_CHECK_OPENCL; - other.m_size = 0; - other.m_ownsData = false; - } - - template - OpenCLStorage::OpenCLStorage(std::initializer_list list) : - m_size(list.size()), - m_buffer(global::openCLContext, bufferFlags, list.size() * sizeof(Scalar)), - m_ownsData(true) { - LIBRAPID_CHECK_OPENCL; - global::openCLQueue.enqueueWriteBuffer( - m_buffer, CL_TRUE, 0, m_size * sizeof(Scalar), list.begin()); - } - - template - OpenCLStorage::OpenCLStorage(const std::vector &vec) : - m_size(vec.size()), - m_buffer(global::openCLContext, bufferFlags, m_size * sizeof(Scalar)), - m_ownsData(true) { - LIBRAPID_CHECK_OPENCL; - global::openCLQueue.enqueueWriteBuffer( - m_buffer, CL_TRUE, 0, m_size * sizeof(Scalar), vec.data()); - } - - template - OpenCLStorage &OpenCLStorage::operator=(const OpenCLStorage &other) { - LIBRAPID_CHECK_OPENCL; - if (this != &other) { - if (m_ownsData) { - m_buffer = other.m_buffer; - m_size = other.m_size; - } else { - LIBRAPID_ASSERT(m_size == other.m_size, - "Cannot copy storage with {} elements to dependent storage with " - "{} elements", - other.m_size, - m_size); - - global::openCLQueue.enqueueCopyBuffer( - other.m_buffer, m_buffer, 0, 0, m_size * sizeof(Scalar)); - } - } - return *this; - } - - template - OpenCLStorage & - OpenCLStorage::operator=(OpenCLStorage &&other) LIBRAPID_RELEASE_NOEXCEPT { - LIBRAPID_CHECK_OPENCL; - if (this != &other) { - if (m_ownsData) { - std::swap(m_buffer, other.m_buffer); - std::swap(m_size, other.m_size); - std::swap(m_ownsData, other.m_ownsData); - } else { - LIBRAPID_ASSERT(m_size == other.m_size, - "Cannot move into dependent OpenCLStorage " - "with different size"); - global::openCLQueue.enqueueCopyBuffer( - other.m_buffer, m_buffer, 0, 0, m_size * sizeof(Scalar)); - } - } - return *this; - } - - template - void OpenCLStorage::set(const OpenCLStorage &other) { - LIBRAPID_CHECK_OPENCL; - m_buffer = other.m_buffer; - m_size = other.m_size; - m_ownsData = other.m_ownsData; - } - - template - auto OpenCLStorage::copy() const -> OpenCLStorage { - LIBRAPID_CHECK_OPENCL; - OpenCLStorage result(m_size); - global::openCLQueue.enqueueCopyBuffer( - m_buffer, result.m_buffer, 0, 0, m_size * sizeof(Scalar)); - return result; - } - - template - template - ShapeType OpenCLStorage::defaultShape() { - return ShapeType({0}); - } - - template - template - OpenCLStorage OpenCLStorage::fromData(const std::initializer_list &list) { - return OpenCLStorage(list); - } - - template - template - OpenCLStorage OpenCLStorage::fromData(const std::vector &vec) { - return OpenCLStorage(vec); - } - - template - OpenCLStorage::~OpenCLStorage() { - // cl::Buffer is reference counted, so we do not need to worry about whether the array - // owns the buffer or not. If it does, the buffer will be deleted when the array is - // destroyed. If it does not, the buffer will be deleted when the last array referencing - // it is destroyed. - } - - template - void OpenCLStorage::resize(SizeType newSize) { - resizeImpl(newSize); - } - - template - void OpenCLStorage::resize(SizeType newSize, SizeType value) { - resizeImpl(newSize, 0); - } - - template - void OpenCLStorage::resizeImpl(SizeType newSize) { - if (newSize == m_size) return; - m_size = newSize; - cl::Buffer newBuffer(global::openCLContext, bufferFlags, newSize * sizeof(Scalar)); - global::openCLQueue.enqueueCopyBuffer(m_buffer, newBuffer, 0, 0, m_size * sizeof(Scalar)); - m_buffer = std::move(newBuffer); - } - - template - void OpenCLStorage::resizeImpl(SizeType newSize, int) { - if (newSize == m_size) return; - m_size = newSize; - m_buffer = cl::Buffer(global::openCLContext, bufferFlags, newSize * sizeof(Scalar)); - } - - template - auto OpenCLStorage::size() const -> SizeType { - return m_size; - } - - template - auto OpenCLStorage::operator[](SizeType index) const - -> const detail::OpenCLRef { - LIBRAPID_ASSERT(index >= 0 && index < m_size, - "Index {} is out of range for OpenCLStorage with {} elements", - index, - m_size); - return detail::OpenCLRef(m_buffer, index); - } - - template - auto OpenCLStorage::operator[](SizeType index) -> detail::OpenCLRef { - LIBRAPID_ASSERT(index >= 0 && index < m_size, - "Index {} is out of range for OpenCLStorage with {} elements", - index, - m_size); - return detail::OpenCLRef(m_buffer, index); - } - - template - auto OpenCLStorage::data() const -> const cl::Buffer & { - return m_buffer; - } - - template - auto OpenCLStorage::data() -> cl::Buffer & { - return m_buffer; - } + template \ + auto operator OP(const OpenCLRef &lhs, const OpenCLRef &rhs) { \ + return lhs.get() OP rhs.get(); \ + } + + template + class OpenCLRef { + public: + OpenCLRef(const cl::Buffer &buffer, size_t offset) : + m_buffer(buffer), m_offset(offset) {} + + LIBRAPID_ALWAYS_INLINE OpenCLRef &operator=(const T &val) { + global::openCLQueue.enqueueWriteBuffer( + m_buffer, CL_TRUE, m_offset * sizeof(T), sizeof(T), &val); + return *this; + } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T get() const { + T tmp; + global::openCLQueue.enqueueReadBuffer( + m_buffer, CL_TRUE, m_offset * sizeof(T), sizeof(T), &tmp); + global::openCLQueue.finish(); + return tmp; + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE operator CAST() const { + return static_cast(get()); + } + + LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const { + return fmt::format(format, get()); + } + + private: + cl::Buffer m_buffer; + size_t m_offset; + }; + + OPENCL_REF_OPERATOR(+) + OPENCL_REF_OPERATOR(-) + OPENCL_REF_OPERATOR(*) + OPENCL_REF_OPERATOR(/) + OPENCL_REF_OPERATOR(%) + OPENCL_REF_OPERATOR(^) + OPENCL_REF_OPERATOR(&) + OPENCL_REF_OPERATOR(|) + OPENCL_REF_OPERATOR(<<) + OPENCL_REF_OPERATOR(>>) + OPENCL_REF_OPERATOR_NO_ASSIGN(==) + OPENCL_REF_OPERATOR_NO_ASSIGN(!=) + OPENCL_REF_OPERATOR_NO_ASSIGN(<) + OPENCL_REF_OPERATOR_NO_ASSIGN(>) + OPENCL_REF_OPERATOR_NO_ASSIGN(<=) + OPENCL_REF_OPERATOR_NO_ASSIGN(>=) + } // namespace detail + + template + class OpenCLStorage { + public: + using Scalar = Scalar_; + using Pointer = Scalar *; + using ConstPointer = const Scalar *; + using Reference = Scalar &; + using ConstReference = const Scalar &; + using SizeType = size_t; + using DifferenceType = ptrdiff_t; + static constexpr cl_int bufferFlags = CL_MEM_READ_WRITE; + + /// \brief Default constructor + OpenCLStorage() = default; + + /// \brief Construct an OpenCLStorage with the given size. The data is not initialised. + /// \param size The size of the OpenCLStorage + LIBRAPID_ALWAYS_INLINE explicit OpenCLStorage(SizeType size); + + /// \brief Construct an OpenCLStorage with the given size and initialise it with the given + /// value. + /// \param size The size of the OpenCLStorage + /// \param value The value to initialise the OpenCLStorage with + LIBRAPID_ALWAYS_INLINE OpenCLStorage(SizeType size, Scalar value); + + LIBRAPID_ALWAYS_INLINE OpenCLStorage(const cl::Buffer &buffer, SizeType size, + bool ownsData); + + /// \brief Construct an OpenCLStorage from another instance + /// \param other The other instance + LIBRAPID_ALWAYS_INLINE OpenCLStorage(const OpenCLStorage &other); + + /// \brief Move-construct an OpenCLStorage from another instance + /// \param other The other instance + LIBRAPID_ALWAYS_INLINE OpenCLStorage(OpenCLStorage &&other) LIBRAPID_RELEASE_NOEXCEPT; + + /// \brief Initialize an OpenCLStorage instance from an initializer-list + /// \param list Values to populate with + LIBRAPID_ALWAYS_INLINE OpenCLStorage(std::initializer_list list); + + /// \brief Initialize an OpenCLStorage instance from a vector + /// \param vec Values to populate with + LIBRAPID_ALWAYS_INLINE explicit OpenCLStorage(const std::vector &vec); + + LIBRAPID_ALWAYS_INLINE OpenCLStorage &operator=(const OpenCLStorage &other); + + LIBRAPID_ALWAYS_INLINE OpenCLStorage & + operator=(OpenCLStorage &&other) LIBRAPID_RELEASE_NOEXCEPT; + + void set(const OpenCLStorage &other); + + OpenCLStorage copy() const; + + template + static ShapeType defaultShape(); + + template + static OpenCLStorage fromData(const std::initializer_list &list); + + template + static OpenCLStorage fromData(const std::vector &vec); + + ~OpenCLStorage(); + + /// Resize a CudaStorage object to \p size elements. Existing elements are preserved where + /// possible. + /// \param size Number of elements + /// \see resize(SizeType, int) + LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize); + + /// Resize a CudaStorage object to \p size elements. Existing elements are not preserved. + /// This method of resizing is faster and more efficient than the version which preserves + /// the original data, but of course, this has the drawback that data will be lost. + /// \param size Number of elements + LIBRAPID_ALWAYS_INLINE void resize(SizeType newSize, SizeType value); + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE SizeType size() const; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::OpenCLRef + operator[](SizeType index); + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const detail::OpenCLRef + operator[](SizeType index) const; + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE const cl::Buffer &data() const; + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE cl::Buffer &data(); + + private: + LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize); + LIBRAPID_ALWAYS_INLINE void resizeImpl(SizeType newSize, int); + + SizeType m_size; + cl::Buffer m_buffer; + bool m_ownsData = true; + }; + + template + OpenCLStorage::OpenCLStorage(SizeType size) : + m_size(size), m_buffer(global::openCLContext, bufferFlags, size * sizeof(Scalar)), + m_ownsData(true) { + LIBRAPID_CHECK_OPENCL; + } + + template + OpenCLStorage::OpenCLStorage(SizeType size, Scalar value) : + m_size(size), m_buffer(global::openCLContext, bufferFlags, size * sizeof(Scalar)) { + LIBRAPID_CHECK_OPENCL; + global::openCLQueue.enqueueFillBuffer(m_buffer, value, 0, size * sizeof(Scalar)); + } + + template + OpenCLStorage::OpenCLStorage(const cl::Buffer &buffer, SizeType size, bool ownsData) : + m_size(size), m_buffer(buffer), m_ownsData(ownsData) { + LIBRAPID_CHECK_OPENCL; + } + + template + OpenCLStorage::OpenCLStorage(const OpenCLStorage &other) : + m_size(other.m_size), m_buffer(other.m_buffer), m_ownsData(true) { + LIBRAPID_CHECK_OPENCL; + } + + template + OpenCLStorage::OpenCLStorage(OpenCLStorage &&other) LIBRAPID_RELEASE_NOEXCEPT + : m_size(std::move(other.m_size)), + m_buffer(std::move(other.m_buffer)), + m_ownsData(other.m_ownsData) { + LIBRAPID_CHECK_OPENCL; + other.m_size = 0; + other.m_ownsData = false; + } + + template + OpenCLStorage::OpenCLStorage(std::initializer_list list) : + m_size(list.size()), + m_buffer(global::openCLContext, bufferFlags, list.size() * sizeof(Scalar)), + m_ownsData(true) { + LIBRAPID_CHECK_OPENCL; + global::openCLQueue.enqueueWriteBuffer( + m_buffer, CL_TRUE, 0, m_size * sizeof(Scalar), list.begin()); + } + + template + OpenCLStorage::OpenCLStorage(const std::vector &vec) : + m_size(vec.size()), + m_buffer(global::openCLContext, bufferFlags, m_size * sizeof(Scalar)), + m_ownsData(true) { + LIBRAPID_CHECK_OPENCL; + global::openCLQueue.enqueueWriteBuffer( + m_buffer, CL_TRUE, 0, m_size * sizeof(Scalar), vec.data()); + } + + template + OpenCLStorage &OpenCLStorage::operator=(const OpenCLStorage &other) { + LIBRAPID_CHECK_OPENCL; + if (this != &other) { + if (m_ownsData) { + m_buffer = other.m_buffer; + m_size = other.m_size; + } else { + LIBRAPID_ASSERT(m_size == other.m_size, + "Cannot copy storage with {} elements to dependent storage with " + "{} elements", + other.m_size, + m_size); + + global::openCLQueue.enqueueCopyBuffer( + other.m_buffer, m_buffer, 0, 0, m_size * sizeof(Scalar)); + } + } + return *this; + } + + template + OpenCLStorage & + OpenCLStorage::operator=(OpenCLStorage &&other) LIBRAPID_RELEASE_NOEXCEPT { + LIBRAPID_CHECK_OPENCL; + if (this != &other) { + if (m_ownsData) { + std::swap(m_buffer, other.m_buffer); + std::swap(m_size, other.m_size); + std::swap(m_ownsData, other.m_ownsData); + } else { + LIBRAPID_ASSERT(m_size == other.m_size, + "Cannot move into dependent OpenCLStorage " + "with different size"); + global::openCLQueue.enqueueCopyBuffer( + other.m_buffer, m_buffer, 0, 0, m_size * sizeof(Scalar)); + } + } + return *this; + } + + template + void OpenCLStorage::set(const OpenCLStorage &other) { + LIBRAPID_CHECK_OPENCL; + m_buffer = other.m_buffer; + m_size = other.m_size; + m_ownsData = other.m_ownsData; + } + + template + auto OpenCLStorage::copy() const -> OpenCLStorage { + LIBRAPID_CHECK_OPENCL; + OpenCLStorage result(m_size); + global::openCLQueue.enqueueCopyBuffer( + m_buffer, result.m_buffer, 0, 0, m_size * sizeof(Scalar)); + return result; + } + + template + template + ShapeType OpenCLStorage::defaultShape() { + return ShapeType({0}); + } + + template + template + OpenCLStorage OpenCLStorage::fromData(const std::initializer_list &list) { + return OpenCLStorage(list); + } + + template + template + OpenCLStorage OpenCLStorage::fromData(const std::vector &vec) { + return OpenCLStorage(vec); + } + + template + OpenCLStorage::~OpenCLStorage() { + // cl::Buffer is reference counted, so we do not need to worry about whether the array + // owns the buffer or not. If it does, the buffer will be deleted when the array is + // destroyed. If it does not, the buffer will be deleted when the last array referencing + // it is destroyed. + } + + template + void OpenCLStorage::resize(SizeType newSize) { + resizeImpl(newSize); + } + + template + void OpenCLStorage::resize(SizeType newSize, SizeType value) { + resizeImpl(newSize, 0); + } + + template + void OpenCLStorage::resizeImpl(SizeType newSize) { + if (newSize == m_size) return; + m_size = newSize; + cl::Buffer newBuffer(global::openCLContext, bufferFlags, newSize * sizeof(Scalar)); + global::openCLQueue.enqueueCopyBuffer(m_buffer, newBuffer, 0, 0, m_size * sizeof(Scalar)); + m_buffer = std::move(newBuffer); + } + + template + void OpenCLStorage::resizeImpl(SizeType newSize, int) { + if (newSize == m_size) return; + m_size = newSize; + m_buffer = cl::Buffer(global::openCLContext, bufferFlags, newSize * sizeof(Scalar)); + } + + template + auto OpenCLStorage::size() const -> SizeType { + return m_size; + } + + template + auto OpenCLStorage::operator[](SizeType index) const + -> const detail::OpenCLRef { + LIBRAPID_ASSERT(index >= 0 && index < m_size, + "Index {} is out of range for OpenCLStorage with {} elements", + index, + m_size); + return detail::OpenCLRef(m_buffer, index); + } + + template + auto OpenCLStorage::operator[](SizeType index) -> detail::OpenCLRef { + LIBRAPID_ASSERT(index >= 0 && index < m_size, + "Index {} is out of range for OpenCLStorage with {} elements", + index, + m_size); + return detail::OpenCLRef(m_buffer, index); + } + + template + auto OpenCLStorage::data() const -> const cl::Buffer & { + return m_buffer; + } + + template + auto OpenCLStorage::data() -> cl::Buffer & { + return m_buffer; + } } // namespace librapid #endif // LIBRAPID_HAS_OPENCL diff --git a/librapid/include/librapid/simd/vecOps.hpp b/librapid/include/librapid/simd/vecOps.hpp index a4f1f6ad..a0c8109b 100644 --- a/librapid/include/librapid/simd/vecOps.hpp +++ b/librapid/include/librapid/simd/vecOps.hpp @@ -2,109 +2,109 @@ #define LIBRAPID_SIMD_TRIGONOMETRY namespace librapid { - namespace typetraits { - template - struct IsSIMD : std::false_type {}; + namespace typetraits { + template + struct IsSIMD : std::false_type {}; - template - struct IsSIMD> : std::true_type {}; + template + struct IsSIMD> : std::true_type {}; - template - struct IsSIMD> : std::true_type {}; - } // namespace typetraits + template + struct IsSIMD> : std::true_type {}; + } // namespace typetraits #define REQUIRE_SIMD(TYPE) typename std::enable_if_t::value, int> = 0 #define IF_FLOATING(TYPE) if constexpr (std::is_floating_point_v) - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sin(const T &x) { - return xsimd::sin(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cos(const T &x) { - return xsimd::cos(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto tan(const T &x) { - return xsimd::tan(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto asin(const T &x) { - return xsimd::asin(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto acos(const T &x) { - return xsimd::acos(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto atan(const T &x) { - return xsimd::atan(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sinh(const T &x) { - return xsimd::sinh(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cosh(const T &x) { - return xsimd::cosh(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto tanh(const T &x) { - return xsimd::tanh(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto exp(const T &x) { - return xsimd::exp(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log(const T &x) { - return xsimd::log(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log2(const T &x) { - return xsimd::log2(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log10(const T &x) { - return xsimd::log10(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrt(const T &x) { - return xsimd::sqrt(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cbrt(const T &x) { - return xsimd::cbrt(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto abs(const T &x) { - return xsimd::abs(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto floor(const T &x) { - return xsimd::floor(x); - } - - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ceil(const T &x) { - return xsimd::ceil(x); - } + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sin(const T &x) { + return xsimd::sin(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cos(const T &x) { + return xsimd::cos(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto tan(const T &x) { + return xsimd::tan(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto asin(const T &x) { + return xsimd::asin(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto acos(const T &x) { + return xsimd::acos(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto atan(const T &x) { + return xsimd::atan(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sinh(const T &x) { + return xsimd::sinh(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cosh(const T &x) { + return xsimd::cosh(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto tanh(const T &x) { + return xsimd::tanh(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto exp(const T &x) { + return xsimd::exp(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log(const T &x) { + return xsimd::log(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log2(const T &x) { + return xsimd::log2(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto log10(const T &x) { + return xsimd::log10(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto sqrt(const T &x) { + return xsimd::sqrt(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto cbrt(const T &x) { + return xsimd::cbrt(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto abs(const T &x) { + return xsimd::abs(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto floor(const T &x) { + return xsimd::floor(x); + } + + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ceil(const T &x) { + return xsimd::ceil(x); + } } // namespace librapid #endif // LIBRAPID_SIMD_TRIGONOMETRY \ No newline at end of file diff --git a/librapid/include/librapid/utils/cacheLineSize.hpp b/librapid/include/librapid/utils/cacheLineSize.hpp index 5b2525ed..5e3ccfa4 100644 --- a/librapid/include/librapid/utils/cacheLineSize.hpp +++ b/librapid/include/librapid/utils/cacheLineSize.hpp @@ -2,10 +2,10 @@ #define LIBRAPID_UTILS_CACHE_LINE_SIZE_HPP namespace librapid { - /// Returns the cache line size of the processor, in bytes. If the cache size cannot be - /// determined, the return value is 64. - /// \return Cache line size in bytes - size_t cacheLineSize(); -} + /// Returns the cache line size of the processor, in bytes. If the cache size cannot be + /// determined, the return value is 64. + /// \return Cache line size in bytes + size_t cacheLineSize(); +} // namespace librapid #endif // LIBRAPID_UTILS_CACHE_LINE_SIZE_HPP \ No newline at end of file diff --git a/librapid/include/librapid/utils/map.hpp b/librapid/include/librapid/utils/map.hpp index 8167da9f..fa80465d 100644 --- a/librapid/include/librapid/utils/map.hpp +++ b/librapid/include/librapid/utils/map.hpp @@ -2,131 +2,131 @@ #define LIBRAPID_UTILS_MAP_HPP namespace librapid { - template - class Map : public std::map { - public: - /// \brief Check if a key exists in the map - /// \param key Key to search for - /// \return Boolean - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool contains(const Key &key) const { - return this->find(key) != this->end(); - } + template + class Map : public std::map { + public: + /// \brief Check if a key exists in the map + /// \param key Key to search for + /// \return Boolean + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool contains(const Key &key) const { + return this->find(key) != this->end(); + } - /// \brief Check if a key exists in the map and, if it does, set \p value to the value of - /// the key. The function returns true if the key exists, false otherwise. (If the function - /// returns false, \p value will not be modified/initialized, so make sure you check the - /// return value!) - /// \param key Key to search for - /// \param value Value of the key, if it exists (output) - /// \return True if the key exists, false otherwise - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool contains(const Key &key, - Value &value) const { - auto it = this->find(key); - if (it != this->end()) { - value = it->second; - return true; - } - return false; - } + /// \brief Check if a key exists in the map and, if it does, set \p value to the value of + /// the key. The function returns true if the key exists, false otherwise. (If the function + /// returns false, \p value will not be modified/initialized, so make sure you check the + /// return value!) + /// \param key Key to search for + /// \param value Value of the key, if it exists (output) + /// \return True if the key exists, false otherwise + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool contains(const Key &key, + Value &value) const { + auto it = this->find(key); + if (it != this->end()) { + value = it->second; + return true; + } + return false; + } - /// \brief Get the value of a key - /// \param key Key to search for - /// \return Value of the key - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto get(const Key &key) const { - return (*this)[key]; - } + /// \brief Get the value of a key + /// \param key Key to search for + /// \return Value of the key + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto get(const Key &key) const { + return (*this)[key]; + } - /// \brief Get the value of a key, or a default value if the key does not exist - /// \param key Key to search for - /// \param defaultValue Default value to return if the key does not exist - /// \return Value of the key, or \p defaultValue if the key does not exist - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto get(const Key &key, - const Value &defaultValue) const { - auto it = this->find(key); - if (it != this->end()) { return it->second; } - return defaultValue; - } + /// \brief Get the value of a key, or a default value if the key does not exist + /// \param key Key to search for + /// \param defaultValue Default value to return if the key does not exist + /// \return Value of the key, or \p defaultValue if the key does not exist + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto get(const Key &key, + const Value &defaultValue) const { + auto it = this->find(key); + if (it != this->end()) { return it->second; } + return defaultValue; + } - LIBRAPID_NODISCARD std::string str(const std::string &keyFormat = "{}", - const std::string &valueFormat = "{}") const { - std::string str = "[\n"; - for (const auto &pair : *this) { - str += " " + fmt::format(keyFormat, pair.first); - str += ": "; - str += fmt::format(valueFormat, pair.second); - str += "\n"; - } - str += "]"; + LIBRAPID_NODISCARD std::string str(const std::string &keyFormat = "{}", + const std::string &valueFormat = "{}") const { + std::string str = "[\n"; + for (const auto &pair : *this) { + str += " " + fmt::format(keyFormat, pair.first); + str += ": "; + str += fmt::format(valueFormat, pair.second); + str += "\n"; + } + str += "]"; - return str; - } - }; + return str; + } + }; - template - class UnorderedMap : public std::unordered_map { - public: - /// \brief Check if a key exists in the map - /// \param key Key to search for - /// \return Boolean - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool contains(const Key &key) const { - return this->find(key) != this->end(); - } + template + class UnorderedMap : public std::unordered_map { + public: + /// \brief Check if a key exists in the map + /// \param key Key to search for + /// \return Boolean + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool contains(const Key &key) const { + return this->find(key) != this->end(); + } - /// \brief Check if a key exists in the map and, if it does, set \p value to the value of - /// the key. The function returns true if the key exists, false otherwise. (If the function - /// returns false, \p value will not be modified/initialized, so make sure you check the - /// return value!) - /// \param key Key to search for - /// \param value Value of the key, if it exists (output) - /// \return True if the key exists, false otherwise - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool contains(const Key &key, - Value &value) const { - auto it = this->find(key); - if (it != this->end()) { - value = it->second; - return true; - } - return false; - } + /// \brief Check if a key exists in the map and, if it does, set \p value to the value of + /// the key. The function returns true if the key exists, false otherwise. (If the function + /// returns false, \p value will not be modified/initialized, so make sure you check the + /// return value!) + /// \param key Key to search for + /// \param value Value of the key, if it exists (output) + /// \return True if the key exists, false otherwise + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool contains(const Key &key, + Value &value) const { + auto it = this->find(key); + if (it != this->end()) { + value = it->second; + return true; + } + return false; + } - /// \brief Get the value of a key - /// \param key Key to search for - /// \return Value of the key - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto get(const Key &key) const { - return (*this)[key]; - } + /// \brief Get the value of a key + /// \param key Key to search for + /// \return Value of the key + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto get(const Key &key) const { + return (*this)[key]; + } - /// \brief Get the value of a key, or a default value if the key does not exist - /// \param key Key to search for - /// \param defaultValue Default value to return if the key does not exist - /// \return Value of the key, or \p defaultValue if the key does not exist - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto get(const Key &key, - const Value &defaultValue) const { - auto it = this->find(key); - if (it != this->end()) { return it->second; } - return defaultValue; - } + /// \brief Get the value of a key, or a default value if the key does not exist + /// \param key Key to search for + /// \param defaultValue Default value to return if the key does not exist + /// \return Value of the key, or \p defaultValue if the key does not exist + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto get(const Key &key, + const Value &defaultValue) const { + auto it = this->find(key); + if (it != this->end()) { return it->second; } + return defaultValue; + } - LIBRAPID_NODISCARD std::string str(const std::string &keyFormat = "{}", - const std::string &valueFormat = "{}") const { - std::string str = "[\n"; - for (const auto &pair : *this) { - str += " " + fmt::format(keyFormat, pair.first); - str += ": "; - str += fmt::format(valueFormat, pair.second); - str += "\n"; - } - str += "]"; + LIBRAPID_NODISCARD std::string str(const std::string &keyFormat = "{}", + const std::string &valueFormat = "{}") const { + std::string str = "[\n"; + for (const auto &pair : *this) { + str += " " + fmt::format(keyFormat, pair.first); + str += ": "; + str += fmt::format(valueFormat, pair.second); + str += "\n"; + } + str += "]"; - return str; - } - }; + return str; + } + }; } // namespace librapid LIBRAPID_SIMPLE_IO_IMPL(typename Key COMMA typename Value, librapid::Map) LIBRAPID_SIMPLE_IO_NORANGE(typename Key COMMA typename Value, librapid::Map) LIBRAPID_SIMPLE_IO_IMPL(typename Key COMMA typename Value, librapid::UnorderedMap) LIBRAPID_SIMPLE_IO_NORANGE(typename Key COMMA typename Value, - librapid::UnorderedMap) + librapid::UnorderedMap) #endif // LIBRAPID_UTILS_MAP_HPP \ No newline at end of file diff --git a/librapid/include/librapid/utils/memUtils.hpp b/librapid/include/librapid/utils/memUtils.hpp index 511ec9bb..c99f9899 100644 --- a/librapid/include/librapid/utils/memUtils.hpp +++ b/librapid/include/librapid/utils/memUtils.hpp @@ -2,126 +2,126 @@ #define LIBRAPID_UTILS_MEMUTILS_HPP namespace librapid { - /// Cast the bits of one value directly into another type -- no conversion is performed - /// \tparam To The type to cast to - /// \tparam From The type to cast from - /// \param val The value to cast - /// \return The value bitwise mapped to the new type - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr To bitCast(const From &val) noexcept { - static_assert( - sizeof(To) == sizeof(From), - "Types have different sizes, and cannot be cast bit-for-bit between each other"); + /// Cast the bits of one value directly into another type -- no conversion is performed + /// \tparam To The type to cast to + /// \tparam From The type to cast from + /// \param val The value to cast + /// \return The value bitwise mapped to the new type + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE constexpr To bitCast(const From &val) noexcept { + static_assert( + sizeof(To) == sizeof(From), + "Types have different sizes, and cannot be cast bit-for-bit between each other"); #if defined(__CUDACC__) - To toOjb; // assumes default-init - ::std::memcpy(::std::memcpy::addressof(toOjb), ::std::memcpy::addressof(val), sizeof(To)); - return _To_obj; + To toOjb; // assumes default-init + ::std::memcpy(::std::memcpy::addressof(toOjb), ::std::memcpy::addressof(val), sizeof(To)); + return _To_obj; #elif defined(LIBRAPID_MSVC) - // MSVC doesn't support std::bit_cast until C++20 - return *(To *)(&val); + // MSVC doesn't support std::bit_cast until C++20 + return *(To *)(&val); #elif defined(LIBRAPID_GCC) || defined(LIBRAPID_CLANG) -# if __cplusplus > 201703l && __has_builtin(__builtin_bit_cast) - return __builtin_bit_cast(To, val); -# else - // Fallback option - return *(To *)(&val); -# endif // __has_builtin(__builtin_bit_cast) +# if __cplusplus > 201703l && __has_builtin(__builtin_bit_cast) + return __builtin_bit_cast(To, val); +# else + // Fallback option + return *(To *)(&val); +# endif // __has_builtin(__builtin_bit_cast) #else - // Further fallback option - return *(To *)(&val); + // Further fallback option + return *(To *)(&val); #endif - } + } - /// Returns true if the input value is NaN - /// \tparam T The type of the value - /// \param val The value to check - /// \return True if the value is NaN, false otherwise - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isNaN(const T &val) noexcept { - return std::isnan(val); - } + /// Returns true if the input value is NaN + /// \tparam T The type of the value + /// \param val The value to check + /// \return True if the value is NaN, false otherwise + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isNaN(const T &val) noexcept { + return std::isnan(val); + } - /// Returns true if the input value is finite - /// \tparam T The type of the value - /// \param val The value to check - /// \return True if the value is finite, - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isFinite(const T &val) noexcept { - return std::isfinite(val); - } + /// Returns true if the input value is finite + /// \tparam T The type of the value + /// \param val The value to check + /// \return True if the value is finite, + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isFinite(const T &val) noexcept { + return std::isfinite(val); + } - /// Returns true if the input value is infinite - /// \tparam T The type of the value - /// \param val The value to check - /// \return True if the value is infinite, - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isInf(const T &val) noexcept { - return std::isinf(val); - } + /// Returns true if the input value is infinite + /// \tparam T The type of the value + /// \param val The value to check + /// \return True if the value is infinite, + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool isInf(const T &val) noexcept { + return std::isinf(val); + } - /// Create a new number with a given magnitude and sign - /// \tparam T The type of the magnitude - /// \tparam M The type of the sign - /// \param mag The magnitude of the number - /// \param sign The value from which to copy the sign - /// \return - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T copySign(const T &mag, const M &sign) noexcept { + /// Create a new number with a given magnitude and sign + /// \tparam T The type of the magnitude + /// \tparam M The type of the sign + /// \param mag The magnitude of the number + /// \param sign The value from which to copy the sign + /// \return + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T copySign(const T &mag, const M &sign) noexcept { #if defined(LIBRAPID_MSVC) - return std::copysign(mag, static_cast(sign)); + return std::copysign(mag, static_cast(sign)); #else - if constexpr (std::is_fundamental_v && std::is_fundamental_v) { - return std::copysign(mag, static_cast(sign)); - } else { - if (sign < 0) return -mag; - return mag; - } + if constexpr (std::is_fundamental_v && std::is_fundamental_v) { + return std::copysign(mag, static_cast(sign)); + } else { + if (sign < 0) return -mag; + return mag; + } #endif - } + } - /// Extract the sign bit from a value - /// \tparam T The type of the value - /// \param val The value to extract the sign bit from - /// \return The sign bit of the value - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const T &val) noexcept { - return signBit((double)val); - } + /// Extract the sign bit from a value + /// \tparam T The type of the value + /// \param val The value to extract the sign bit from + /// \return The sign bit of the value + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const T &val) noexcept { + return signBit((double)val); + } - /// Extract the sign bit from a value - /// \param val The value to extract the sign bit from - /// \return The sign bit of the value - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const long double &val) noexcept { - return std::signbit(val); - } + /// Extract the sign bit from a value + /// \param val The value to extract the sign bit from + /// \return The sign bit of the value + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const long double &val) noexcept { + return std::signbit(val); + } - /// Extract the sign bit from a value - /// \param val The value to extract the sign bit from - /// \return The sign bit of the value - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const double &val) noexcept { - return std::signbit(val); - } + /// Extract the sign bit from a value + /// \param val The value to extract the sign bit from + /// \return The sign bit of the value + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const double &val) noexcept { + return std::signbit(val); + } - /// Extract the sign bit from a value - /// \param val The value to extract the sign bit from - /// \return The sign bit of the value - template<> - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const float &val) noexcept { - return std::signbit(val); - } + /// Extract the sign bit from a value + /// \param val The value to extract the sign bit from + /// \return The sign bit of the value + template<> + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE bool signBit(const float &val) noexcept { + return std::signbit(val); + } - /// Return a value multiplied by 2 raised to the power of an exponent - /// \tparam T The type of the value - /// \param x The value to multiply - /// \param exp The exponent to raise 2 to - /// \return x * 2^exp - template - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T ldexp(const T &x, const int64_t exp) noexcept { - return std::ldexp(x, (int)exp); - } + /// Return a value multiplied by 2 raised to the power of an exponent + /// \tparam T The type of the value + /// \param x The value to multiply + /// \param exp The exponent to raise 2 to + /// \return x * 2^exp + template + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T ldexp(const T &x, const int64_t exp) noexcept { + return std::ldexp(x, (int)exp); + } } // namespace librapid #endif // LIBRAPID_UTILS_MEMUTILS_HPP \ No newline at end of file diff --git a/librapid/include/librapid/utils/time.hpp b/librapid/include/librapid/utils/time.hpp index b5de8124..240cb5a7 100644 --- a/librapid/include/librapid/utils/time.hpp +++ b/librapid/include/librapid/utils/time.hpp @@ -2,163 +2,161 @@ #define LIBRAPID_UTILS_TIME_HPP namespace librapid { - namespace time { - constexpr int64_t nanosecond = int64_t(1); - constexpr int64_t microsecond = nanosecond * 1000; - constexpr int64_t millisecond = microsecond * 1000; - constexpr int64_t second = millisecond * 1000; - constexpr int64_t minute = second * 60; - constexpr int64_t hour = minute * 60; - constexpr int64_t day = hour * 24; - } // namespace time - - template - LIBRAPID_NODISCARD double now() { - using namespace std::chrono; + namespace time { + constexpr int64_t nanosecond = int64_t(1); + constexpr int64_t microsecond = nanosecond * 1000; + constexpr int64_t millisecond = microsecond * 1000; + constexpr int64_t second = millisecond * 1000; + constexpr int64_t minute = second * 60; + constexpr int64_t hour = minute * 60; + constexpr int64_t day = hour * 24; + } // namespace time + + template + LIBRAPID_NODISCARD double now() { + using namespace std::chrono; #if defined(LIBRAPID_OS_WINDOWS) - using rep = int64_t; - using period = std::nano; - using duration = std::chrono::duration; - - static const int64_t clockFreq = []() -> int64_t { - LARGE_INTEGER frequency; - QueryPerformanceFrequency(&frequency); - return frequency.QuadPart; - }(); - - LARGE_INTEGER count; - QueryPerformanceCounter(&count); - return duration(count.QuadPart * static_cast(std::nano::den) / clockFreq).count() / - (double)scale; + using rep = int64_t; + using period = std::nano; + using duration = std::chrono::duration; + + static const int64_t clockFreq = []() -> int64_t { + LARGE_INTEGER frequency; + QueryPerformanceFrequency(&frequency); + return frequency.QuadPart; + }(); + + LARGE_INTEGER count; + QueryPerformanceCounter(&count); + return duration(count.QuadPart * static_cast(std::nano::den) / clockFreq).count() / + (double)scale; #else - return (double)high_resolution_clock::now().time_since_epoch().count() / (double)scale; + return (double)high_resolution_clock::now().time_since_epoch().count() / (double)scale; #endif - } + } - constexpr static double sleepOffset = 0; + constexpr static double sleepOffset = 0; - template - LIBRAPID_ALWAYS_INLINE void sleep(double time) { - using namespace std::chrono; - time *= scale; - auto start = now(); - while (now() - start < time - sleepOffset) {} - } + template + LIBRAPID_ALWAYS_INLINE void sleep(double time) { + using namespace std::chrono; + time *= scale; + auto start = now(); + while (now() - start < time - sleepOffset) {} + } - template - std::string formatTime(double time, const std::string &format = "{:.3f}") { - double ns = time * scale; - int numUnits = 8; + template + std::string formatTime(double time, const std::string &format = "{:.3f}") { + double ns = time * scale; + int numUnits = 8; - static std::string prefix[] = { - "ns", + static std::string prefix[] = { + "ns", #if defined(LIBRAPID_OS_WINDOWS) && defined(LIBRAPID_NO_WINDOWS_H) - "µs", + "µs", #else - "us", + "us", #endif - "ms", - "s", - "m", - "h", - "d", - "y" - }; - - static double divisor[] = {1000, 1000, 1000, 60, 60, 24, 365, 1e300}; - for (int i = 0; i < numUnits; ++i) { - // if (ns < divisor[i]) return std::operator+(fmt::format(format, ns), prefix[i]); - if (ns < divisor[i]) return fmt::vformat(format, fmt::make_format_args(ns)) + prefix[i]; - ns /= divisor[i]; - } - return fmt::format("{}ns", time * ns); - } - - /// A timer class that can be used to measure a multitude of things. - /// The timer can be started, stopped and reset, and can, optionally, output - /// the time between construction and destruction to the console. - class Timer { - public: - /// Create a new timer with a given name - /// \param name The name of the timer - /// \param printOnDestruct Whether to print the time between construction and destruction - explicit Timer(std::string name = "") : - m_name(std::move(name)), m_start(now()), m_end(-1) {} - - Timer(const Timer &) = default; - Timer(Timer &&) = default; - Timer &operator=(const Timer &) = default; - Timer &operator=(Timer &&) = default; - - /// Timer destructor - ~Timer() { - m_end = now(); - } - - template - Timer &setTargetTime(double time) { - m_iters = 0; - m_targetTime = time * (double)scale; - m_start = now(); - return *this; - } - - /// Start the timer - void start() { - m_start = now(); - m_end = -1; - } - - /// Stop the timer - void stop() { m_end = now(); } - - /// Reset the timer - void reset() { - m_start = now(); - m_end = -1; - } - - /// Get the elapsed time in a given unit - /// \tparam scale The unit to return the time in - /// \return The elapsed time in the given unit - template - LIBRAPID_NODISCARD double elapsed() const { - if (m_end == -1) return (now() - m_start) / (double)scale; - return (m_end - m_start) / (double)scale; - } - - /// Get the average time in a given unit - /// \tparam scale The unit to return the time in - /// \return The average time in the given unit - template - LIBRAPID_NODISCARD double average() const { - return elapsed() / (double)m_iters; - } - - bool isRunning() { - ++m_iters; - return now() - m_start < m_targetTime; - } - - /// Print the current elapsed time of the timer - LIBRAPID_NODISCARD std::string str(const std::string &format = "{:.3f}") const { - double tmpEnd = m_end; - if (tmpEnd < 0) tmpEnd = now(); - return fmt::format( - "{}Elapsed: {} | Average: {}", - (m_name.empty() ? "" : m_name + ": "), - formatTime(tmpEnd - m_start, format), - formatTime((tmpEnd - m_start) / (double)m_iters, format)); - } - - private: - std::string m_name = "Timer"; - double m_start = 0; - double m_end = 0; - - size_t m_iters = 0; - double m_targetTime = 0; - }; + "ms", + "s", + "m", + "h", + "d", + "y" + }; + + static double divisor[] = {1000, 1000, 1000, 60, 60, 24, 365, 1e300}; + for (int i = 0; i < numUnits; ++i) { + // if (ns < divisor[i]) return std::operator+(fmt::format(format, ns), prefix[i]); + if (ns < divisor[i]) return fmt::vformat(format, fmt::make_format_args(ns)) + prefix[i]; + ns /= divisor[i]; + } + return fmt::format("{}ns", time * ns); + } + + /// A timer class that can be used to measure a multitude of things. + /// The timer can be started, stopped and reset, and can, optionally, output + /// the time between construction and destruction to the console. + class Timer { + public: + /// Create a new timer with a given name + /// \param name The name of the timer + /// \param printOnDestruct Whether to print the time between construction and destruction + explicit Timer(std::string name = "") : + m_name(std::move(name)), m_start(now()), m_end(-1) {} + + Timer(const Timer &) = default; + Timer(Timer &&) = default; + Timer &operator=(const Timer &) = default; + Timer &operator=(Timer &&) = default; + + /// Timer destructor + ~Timer() { m_end = now(); } + + template + Timer &setTargetTime(double time) { + m_iters = 0; + m_targetTime = time * (double)scale; + m_start = now(); + return *this; + } + + /// Start the timer + void start() { + m_start = now(); + m_end = -1; + } + + /// Stop the timer + void stop() { m_end = now(); } + + /// Reset the timer + void reset() { + m_start = now(); + m_end = -1; + } + + /// Get the elapsed time in a given unit + /// \tparam scale The unit to return the time in + /// \return The elapsed time in the given unit + template + LIBRAPID_NODISCARD double elapsed() const { + if (m_end == -1) return (now() - m_start) / (double)scale; + return (m_end - m_start) / (double)scale; + } + + /// Get the average time in a given unit + /// \tparam scale The unit to return the time in + /// \return The average time in the given unit + template + LIBRAPID_NODISCARD double average() const { + return elapsed() / (double)m_iters; + } + + bool isRunning() { + ++m_iters; + return now() - m_start < m_targetTime; + } + + /// Print the current elapsed time of the timer + LIBRAPID_NODISCARD std::string str(const std::string &format = "{:.3f}") const { + double tmpEnd = m_end; + if (tmpEnd < 0) tmpEnd = now(); + return fmt::format( + "{}Elapsed: {} | Average: {}", + (m_name.empty() ? "" : m_name + ": "), + formatTime(tmpEnd - m_start, format), + formatTime((tmpEnd - m_start) / (double)m_iters, format)); + } + + private: + std::string m_name = "Timer"; + double m_start = 0; + double m_end = 0; + + size_t m_iters = 0; + double m_targetTime = 0; + }; } // namespace librapid LIBRAPID_SIMPLE_IO_IMPL_NO_TEMPLATE(librapid::Timer); diff --git a/librapid/src/cacheLineSize.cpp b/librapid/src/cacheLineSize.cpp index 5b1b2757..19502d51 100644 --- a/librapid/src/cacheLineSize.cpp +++ b/librapid/src/cacheLineSize.cpp @@ -14,64 +14,64 @@ #if defined(LIBRAPID_APPLE) -# include +# include namespace librapid { - size_t cacheLineSize() { - size_t lineSize = 64; - size_t sizeOfLineSize = sizeof(lineSize); - sysctlbyname("hw.cachelinesize", &lineSize, &sizeOfLineSize, 0, 0); - return lineSize; - } + size_t cacheLineSize() { + size_t lineSize = 64; + size_t sizeOfLineSize = sizeof(lineSize); + sysctlbyname("hw.cachelinesize", &lineSize, &sizeOfLineSize, 0, 0); + return lineSize; + } } // namespace librapid #elif defined(LIBRAPID_WINDOWS) && !defined(LIBRAPID_NO_WINDOWS_H) namespace librapid { - size_t cacheLineSize() { - size_t lineSize = 64; - DWORD bufferSize = 0; - DWORD i = 0; - SYSTEM_LOGICAL_PROCESSOR_INFORMATION *buffer = 0; + size_t cacheLineSize() { + size_t lineSize = 64; + DWORD bufferSize = 0; + DWORD i = 0; + SYSTEM_LOGICAL_PROCESSOR_INFORMATION *buffer = 0; - GetLogicalProcessorInformation(0, &bufferSize); - buffer = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION *)malloc(bufferSize); - GetLogicalProcessorInformation(&buffer[0], &bufferSize); + GetLogicalProcessorInformation(0, &bufferSize); + buffer = (SYSTEM_LOGICAL_PROCESSOR_INFORMATION *)malloc(bufferSize); + GetLogicalProcessorInformation(&buffer[0], &bufferSize); - for (i = 0; i != bufferSize / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION); ++i) { - if (buffer[i].Relationship == RelationCache && buffer[i].Cache.Level == 1) { - lineSize = buffer[i].Cache.LineSize; - break; - } - } + for (i = 0; i != bufferSize / sizeof(SYSTEM_LOGICAL_PROCESSOR_INFORMATION); ++i) { + if (buffer[i].Relationship == RelationCache && buffer[i].Cache.Level == 1) { + lineSize = buffer[i].Cache.LineSize; + break; + } + } - free(buffer); - return lineSize; - } + free(buffer); + return lineSize; + } } // namespace librapid #elif defined(LIBRAPID_LINUX) namespace librapid { - size_t cacheLineSize() { - FILE *p = 0; - p = fopen("/sys/devices/system/cpu/cpu0/cache/index0/coherency_line_size", "r"); - unsigned int lineSize = 64; - if (p) { - fscanf(p, "%d", &lineSize); - fclose(p); - } - return lineSize; - } + size_t cacheLineSize() { + FILE *p = 0; + p = fopen("/sys/devices/system/cpu/cpu0/cache/index0/coherency_line_size", "r"); + unsigned int lineSize = 64; + if (p) { + fscanf(p, "%d", &lineSize); + fclose(p); + } + return lineSize; + } } // namespace librapid #else namespace librapid { - size_t cacheLineSize() { - // On unknown platforms, return 64 - return 64; - } + size_t cacheLineSize() { + // On unknown platforms, return 64 + return 64; + } } // namespace librapid #endif \ No newline at end of file diff --git a/librapid/src/compat.cpp b/librapid/src/compat.cpp index 9ff00c4a..62bb0936 100644 --- a/librapid/src/compat.cpp +++ b/librapid/src/compat.cpp @@ -3,97 +3,97 @@ #if defined(LIBRAPID_HAS_OPENCL) namespace clblast { - // template<> - // StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, - // const Transpose b_transpose, const size_t m, const size_t n, - // const size_t k, const librapid::Complex alpha, - // const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, - // const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, - // const librapid::Complex beta, cl_mem c_buffer, - // const size_t c_offset, const size_t c_ld, cl_command_queue *queue, - // cl_event *event, cl_mem temp_buffer) { - // return Gemm(layout, - // a_transpose, - // b_transpose, - // m, - // n, - // k, - // {alpha.real(), alpha.imag()}, - // a_buffer, - // a_offset, - // a_ld, - // b_buffer, - // b_offset, - // b_ld, - // {beta.real(), beta.imag()}, - // c_buffer, - // c_offset, - // c_ld, - // queue, - // event, - // temp_buffer); - // } + // template<> + // StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, + // const Transpose b_transpose, const size_t m, const size_t n, + // const size_t k, const librapid::Complex alpha, + // const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + // const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + // const librapid::Complex beta, cl_mem c_buffer, + // const size_t c_offset, const size_t c_ld, cl_command_queue *queue, + // cl_event *event, cl_mem temp_buffer) { + // return Gemm(layout, + // a_transpose, + // b_transpose, + // m, + // n, + // k, + // {alpha.real(), alpha.imag()}, + // a_buffer, + // a_offset, + // a_ld, + // b_buffer, + // b_offset, + // b_ld, + // {beta.real(), beta.imag()}, + // c_buffer, + // c_offset, + // c_ld, + // queue, + // event, + // temp_buffer); + // } - // template<> - // StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, - // const Transpose b_transpose, const size_t m, const size_t n, - // const size_t k, const librapid::Complex alpha, - // const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, - // const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, - // const librapid::Complex beta, cl_mem c_buffer, - // const size_t c_offset, const size_t c_ld, cl_command_queue *queue, - // cl_event *event, cl_mem temp_buffer) { - // return Gemm(layout, - // a_transpose, - // b_transpose, - // m, - // n, - // k, - // {alpha.real(), alpha.imag()}, - // a_buffer, - // a_offset, - // a_ld, - // b_buffer, - // b_offset, - // b_ld, - // {beta.real(), beta.imag()}, - // c_buffer, - // c_offset, - // c_ld, - // queue, - // event, - // temp_buffer); - // } + // template<> + // StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, + // const Transpose b_transpose, const size_t m, const size_t n, + // const size_t k, const librapid::Complex alpha, + // const cl_mem a_buffer, const size_t a_offset, const size_t a_ld, + // const cl_mem b_buffer, const size_t b_offset, const size_t b_ld, + // const librapid::Complex beta, cl_mem c_buffer, + // const size_t c_offset, const size_t c_ld, cl_command_queue *queue, + // cl_event *event, cl_mem temp_buffer) { + // return Gemm(layout, + // a_transpose, + // b_transpose, + // m, + // n, + // k, + // {alpha.real(), alpha.imag()}, + // a_buffer, + // a_offset, + // a_ld, + // b_buffer, + // b_offset, + // b_ld, + // {beta.real(), beta.imag()}, + // c_buffer, + // c_offset, + // c_ld, + // queue, + // event, + // temp_buffer); + // } - template<> - StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, - const Transpose b_transpose, const size_t m, const size_t n, - const size_t k, const librapid::half alpha, const cl_mem a_buffer, - const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, - const size_t b_offset, const size_t b_ld, const librapid::half beta, - cl_mem c_buffer, const size_t c_offset, const size_t c_ld, - cl_command_queue *queue, cl_event *event, cl_mem temp_buffer) { - return Gemm(layout, - a_transpose, - b_transpose, - m, - n, - k, - alpha.data().m_bits, - a_buffer, - a_offset, - a_ld, - b_buffer, - b_offset, - b_ld, - beta.data().m_bits, - c_buffer, - c_offset, - c_ld, - queue, - event, - temp_buffer); - } + template<> + StatusCode PUBLIC_API Gemm(const Layout layout, const Transpose a_transpose, + const Transpose b_transpose, const size_t m, const size_t n, + const size_t k, const librapid::half alpha, const cl_mem a_buffer, + const size_t a_offset, const size_t a_ld, const cl_mem b_buffer, + const size_t b_offset, const size_t b_ld, const librapid::half beta, + cl_mem c_buffer, const size_t c_offset, const size_t c_ld, + cl_command_queue *queue, cl_event *event, cl_mem temp_buffer) { + return Gemm(layout, + a_transpose, + b_transpose, + m, + n, + k, + alpha.data().m_bits, + a_buffer, + a_offset, + a_ld, + b_buffer, + b_offset, + b_ld, + beta.data().m_bits, + c_buffer, + c_offset, + c_ld, + queue, + event, + temp_buffer); + } } // namespace clblast #endif // LIBRAPID_HAS_OPENCL diff --git a/librapid/src/cudaKernelProcessor.cpp b/librapid/src/cudaKernelProcessor.cpp index 69800acf..2ddf9825 100644 --- a/librapid/src/cudaKernelProcessor.cpp +++ b/librapid/src/cudaKernelProcessor.cpp @@ -1,44 +1,44 @@ #if defined(LIBRAPID_HAS_CUDA) -# include +# include namespace librapid::cuda { - const std::string &loadKernel(const std::string &path, bool relative) { - static std::map mapping; - - if (mapping.find(path) != mapping.end()) { return mapping[path]; } - - auto basePath = fmt::format("{}/include/librapid/cuda/kernels/", LIBRAPID_SOURCE); - - std::string helperPath = fmt::format("{}/kernelHelper.cuh", basePath); - std::string vectorOpsPath = fmt::format("{}/vectorOps.cuh", basePath); - std::string dualPath = - fmt::format("{}/include/librapid/autodiff/dual.hpp", LIBRAPID_SOURCE); - std::string kernelPath = fmt::format("{}{}.cu", relative ? (basePath + "/") : "", path); - std::fstream helper(helperPath); - std::fstream vectorOps(vectorOpsPath); - std::fstream dual(dualPath); - std::fstream kernel(kernelPath); - LIBRAPID_ASSERT(helper.is_open(), "Failed to load CUDA helper functions"); - LIBRAPID_ASSERT(vectorOps.is_open(), "Failed to load CUDA vectorOps helper functions"); - LIBRAPID_ASSERT(dual.is_open(), "Failed to load dual number library"); - LIBRAPID_ASSERT(kernel.is_open(), "Failed to load CUDA kernel '{}.cu'", path); - std::stringstream buffer; - buffer << helper.rdbuf(); - buffer << "\n\n"; - buffer << vectorOps.rdbuf(); - buffer << "\n\n"; - buffer << dual.rdbuf(); - buffer << "\n\n"; - buffer << kernel.rdbuf(); - - mapping[path] = path + "\n" + buffer.str(); - return mapping[path]; - } - - jitify::Program generateCudaProgram(const std::string &kernel) { - return global::jitCache.program(kernel, {}, {fmt::format("-I{}", CUDA_INCLUDE_DIRS)}); - } + const std::string &loadKernel(const std::string &path, bool relative) { + static std::map mapping; + + if (mapping.find(path) != mapping.end()) { return mapping[path]; } + + auto basePath = fmt::format("{}/include/librapid/cuda/kernels/", LIBRAPID_SOURCE); + + std::string helperPath = fmt::format("{}/kernelHelper.cuh", basePath); + std::string vectorOpsPath = fmt::format("{}/vectorOps.cuh", basePath); + std::string dualPath = + fmt::format("{}/include/librapid/autodiff/dual.hpp", LIBRAPID_SOURCE); + std::string kernelPath = fmt::format("{}{}.cu", relative ? (basePath + "/") : "", path); + std::fstream helper(helperPath); + std::fstream vectorOps(vectorOpsPath); + std::fstream dual(dualPath); + std::fstream kernel(kernelPath); + LIBRAPID_ASSERT(helper.is_open(), "Failed to load CUDA helper functions"); + LIBRAPID_ASSERT(vectorOps.is_open(), "Failed to load CUDA vectorOps helper functions"); + LIBRAPID_ASSERT(dual.is_open(), "Failed to load dual number library"); + LIBRAPID_ASSERT(kernel.is_open(), "Failed to load CUDA kernel '{}.cu'", path); + std::stringstream buffer; + buffer << helper.rdbuf(); + buffer << "\n\n"; + buffer << vectorOps.rdbuf(); + buffer << "\n\n"; + buffer << dual.rdbuf(); + buffer << "\n\n"; + buffer << kernel.rdbuf(); + + mapping[path] = path + "\n" + buffer.str(); + return mapping[path]; + } + + jitify::Program generateCudaProgram(const std::string &kernel) { + return global::jitCache.program(kernel, {}, {fmt::format("-I{}", CUDA_INCLUDE_DIRS)}); + } } // namespace librapid::cuda #endif // LIBRAPID_HAS_CUDA diff --git a/librapid/src/fastMath.cpp b/librapid/src/fastMath.cpp index 000b48ba..c9cfaac6 100644 --- a/librapid/src/fastMath.cpp +++ b/librapid/src/fastMath.cpp @@ -1,32 +1,32 @@ #include namespace librapid::fastmath { - double pow10(int64_t exponent) { - const static double pows[] = {0.0000001, - 0.000001, - 0.00001, - 0.0001, - 0.001, - 0.01, - 0.1, - 1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 1000000}; + double pow10(int64_t exponent) { + const static double pows[] = {0.0000001, + 0.000001, + 0.00001, + 0.0001, + 0.001, + 0.01, + 0.1, + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 1000000}; - if (exponent >= -7 && exponent <= 7) return pows[exponent + 7]; + if (exponent >= -7 && exponent <= 7) return pows[exponent + 7]; - double res = 1; + double res = 1; - if (exponent > 0) - for (int64_t i = 0; i < exponent; ++i) res *= 10.; - else - for (int64_t i = 0; i > exponent; --i) res *= 0.1; + if (exponent > 0) + for (int64_t i = 0; i < exponent; ++i) res *= 10.; + else + for (int64_t i = 0; i > exponent; --i) res *= 0.1; - return res; - } -} \ No newline at end of file + return res; + } +} // namespace librapid::fastmath \ No newline at end of file diff --git a/librapid/src/global.cpp b/librapid/src/global.cpp index 558acafd..417a0925 100644 --- a/librapid/src/global.cpp +++ b/librapid/src/global.cpp @@ -3,77 +3,77 @@ #include // setenv namespace librapid { - namespace global { - bool throwOnAssert = false; - size_t multithreadThreshold = 5000; - size_t gemmMultithreadThreshold = 100; - size_t gemvMultithreadThreshold = 100; - size_t numThreads = 8; - size_t randomSeed = 0; // Set in PreMain - bool reseed = false; - size_t cacheLineSize = 64; - size_t memoryAlignment = LIBRAPID_DEFAULT_MEM_ALIGN; + namespace global { + bool throwOnAssert = false; + size_t multithreadThreshold = 5000; + size_t gemmMultithreadThreshold = 100; + size_t gemvMultithreadThreshold = 100; + size_t numThreads = 8; + size_t randomSeed = 0; // Set in PreMain + bool reseed = false; + size_t cacheLineSize = 64; + size_t memoryAlignment = LIBRAPID_DEFAULT_MEM_ALIGN; #if defined(LIBRAPID_HAS_OPENCL) - std::vector openclDevices; - cl::Context openCLContext; - cl::Device openCLDevice; - cl::CommandQueue openCLQueue; - cl::Program::Sources openCLSources; - cl::Program openCLProgram; - bool openCLConfigured = false; + std::vector openclDevices; + cl::Context openCLContext; + cl::Device openCLDevice; + cl::CommandQueue openCLQueue; + cl::Program::Sources openCLSources; + cl::Program openCLProgram; + bool openCLConfigured = false; #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - cudaStream_t cudaStream; - cublasHandle_t cublasHandle; - cublasLtHandle_t cublasLtHandle; - uint64_t cublasLtWorkspaceSize = 1024 * 1024 * 4; - void *cublasLtWorkspace; - jitify::JitCache jitCache; + cudaStream_t cudaStream; + cublasHandle_t cublasHandle; + cublasLtHandle_t cublasLtHandle; + uint64_t cublasLtWorkspaceSize = 1024 * 1024 * 4; + void *cublasLtWorkspace; + jitify::JitCache jitCache; #endif // LIBRAPID_HAS_CUDA - } // namespace global + } // namespace global #if defined(_WIN32) -# define SETENV(name, value) _putenv_s(name, value) +# define SETENV(name, value) _putenv_s(name, value) #else -# define SETENV(name, value) setenv(name, value, 1) +# define SETENV(name, value) setenv(name, value, 1) #endif - void setOpenBLASThreadsEnv(int num_threads) { - char num_threads_str[20]; - sprintf(num_threads_str, "%d", num_threads); + void setOpenBLASThreadsEnv(int num_threads) { + char num_threads_str[20]; + sprintf(num_threads_str, "%d", num_threads); - SETENV("OPENBLAS_NUM_THREADS", num_threads_str); - SETENV("GOTO_NUM_THREADS", num_threads_str); - SETENV("OMP_NUM_THREADS", num_threads_str); - } + SETENV("OPENBLAS_NUM_THREADS", num_threads_str); + SETENV("GOTO_NUM_THREADS", num_threads_str); + SETENV("OMP_NUM_THREADS", num_threads_str); + } - void setNumThreads(size_t numThreads) { - global::numThreads = numThreads; + void setNumThreads(size_t numThreads) { + global::numThreads = numThreads; - // OpenBLAS threading + // OpenBLAS threading #if defined(LIBRAPID_BLAS_OPENBLAS) - openblas_set_num_threads((int)numThreads); - omp_set_num_threads((int)numThreads); - goto_set_num_threads((int)numThreads); + openblas_set_num_threads((int)numThreads); + omp_set_num_threads((int)numThreads); + goto_set_num_threads((int)numThreads); - setOpenBLASThreadsEnv((int)numThreads); + setOpenBLASThreadsEnv((int)numThreads); #endif // LIBRAPID_BLAS_OPENBLAS - // MKL threading + // MKL threading #if defined(LIBRAPID_BLAS_MKL) - mkl_set_num_threads((int)numThreads); + mkl_set_num_threads((int)numThreads); #endif // LIBRAPID_BLAS_MKL - } + } - size_t getNumThreads() { return global::numThreads; } + size_t getNumThreads() { return global::numThreads; } - void setSeed(size_t seed) { - global::randomSeed = seed; - global::reseed = true; - } + void setSeed(size_t seed) { + global::randomSeed = seed; + global::reseed = true; + } - size_t getSeed() { return global::randomSeed; } + size_t getSeed() { return global::randomSeed; } } // namespace librapid diff --git a/librapid/src/helper_cuda.cpp b/librapid/src/helper_cuda.cpp index cf6dc5f1..600c05d8 100644 --- a/librapid/src/helper_cuda.cpp +++ b/librapid/src/helper_cuda.cpp @@ -7,275 +7,275 @@ const char *_cudaGetErrorEnum(cudaError_t error) { return cudaGetErrorName(error #ifdef CUDA_DRIVER_API // CUDA Driver API errors const char *_cudaGetErrorEnum(CUresult error) { - static char unknown[] = ""; - const char *ret = NULL; - cuGetErrorName(error, &ret); - return ret ? ret : unknown; + static char unknown[] = ""; + const char *ret = NULL; + cuGetErrorName(error, &ret); + return ret ? ret : unknown; } #endif #ifdef __DRIVER_TYPES_H__ const char *getCublasErrorEnum_(cublasStatus_t error) { - switch (error) { - case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; - } + switch (error) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + } - return "UNKNOWN ERROR"; + return "UNKNOWN ERROR"; } #endif #ifdef CUDA_DRIVER_API const char *_cudaGetErrorEnum(CUresult error) { - static char unknown[] = ""; - const char *ret = NULL; - cuGetErrorName(error, &ret); - return ret ? ret : unknown; + static char unknown[] = ""; + const char *ret = NULL; + cuGetErrorName(error, &ret); + return ret ? ret : unknown; } #endif #ifdef CUBLAS_API_H_ const char *_cudaGetErrorEnum(cublasStatus_t error) { - switch (error) { - case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; - } + switch (error) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + } - return ""; + return ""; } #endif #ifdef _CUFFT_H_ const char *_cudaGetErrorEnum(cufftResult error) { - switch (error) { - case CUFFT_SUCCESS: return "CUFFT_SUCCESS"; - case CUFFT_INVALID_PLAN: return "CUFFT_INVALID_PLAN"; - case CUFFT_ALLOC_FAILED: return "CUFFT_ALLOC_FAILED"; - case CUFFT_INVALID_TYPE: return "CUFFT_INVALID_TYPE"; - case CUFFT_INVALID_VALUE: return "CUFFT_INVALID_VALUE"; - case CUFFT_INTERNAL_ERROR: return "CUFFT_INTERNAL_ERROR"; - case CUFFT_EXEC_FAILED: return "CUFFT_EXEC_FAILED"; - case CUFFT_SETUP_FAILED: return "CUFFT_SETUP_FAILED"; - case CUFFT_INVALID_SIZE: return "CUFFT_INVALID_SIZE"; - case CUFFT_UNALIGNED_DATA: return "CUFFT_UNALIGNED_DATA"; - case CUFFT_INCOMPLETE_PARAMETER_LIST: return "CUFFT_INCOMPLETE_PARAMETER_LIST"; - case CUFFT_INVALID_DEVICE: return "CUFFT_INVALID_DEVICE"; - case CUFFT_PARSE_ERROR: return "CUFFT_PARSE_ERROR"; - case CUFFT_NO_WORKSPACE: return "CUFFT_NO_WORKSPACE"; - case CUFFT_NOT_IMPLEMENTED: return "CUFFT_NOT_IMPLEMENTED"; - case CUFFT_LICENSE_ERROR: return "CUFFT_LICENSE_ERROR"; - case CUFFT_NOT_SUPPORTED: return "CUFFT_NOT_SUPPORTED"; - } + switch (error) { + case CUFFT_SUCCESS: return "CUFFT_SUCCESS"; + case CUFFT_INVALID_PLAN: return "CUFFT_INVALID_PLAN"; + case CUFFT_ALLOC_FAILED: return "CUFFT_ALLOC_FAILED"; + case CUFFT_INVALID_TYPE: return "CUFFT_INVALID_TYPE"; + case CUFFT_INVALID_VALUE: return "CUFFT_INVALID_VALUE"; + case CUFFT_INTERNAL_ERROR: return "CUFFT_INTERNAL_ERROR"; + case CUFFT_EXEC_FAILED: return "CUFFT_EXEC_FAILED"; + case CUFFT_SETUP_FAILED: return "CUFFT_SETUP_FAILED"; + case CUFFT_INVALID_SIZE: return "CUFFT_INVALID_SIZE"; + case CUFFT_UNALIGNED_DATA: return "CUFFT_UNALIGNED_DATA"; + case CUFFT_INCOMPLETE_PARAMETER_LIST: return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + case CUFFT_INVALID_DEVICE: return "CUFFT_INVALID_DEVICE"; + case CUFFT_PARSE_ERROR: return "CUFFT_PARSE_ERROR"; + case CUFFT_NO_WORKSPACE: return "CUFFT_NO_WORKSPACE"; + case CUFFT_NOT_IMPLEMENTED: return "CUFFT_NOT_IMPLEMENTED"; + case CUFFT_LICENSE_ERROR: return "CUFFT_LICENSE_ERROR"; + case CUFFT_NOT_SUPPORTED: return "CUFFT_NOT_SUPPORTED"; + } - return ""; + return ""; } #endif #ifdef CUSPARSEAPI // cuSPARSE API errors const char *_cudaGetErrorEnum(cusparseStatus_t error) { - switch (error) { - case CUSPARSE_STATUS_SUCCESS: return "CUSPARSE_STATUS_SUCCESS"; - case CUSPARSE_STATUS_NOT_INITIALIZED: return "CUSPARSE_STATUS_NOT_INITIALIZED"; - case CUSPARSE_STATUS_ALLOC_FAILED: return "CUSPARSE_STATUS_ALLOC_FAILED"; - case CUSPARSE_STATUS_INVALID_VALUE: return "CUSPARSE_STATUS_INVALID_VALUE"; - case CUSPARSE_STATUS_ARCH_MISMATCH: return "CUSPARSE_STATUS_ARCH_MISMATCH"; - case CUSPARSE_STATUS_MAPPING_ERROR: return "CUSPARSE_STATUS_MAPPING_ERROR"; - case CUSPARSE_STATUS_EXECUTION_FAILED: return "CUSPARSE_STATUS_EXECUTION_FAILED"; - case CUSPARSE_STATUS_INTERNAL_ERROR: return "CUSPARSE_STATUS_INTERNAL_ERROR"; - case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED: - return "CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; - } + switch (error) { + case CUSPARSE_STATUS_SUCCESS: return "CUSPARSE_STATUS_SUCCESS"; + case CUSPARSE_STATUS_NOT_INITIALIZED: return "CUSPARSE_STATUS_NOT_INITIALIZED"; + case CUSPARSE_STATUS_ALLOC_FAILED: return "CUSPARSE_STATUS_ALLOC_FAILED"; + case CUSPARSE_STATUS_INVALID_VALUE: return "CUSPARSE_STATUS_INVALID_VALUE"; + case CUSPARSE_STATUS_ARCH_MISMATCH: return "CUSPARSE_STATUS_ARCH_MISMATCH"; + case CUSPARSE_STATUS_MAPPING_ERROR: return "CUSPARSE_STATUS_MAPPING_ERROR"; + case CUSPARSE_STATUS_EXECUTION_FAILED: return "CUSPARSE_STATUS_EXECUTION_FAILED"; + case CUSPARSE_STATUS_INTERNAL_ERROR: return "CUSPARSE_STATUS_INTERNAL_ERROR"; + case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + return "CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; + } - return ""; + return ""; } #endif #ifdef CUSOLVER_COMMON_H_ // cuSOLVER API errors const char *_cudaGetErrorEnum(cusolverStatus_t error) { - switch (error) { - case CUSOLVER_STATUS_SUCCESS: return "CUSOLVER_STATUS_SUCCESS"; - case CUSOLVER_STATUS_NOT_INITIALIZED: return "CUSOLVER_STATUS_NOT_INITIALIZED"; - case CUSOLVER_STATUS_ALLOC_FAILED: return "CUSOLVER_STATUS_ALLOC_FAILED"; - case CUSOLVER_STATUS_INVALID_VALUE: return "CUSOLVER_STATUS_INVALID_VALUE"; - case CUSOLVER_STATUS_ARCH_MISMATCH: return "CUSOLVER_STATUS_ARCH_MISMATCH"; - case CUSOLVER_STATUS_MAPPING_ERROR: return "CUSOLVER_STATUS_MAPPING_ERROR"; - case CUSOLVER_STATUS_EXECUTION_FAILED: return "CUSOLVER_STATUS_EXECUTION_FAILED"; - case CUSOLVER_STATUS_INTERNAL_ERROR: return "CUSOLVER_STATUS_INTERNAL_ERROR"; - case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: - return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; - case CUSOLVER_STATUS_NOT_SUPPORTED: return "CUSOLVER_STATUS_NOT_SUPPORTED "; - case CUSOLVER_STATUS_ZERO_PIVOT: return "CUSOLVER_STATUS_ZERO_PIVOT"; - case CUSOLVER_STATUS_INVALID_LICENSE: return "CUSOLVER_STATUS_INVALID_LICENSE"; - } + switch (error) { + case CUSOLVER_STATUS_SUCCESS: return "CUSOLVER_STATUS_SUCCESS"; + case CUSOLVER_STATUS_NOT_INITIALIZED: return "CUSOLVER_STATUS_NOT_INITIALIZED"; + case CUSOLVER_STATUS_ALLOC_FAILED: return "CUSOLVER_STATUS_ALLOC_FAILED"; + case CUSOLVER_STATUS_INVALID_VALUE: return "CUSOLVER_STATUS_INVALID_VALUE"; + case CUSOLVER_STATUS_ARCH_MISMATCH: return "CUSOLVER_STATUS_ARCH_MISMATCH"; + case CUSOLVER_STATUS_MAPPING_ERROR: return "CUSOLVER_STATUS_MAPPING_ERROR"; + case CUSOLVER_STATUS_EXECUTION_FAILED: return "CUSOLVER_STATUS_EXECUTION_FAILED"; + case CUSOLVER_STATUS_INTERNAL_ERROR: return "CUSOLVER_STATUS_INTERNAL_ERROR"; + case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; + case CUSOLVER_STATUS_NOT_SUPPORTED: return "CUSOLVER_STATUS_NOT_SUPPORTED "; + case CUSOLVER_STATUS_ZERO_PIVOT: return "CUSOLVER_STATUS_ZERO_PIVOT"; + case CUSOLVER_STATUS_INVALID_LICENSE: return "CUSOLVER_STATUS_INVALID_LICENSE"; + } - return ""; + return ""; } #endif #ifdef CURAND_H_ // cuRAND API errors const char *_cudaGetErrorEnum(curandStatus_t error) { - switch (error) { - case CURAND_STATUS_SUCCESS: return "CURAND_STATUS_SUCCESS"; - case CURAND_STATUS_VERSION_MISMATCH: return "CURAND_STATUS_VERSION_MISMATCH"; - case CURAND_STATUS_NOT_INITIALIZED: return "CURAND_STATUS_NOT_INITIALIZED"; - case CURAND_STATUS_ALLOCATION_FAILED: return "CURAND_STATUS_ALLOCATION_FAILED"; - case CURAND_STATUS_TYPE_ERROR: return "CURAND_STATUS_TYPE_ERROR"; - case CURAND_STATUS_OUT_OF_RANGE: return "CURAND_STATUS_OUT_OF_RANGE"; - case CURAND_STATUS_LENGTH_NOT_MULTIPLE: return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; - case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: - return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; - case CURAND_STATUS_LAUNCH_FAILURE: return "CURAND_STATUS_LAUNCH_FAILURE"; - case CURAND_STATUS_PREEXISTING_FAILURE: return "CURAND_STATUS_PREEXISTING_FAILURE"; - case CURAND_STATUS_INITIALIZATION_FAILED: return "CURAND_STATUS_INITIALIZATION_FAILED"; - case CURAND_STATUS_ARCH_MISMATCH: return "CURAND_STATUS_ARCH_MISMATCH"; - case CURAND_STATUS_INTERNAL_ERROR: return "CURAND_STATUS_INTERNAL_ERROR"; - } + switch (error) { + case CURAND_STATUS_SUCCESS: return "CURAND_STATUS_SUCCESS"; + case CURAND_STATUS_VERSION_MISMATCH: return "CURAND_STATUS_VERSION_MISMATCH"; + case CURAND_STATUS_NOT_INITIALIZED: return "CURAND_STATUS_NOT_INITIALIZED"; + case CURAND_STATUS_ALLOCATION_FAILED: return "CURAND_STATUS_ALLOCATION_FAILED"; + case CURAND_STATUS_TYPE_ERROR: return "CURAND_STATUS_TYPE_ERROR"; + case CURAND_STATUS_OUT_OF_RANGE: return "CURAND_STATUS_OUT_OF_RANGE"; + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + case CURAND_STATUS_LAUNCH_FAILURE: return "CURAND_STATUS_LAUNCH_FAILURE"; + case CURAND_STATUS_PREEXISTING_FAILURE: return "CURAND_STATUS_PREEXISTING_FAILURE"; + case CURAND_STATUS_INITIALIZATION_FAILED: return "CURAND_STATUS_INITIALIZATION_FAILED"; + case CURAND_STATUS_ARCH_MISMATCH: return "CURAND_STATUS_ARCH_MISMATCH"; + case CURAND_STATUS_INTERNAL_ERROR: return "CURAND_STATUS_INTERNAL_ERROR"; + } - return ""; + return ""; } #endif #ifdef NVJPEGAPI // nvJPEG API errors const char *_cudaGetErrorEnum(nvjpegStatus_t error) { - switch (error) { - case NVJPEG_STATUS_SUCCESS: return "NVJPEG_STATUS_SUCCESS"; - case NVJPEG_STATUS_NOT_INITIALIZED: return "NVJPEG_STATUS_NOT_INITIALIZED"; - case NVJPEG_STATUS_INVALID_PARAMETER: return "NVJPEG_STATUS_INVALID_PARAMETER"; - case NVJPEG_STATUS_BAD_JPEG: return "NVJPEG_STATUS_BAD_JPEG"; - case NVJPEG_STATUS_JPEG_NOT_SUPPORTED: return "NVJPEG_STATUS_JPEG_NOT_SUPPORTED"; - case NVJPEG_STATUS_ALLOCATOR_FAILURE: return "NVJPEG_STATUS_ALLOCATOR_FAILURE"; - case NVJPEG_STATUS_EXECUTION_FAILED: return "NVJPEG_STATUS_EXECUTION_FAILED"; - case NVJPEG_STATUS_ARCH_MISMATCH: return "NVJPEG_STATUS_ARCH_MISMATCH"; - case NVJPEG_STATUS_INTERNAL_ERROR: return "NVJPEG_STATUS_INTERNAL_ERROR"; - } + switch (error) { + case NVJPEG_STATUS_SUCCESS: return "NVJPEG_STATUS_SUCCESS"; + case NVJPEG_STATUS_NOT_INITIALIZED: return "NVJPEG_STATUS_NOT_INITIALIZED"; + case NVJPEG_STATUS_INVALID_PARAMETER: return "NVJPEG_STATUS_INVALID_PARAMETER"; + case NVJPEG_STATUS_BAD_JPEG: return "NVJPEG_STATUS_BAD_JPEG"; + case NVJPEG_STATUS_JPEG_NOT_SUPPORTED: return "NVJPEG_STATUS_JPEG_NOT_SUPPORTED"; + case NVJPEG_STATUS_ALLOCATOR_FAILURE: return "NVJPEG_STATUS_ALLOCATOR_FAILURE"; + case NVJPEG_STATUS_EXECUTION_FAILED: return "NVJPEG_STATUS_EXECUTION_FAILED"; + case NVJPEG_STATUS_ARCH_MISMATCH: return "NVJPEG_STATUS_ARCH_MISMATCH"; + case NVJPEG_STATUS_INTERNAL_ERROR: return "NVJPEG_STATUS_INTERNAL_ERROR"; + } - return ""; + return ""; } #endif #ifdef NV_NPPIDEFS_H // NPP API errors const char *_cudaGetErrorEnum(NppStatus error) { - switch (error) { - case NPP_NOT_SUPPORTED_MODE_ERROR: return "NPP_NOT_SUPPORTED_MODE_ERROR"; - case NPP_ROUND_MODE_NOT_SUPPORTED_ERROR: return "NPP_ROUND_MODE_NOT_SUPPORTED_ERROR"; - case NPP_RESIZE_NO_OPERATION_ERROR: return "NPP_RESIZE_NO_OPERATION_ERROR"; - case NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY: return "NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY"; + switch (error) { + case NPP_NOT_SUPPORTED_MODE_ERROR: return "NPP_NOT_SUPPORTED_MODE_ERROR"; + case NPP_ROUND_MODE_NOT_SUPPORTED_ERROR: return "NPP_ROUND_MODE_NOT_SUPPORTED_ERROR"; + case NPP_RESIZE_NO_OPERATION_ERROR: return "NPP_RESIZE_NO_OPERATION_ERROR"; + case NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY: return "NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY"; -# if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) <= 0x5000 - case NPP_BAD_ARG_ERROR: return "NPP_BAD_ARGUMENT_ERROR"; - case NPP_COEFF_ERROR: return "NPP_COEFFICIENT_ERROR"; - case NPP_RECT_ERROR: return "NPP_RECTANGLE_ERROR"; - case NPP_QUAD_ERROR: return "NPP_QUADRANGLE_ERROR"; - case NPP_MEM_ALLOC_ERR: return "NPP_MEMORY_ALLOCATION_ERROR"; - case NPP_HISTO_NUMBER_OF_LEVELS_ERROR: return "NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR"; - case NPP_INVALID_INPUT: return "NPP_INVALID_INPUT"; - case NPP_POINTER_ERROR: return "NPP_POINTER_ERROR"; - case NPP_WARNING: return "NPP_WARNING"; - case NPP_ODD_ROI_WARNING: return "NPP_ODD_ROI_WARNING"; -# else +# if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) <= 0x5000 + case NPP_BAD_ARG_ERROR: return "NPP_BAD_ARGUMENT_ERROR"; + case NPP_COEFF_ERROR: return "NPP_COEFFICIENT_ERROR"; + case NPP_RECT_ERROR: return "NPP_RECTANGLE_ERROR"; + case NPP_QUAD_ERROR: return "NPP_QUADRANGLE_ERROR"; + case NPP_MEM_ALLOC_ERR: return "NPP_MEMORY_ALLOCATION_ERROR"; + case NPP_HISTO_NUMBER_OF_LEVELS_ERROR: return "NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR"; + case NPP_INVALID_INPUT: return "NPP_INVALID_INPUT"; + case NPP_POINTER_ERROR: return "NPP_POINTER_ERROR"; + case NPP_WARNING: return "NPP_WARNING"; + case NPP_ODD_ROI_WARNING: return "NPP_ODD_ROI_WARNING"; +# else - // These are for CUDA 5.5 or higher - case NPP_BAD_ARGUMENT_ERROR: return "NPP_BAD_ARGUMENT_ERROR"; - case NPP_COEFFICIENT_ERROR: return "NPP_COEFFICIENT_ERROR"; - case NPP_RECTANGLE_ERROR: return "NPP_RECTANGLE_ERROR"; - case NPP_QUADRANGLE_ERROR: return "NPP_QUADRANGLE_ERROR"; - case NPP_MEMORY_ALLOCATION_ERR: return "NPP_MEMORY_ALLOCATION_ERROR"; - case NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR: return "NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR"; - case NPP_INVALID_HOST_POINTER_ERROR: return "NPP_INVALID_HOST_POINTER_ERROR"; - case NPP_INVALID_DEVICE_POINTER_ERROR: return "NPP_INVALID_DEVICE_POINTER_ERROR"; -# endif - case NPP_LUT_NUMBER_OF_LEVELS_ERROR: return "NPP_LUT_NUMBER_OF_LEVELS_ERROR"; - case NPP_TEXTURE_BIND_ERROR: return "NPP_TEXTURE_BIND_ERROR"; - case NPP_WRONG_INTERSECTION_ROI_ERROR: return "NPP_WRONG_INTERSECTION_ROI_ERROR"; - case NPP_NOT_EVEN_STEP_ERROR: return "NPP_NOT_EVEN_STEP_ERROR"; - case NPP_INTERPOLATION_ERROR: return "NPP_INTERPOLATION_ERROR"; - case NPP_RESIZE_FACTOR_ERROR: return "NPP_RESIZE_FACTOR_ERROR"; - case NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR: return "NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR"; + // These are for CUDA 5.5 or higher + case NPP_BAD_ARGUMENT_ERROR: return "NPP_BAD_ARGUMENT_ERROR"; + case NPP_COEFFICIENT_ERROR: return "NPP_COEFFICIENT_ERROR"; + case NPP_RECTANGLE_ERROR: return "NPP_RECTANGLE_ERROR"; + case NPP_QUADRANGLE_ERROR: return "NPP_QUADRANGLE_ERROR"; + case NPP_MEMORY_ALLOCATION_ERR: return "NPP_MEMORY_ALLOCATION_ERROR"; + case NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR: return "NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR"; + case NPP_INVALID_HOST_POINTER_ERROR: return "NPP_INVALID_HOST_POINTER_ERROR"; + case NPP_INVALID_DEVICE_POINTER_ERROR: return "NPP_INVALID_DEVICE_POINTER_ERROR"; +# endif + case NPP_LUT_NUMBER_OF_LEVELS_ERROR: return "NPP_LUT_NUMBER_OF_LEVELS_ERROR"; + case NPP_TEXTURE_BIND_ERROR: return "NPP_TEXTURE_BIND_ERROR"; + case NPP_WRONG_INTERSECTION_ROI_ERROR: return "NPP_WRONG_INTERSECTION_ROI_ERROR"; + case NPP_NOT_EVEN_STEP_ERROR: return "NPP_NOT_EVEN_STEP_ERROR"; + case NPP_INTERPOLATION_ERROR: return "NPP_INTERPOLATION_ERROR"; + case NPP_RESIZE_FACTOR_ERROR: return "NPP_RESIZE_FACTOR_ERROR"; + case NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR: return "NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR"; -# if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) <= 0x5000 - case NPP_MEMFREE_ERR: return "NPP_MEMFREE_ERR"; - case NPP_MEMSET_ERR: return "NPP_MEMSET_ERR"; - case NPP_MEMCPY_ERR: return "NPP_MEMCPY_ERROR"; - case NPP_MIRROR_FLIP_ERR: return "NPP_MIRROR_FLIP_ERR"; -# else - case NPP_MEMFREE_ERROR: return "NPP_MEMFREE_ERROR"; - case NPP_MEMSET_ERROR: return "NPP_MEMSET_ERROR"; - case NPP_MEMCPY_ERROR: return "NPP_MEMCPY_ERROR"; - case NPP_MIRROR_FLIP_ERROR: return "NPP_MIRROR_FLIP_ERROR"; -# endif - case NPP_ALIGNMENT_ERROR: return "NPP_ALIGNMENT_ERROR"; - case NPP_STEP_ERROR: return "NPP_STEP_ERROR"; - case NPP_SIZE_ERROR: return "NPP_SIZE_ERROR"; - case NPP_NULL_POINTER_ERROR: return "NPP_NULL_POINTER_ERROR"; - case NPP_CUDA_KERNEL_EXECUTION_ERROR: return "NPP_CUDA_KERNEL_EXECUTION_ERROR"; - case NPP_NOT_IMPLEMENTED_ERROR: return "NPP_NOT_IMPLEMENTED_ERROR"; - case NPP_ERROR: return "NPP_ERROR"; - case NPP_SUCCESS: return "NPP_SUCCESS"; - case NPP_WRONG_INTERSECTION_QUAD_WARNING: return "NPP_WRONG_INTERSECTION_QUAD_WARNING"; - case NPP_MISALIGNED_DST_ROI_WARNING: return "NPP_MISALIGNED_DST_ROI_WARNING"; - case NPP_AFFINE_QUAD_INCORRECT_WARNING: return "NPP_AFFINE_QUAD_INCORRECT_WARNING"; - case NPP_DOUBLE_SIZE_WARNING: return "NPP_DOUBLE_SIZE_WARNING"; - case NPP_WRONG_INTERSECTION_ROI_WARNING: return "NPP_WRONG_INTERSECTION_ROI_WARNING"; +# if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) <= 0x5000 + case NPP_MEMFREE_ERR: return "NPP_MEMFREE_ERR"; + case NPP_MEMSET_ERR: return "NPP_MEMSET_ERR"; + case NPP_MEMCPY_ERR: return "NPP_MEMCPY_ERROR"; + case NPP_MIRROR_FLIP_ERR: return "NPP_MIRROR_FLIP_ERR"; +# else + case NPP_MEMFREE_ERROR: return "NPP_MEMFREE_ERROR"; + case NPP_MEMSET_ERROR: return "NPP_MEMSET_ERROR"; + case NPP_MEMCPY_ERROR: return "NPP_MEMCPY_ERROR"; + case NPP_MIRROR_FLIP_ERROR: return "NPP_MIRROR_FLIP_ERROR"; +# endif + case NPP_ALIGNMENT_ERROR: return "NPP_ALIGNMENT_ERROR"; + case NPP_STEP_ERROR: return "NPP_STEP_ERROR"; + case NPP_SIZE_ERROR: return "NPP_SIZE_ERROR"; + case NPP_NULL_POINTER_ERROR: return "NPP_NULL_POINTER_ERROR"; + case NPP_CUDA_KERNEL_EXECUTION_ERROR: return "NPP_CUDA_KERNEL_EXECUTION_ERROR"; + case NPP_NOT_IMPLEMENTED_ERROR: return "NPP_NOT_IMPLEMENTED_ERROR"; + case NPP_ERROR: return "NPP_ERROR"; + case NPP_SUCCESS: return "NPP_SUCCESS"; + case NPP_WRONG_INTERSECTION_QUAD_WARNING: return "NPP_WRONG_INTERSECTION_QUAD_WARNING"; + case NPP_MISALIGNED_DST_ROI_WARNING: return "NPP_MISALIGNED_DST_ROI_WARNING"; + case NPP_AFFINE_QUAD_INCORRECT_WARNING: return "NPP_AFFINE_QUAD_INCORRECT_WARNING"; + case NPP_DOUBLE_SIZE_WARNING: return "NPP_DOUBLE_SIZE_WARNING"; + case NPP_WRONG_INTERSECTION_ROI_WARNING: return "NPP_WRONG_INTERSECTION_ROI_WARNING"; -# if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) >= 0x6000 - /* These are 6.0 or higher */ - case NPP_LUT_PALETTE_BITSIZE_ERROR: return "NPP_LUT_PALETTE_BITSIZE_ERROR"; - case NPP_ZC_MODE_NOT_SUPPORTED_ERROR: return "NPP_ZC_MODE_NOT_SUPPORTED_ERROR"; - case NPP_QUALITY_INDEX_ERROR: return "NPP_QUALITY_INDEX_ERROR"; - case NPP_CHANNEL_ORDER_ERROR: return "NPP_CHANNEL_ORDER_ERROR"; - case NPP_ZERO_MASK_VALUE_ERROR: return "NPP_ZERO_MASK_VALUE_ERROR"; - case NPP_NUMBER_OF_CHANNELS_ERROR: return "NPP_NUMBER_OF_CHANNELS_ERROR"; - case NPP_COI_ERROR: return "NPP_COI_ERROR"; - case NPP_DIVISOR_ERROR: return "NPP_DIVISOR_ERROR"; - case NPP_CHANNEL_ERROR: return "NPP_CHANNEL_ERROR"; - case NPP_STRIDE_ERROR: return "NPP_STRIDE_ERROR"; - case NPP_ANCHOR_ERROR: return "NPP_ANCHOR_ERROR"; - case NPP_MASK_SIZE_ERROR: return "NPP_MASK_SIZE_ERROR"; - case NPP_MOMENT_00_ZERO_ERROR: return "NPP_MOMENT_00_ZERO_ERROR"; - case NPP_THRESHOLD_NEGATIVE_LEVEL_ERROR: return "NPP_THRESHOLD_NEGATIVE_LEVEL_ERROR"; - case NPP_THRESHOLD_ERROR: return "NPP_THRESHOLD_ERROR"; - case NPP_CONTEXT_MATCH_ERROR: return "NPP_CONTEXT_MATCH_ERROR"; - case NPP_FFT_FLAG_ERROR: return "NPP_FFT_FLAG_ERROR"; - case NPP_FFT_ORDER_ERROR: return "NPP_FFT_ORDER_ERROR"; - case NPP_SCALE_RANGE_ERROR: return "NPP_SCALE_RANGE_ERROR"; - case NPP_DATA_TYPE_ERROR: return "NPP_DATA_TYPE_ERROR"; - case NPP_OUT_OFF_RANGE_ERROR: return "NPP_OUT_OFF_RANGE_ERROR"; - case NPP_DIVIDE_BY_ZERO_ERROR: return "NPP_DIVIDE_BY_ZERO_ERROR"; - case NPP_RANGE_ERROR: return "NPP_RANGE_ERROR"; - case NPP_NO_MEMORY_ERROR: return "NPP_NO_MEMORY_ERROR"; - case NPP_ERROR_RESERVED: return "NPP_ERROR_RESERVED"; - case NPP_NO_OPERATION_WARNING: return "NPP_NO_OPERATION_WARNING"; - case NPP_DIVIDE_BY_ZERO_WARNING: return "NPP_DIVIDE_BY_ZERO_WARNING"; -# endif +# if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) >= 0x6000 + /* These are 6.0 or higher */ + case NPP_LUT_PALETTE_BITSIZE_ERROR: return "NPP_LUT_PALETTE_BITSIZE_ERROR"; + case NPP_ZC_MODE_NOT_SUPPORTED_ERROR: return "NPP_ZC_MODE_NOT_SUPPORTED_ERROR"; + case NPP_QUALITY_INDEX_ERROR: return "NPP_QUALITY_INDEX_ERROR"; + case NPP_CHANNEL_ORDER_ERROR: return "NPP_CHANNEL_ORDER_ERROR"; + case NPP_ZERO_MASK_VALUE_ERROR: return "NPP_ZERO_MASK_VALUE_ERROR"; + case NPP_NUMBER_OF_CHANNELS_ERROR: return "NPP_NUMBER_OF_CHANNELS_ERROR"; + case NPP_COI_ERROR: return "NPP_COI_ERROR"; + case NPP_DIVISOR_ERROR: return "NPP_DIVISOR_ERROR"; + case NPP_CHANNEL_ERROR: return "NPP_CHANNEL_ERROR"; + case NPP_STRIDE_ERROR: return "NPP_STRIDE_ERROR"; + case NPP_ANCHOR_ERROR: return "NPP_ANCHOR_ERROR"; + case NPP_MASK_SIZE_ERROR: return "NPP_MASK_SIZE_ERROR"; + case NPP_MOMENT_00_ZERO_ERROR: return "NPP_MOMENT_00_ZERO_ERROR"; + case NPP_THRESHOLD_NEGATIVE_LEVEL_ERROR: return "NPP_THRESHOLD_NEGATIVE_LEVEL_ERROR"; + case NPP_THRESHOLD_ERROR: return "NPP_THRESHOLD_ERROR"; + case NPP_CONTEXT_MATCH_ERROR: return "NPP_CONTEXT_MATCH_ERROR"; + case NPP_FFT_FLAG_ERROR: return "NPP_FFT_FLAG_ERROR"; + case NPP_FFT_ORDER_ERROR: return "NPP_FFT_ORDER_ERROR"; + case NPP_SCALE_RANGE_ERROR: return "NPP_SCALE_RANGE_ERROR"; + case NPP_DATA_TYPE_ERROR: return "NPP_DATA_TYPE_ERROR"; + case NPP_OUT_OFF_RANGE_ERROR: return "NPP_OUT_OFF_RANGE_ERROR"; + case NPP_DIVIDE_BY_ZERO_ERROR: return "NPP_DIVIDE_BY_ZERO_ERROR"; + case NPP_RANGE_ERROR: return "NPP_RANGE_ERROR"; + case NPP_NO_MEMORY_ERROR: return "NPP_NO_MEMORY_ERROR"; + case NPP_ERROR_RESERVED: return "NPP_ERROR_RESERVED"; + case NPP_NO_OPERATION_WARNING: return "NPP_NO_OPERATION_WARNING"; + case NPP_DIVIDE_BY_ZERO_WARNING: return "NPP_DIVIDE_BY_ZERO_WARNING"; +# endif -# if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) >= 0x7000 - /* These are 7.0 or higher */ - case NPP_OVERFLOW_ERROR: return "NPP_OVERFLOW_ERROR"; - case NPP_CORRUPTED_DATA_ERROR: return "NPP_CORRUPTED_DATA_ERROR"; -# endif - } +# if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) >= 0x7000 + /* These are 7.0 or higher */ + case NPP_OVERFLOW_ERROR: return "NPP_OVERFLOW_ERROR"; + case NPP_CORRUPTED_DATA_ERROR: return "NPP_CORRUPTED_DATA_ERROR"; +# endif + } - return ""; + return ""; } #endif diff --git a/librapid/src/literals.cpp b/librapid/src/literals.cpp index 9881263a..1b4485a1 100644 --- a/librapid/src/literals.cpp +++ b/librapid/src/literals.cpp @@ -2,6 +2,6 @@ namespace librapid::literals { #if defined(LIBRAPID_USE_MULTIPREC) - ::librapid::mpfr operator""_f(const char *str, size_t) { return {str}; } + ::librapid::mpfr operator""_f(const char *str, size_t) { return {str}; } #endif // LIBRAPID_USE_MULTIPREC } // namespace librapid::literals diff --git a/librapid/src/multiprecCasting.cpp b/librapid/src/multiprecCasting.cpp index 7b113d69..e6cf063e 100644 --- a/librapid/src/multiprecCasting.cpp +++ b/librapid/src/multiprecCasting.cpp @@ -3,25 +3,25 @@ #if defined(LIBRAPID_USE_MULTIPREC) namespace librapid { - mpz toMpz(const mpz &other) { return other; } - mpz toMpz(const mpf &other) { return mpz(other); } - mpz toMpz(const mpq &other) { return mpz(other); } - mpz toMpz(const mpfr &other) { return mpz(mpf_class(str(other))); } + mpz toMpz(const mpz &other) { return other; } + mpz toMpz(const mpf &other) { return mpz(other); } + mpz toMpz(const mpq &other) { return mpz(other); } + mpz toMpz(const mpfr &other) { return mpz(mpf_class(str(other))); } - mpf toMpf(const mpz &other) { return mpf(other); } - mpf toMpf(const mpf &other) { return other; } - mpf toMpf(const mpq &other) { return mpf(other); } - mpf toMpf(const mpfr &other) { return mpf(str(other)); } + mpf toMpf(const mpz &other) { return mpf(other); } + mpf toMpf(const mpf &other) { return other; } + mpf toMpf(const mpq &other) { return mpf(other); } + mpf toMpf(const mpfr &other) { return mpf(str(other)); } - mpq toMpq(const mpz &other) { return {other}; } - mpq toMpq(const mpf &other) { return mpq(other); } - mpq toMpq(const mpq &other) { return other; } - mpq toMpq(const mpfr &other) { return mpq(mpf_class(str(other))); } + mpq toMpq(const mpz &other) { return {other}; } + mpq toMpq(const mpf &other) { return mpq(other); } + mpq toMpq(const mpq &other) { return other; } + mpq toMpq(const mpfr &other) { return mpq(mpf_class(str(other))); } - mpfr toMpfr(const mpz &other) { return {str(other)}; } - mpfr toMpfr(const mpf &other) { return {str(other)}; } - mpfr toMpfr(const mpq &other) { return {str(mpf_class(other))}; } - mpfr toMpfr(const mpfr &other) { return other; } + mpfr toMpfr(const mpz &other) { return {str(other)}; } + mpfr toMpfr(const mpf &other) { return {str(other)}; } + mpfr toMpfr(const mpq &other) { return {str(mpf_class(other))}; } + mpfr toMpfr(const mpfr &other) { return other; } } // namespace librapid #endif // LIBRAPID_USE_MULTIPREC \ No newline at end of file diff --git a/librapid/src/multiprecExpLogPow.cpp b/librapid/src/multiprecExpLogPow.cpp index fabcc831..4c283931 100644 --- a/librapid/src/multiprecExpLogPow.cpp +++ b/librapid/src/multiprecExpLogPow.cpp @@ -3,19 +3,19 @@ #if defined(LIBRAPID_USE_MULTIPREC) namespace librapid { - mpfr sqrt(const mpfr &val) { return ::mpfr::sqrt(val); } - mpfr cbrt(const mpfr &val) { return ::mpfr::cbrt(val); } - mpfr pow(const mpfr &base, const mpfr &pow) { return ::mpfr::pow(base, pow); } - mpfr exp(const mpfr &val) { return ::mpfr::exp(val); } - mpfr exp2(const mpfr &val) { return ::mpfr::exp2(val); } - mpfr exp10(const mpfr &val) { return ::mpfr::exp10(val); } - mpfr ldexp(const mpfr &val, int exponent) { return ::mpfr::ldexp(val, exponent); } - mpfr log(const mpfr &val) { return ::mpfr::log(val); } - mpfr log(const mpfr &val, const mpfr &base) { - return ::mpfr::operator/(::mpfr::log(val), ::mpfr::log(base)); - } - mpfr log2(const mpfr &val) { return ::mpfr::log2(val); } - mpfr log10(const mpfr &val) { return ::mpfr::log10(val); } + mpfr sqrt(const mpfr &val) { return ::mpfr::sqrt(val); } + mpfr cbrt(const mpfr &val) { return ::mpfr::cbrt(val); } + mpfr pow(const mpfr &base, const mpfr &pow) { return ::mpfr::pow(base, pow); } + mpfr exp(const mpfr &val) { return ::mpfr::exp(val); } + mpfr exp2(const mpfr &val) { return ::mpfr::exp2(val); } + mpfr exp10(const mpfr &val) { return ::mpfr::exp10(val); } + mpfr ldexp(const mpfr &val, int exponent) { return ::mpfr::ldexp(val, exponent); } + mpfr log(const mpfr &val) { return ::mpfr::log(val); } + mpfr log(const mpfr &val, const mpfr &base) { + return ::mpfr::operator/(::mpfr::log(val), ::mpfr::log(base)); + } + mpfr log2(const mpfr &val) { return ::mpfr::log2(val); } + mpfr log10(const mpfr &val) { return ::mpfr::log10(val); } } // namespace librapid #endif // LIBRAPID_USE_MULTIPREC diff --git a/librapid/src/multiprecFloorCeil.cpp b/librapid/src/multiprecFloorCeil.cpp index 7e535196..77ae74b8 100644 --- a/librapid/src/multiprecFloorCeil.cpp +++ b/librapid/src/multiprecFloorCeil.cpp @@ -3,8 +3,8 @@ #if defined(LIBRAPID_USE_MULTIPREC) namespace librapid { - mpfr floor(const mpfr &val) { return ::mpfr::floor(val); } - mpfr ceil(const mpfr &val) { return ::mpfr::ceil(val); } -} + mpfr floor(const mpfr &val) { return ::mpfr::floor(val); } + mpfr ceil(const mpfr &val) { return ::mpfr::ceil(val); } +} // namespace librapid #endif // LIBRAPID_USE_MULTIPREC diff --git a/librapid/src/multiprecHypot.cpp b/librapid/src/multiprecHypot.cpp index 147b1179..70a65f52 100644 --- a/librapid/src/multiprecHypot.cpp +++ b/librapid/src/multiprecHypot.cpp @@ -1,9 +1,9 @@ #include -# if defined(LIBRAPID_USE_MULTIPREC) +#if defined(LIBRAPID_USE_MULTIPREC) namespace librapid { - mpfr hypot(const mpfr &a, const mpfr &b) { return ::mpfr::hypot(a, b); } + mpfr hypot(const mpfr &a, const mpfr &b) { return ::mpfr::hypot(a, b); } } // namespace librapid #endif // LIBRAPID_USE_MULTIPREC diff --git a/librapid/src/multiprecModAbs.cpp b/librapid/src/multiprecModAbs.cpp index 172e55f0..89609126 100644 --- a/librapid/src/multiprecModAbs.cpp +++ b/librapid/src/multiprecModAbs.cpp @@ -3,30 +3,30 @@ #if defined(LIBRAPID_USE_MULTIPREC) namespace librapid { - mpfr abs(const mpfr &val) { return ::mpfr::abs(val); } + mpfr abs(const mpfr &val) { return ::mpfr::abs(val); } - mpf abs(const mpf &val) { - if (val >= 0) - return val; - else - return -val; - } + mpf abs(const mpf &val) { + if (val >= 0) + return val; + else + return -val; + } - mpz abs(const mpz &val) { - if (val >= 0) - return val; - else - return -val; - } + mpz abs(const mpz &val) { + if (val >= 0) + return val; + else + return -val; + } - mpq abs(const mpq &val) { - if (val >= 0) - return val; - else - return -val; - } + mpq abs(const mpq &val) { + if (val >= 0) + return val; + else + return -val; + } - mpfr mod(const mpfr &val, const mpfr &mod) { return ::mpfr::fmod(val, mod); } + mpfr mod(const mpfr &val, const mpfr &mod) { return ::mpfr::fmod(val, mod); } } // namespace librapid #endif // LIBRAPID_USE_MULTIPREC diff --git a/librapid/src/multiprecToString.cpp b/librapid/src/multiprecToString.cpp index bd61332d..070849da 100644 --- a/librapid/src/multiprecToString.cpp +++ b/librapid/src/multiprecToString.cpp @@ -3,44 +3,44 @@ #if defined(LIBRAPID_USE_MULTIPREC) namespace librapid { - std::string str(const mpz &val, int64_t, int base) { return val.get_str(base); } - - std::string str(const mpf &val, int64_t digits, int base) { - mp_exp_t exp; - std::string res = val.get_str(exp, base, digits); - bool sign = false; - - if (res[0] == '-') { - sign = true; - res = std::string(res.begin() + 1, res.end()); - } - - if (exp > 0) { - if (static_cast(exp) >= res.length()) - res += std::string(static_cast(exp) - res.length() + 1, '0'); - res.insert(exp, "."); - } else { - std::string tmp(-exp + 1, '0'); - tmp += res; - tmp.insert(1, "."); - res = tmp; - } - - if (sign) res = "-" + res; - return res; - } - - std::string str(const mpq &val, int64_t, int base) { return val.get_str(base); } - - std::string str(const mpfr &val, int64_t digits, int) { - std::stringstream ss; - ss << std::fixed; - mp_prec_t dig2 = val.getPrecision(); - dig2 = ::mpfr::bits2digits(digits < 0 ? dig2 : mp_prec_t(digits)); - ss.precision(dig2); - ss << val; - return ss.str(); - } + std::string str(const mpz &val, int64_t, int base) { return val.get_str(base); } + + std::string str(const mpf &val, int64_t digits, int base) { + mp_exp_t exp; + std::string res = val.get_str(exp, base, digits); + bool sign = false; + + if (res[0] == '-') { + sign = true; + res = std::string(res.begin() + 1, res.end()); + } + + if (exp > 0) { + if (static_cast(exp) >= res.length()) + res += std::string(static_cast(exp) - res.length() + 1, '0'); + res.insert(exp, "."); + } else { + std::string tmp(-exp + 1, '0'); + tmp += res; + tmp.insert(1, "."); + res = tmp; + } + + if (sign) res = "-" + res; + return res; + } + + std::string str(const mpq &val, int64_t, int base) { return val.get_str(base); } + + std::string str(const mpfr &val, int64_t digits, int) { + std::stringstream ss; + ss << std::fixed; + mp_prec_t dig2 = val.getPrecision(); + dig2 = ::mpfr::bits2digits(digits < 0 ? dig2 : mp_prec_t(digits)); + ss.precision(dig2); + ss << val; + return ss.str(); + } } // namespace librapid #endif // LIBRAPID_USE_MULTIPREC \ No newline at end of file diff --git a/librapid/src/multiprecTrig.cpp b/librapid/src/multiprecTrig.cpp index f31d30a6..2f714ca1 100644 --- a/librapid/src/multiprecTrig.cpp +++ b/librapid/src/multiprecTrig.cpp @@ -3,38 +3,38 @@ #if defined(LIBRAPID_USE_MULTIPREC) namespace librapid { - mpfr sin(const mpfr &val) { return ::mpfr::sin(val); } - mpfr cos(const mpfr &val) { return ::mpfr::cos(val); } - mpfr tan(const mpfr &val) { return ::mpfr::tan(val); } - - mpfr asin(const mpfr &val) { return ::mpfr::asin(val); } - mpfr acos(const mpfr &val) { return ::mpfr::acos(val); } - mpfr atan(const mpfr &val) { return ::mpfr::atan(val); } - mpfr atan2(const mpfr &dy, const mpfr &dx) { return ::mpfr::atan2(dy, dx); } - - mpfr csc(const mpfr &val) { return ::mpfr::csc(val); } - mpfr sec(const mpfr &val) { return ::mpfr::sec(val); } - mpfr cot(const mpfr &val) { return ::mpfr::cot(val); } - - mpfr acsc(const mpfr &val) { return ::mpfr::acsc(val); } - mpfr asec(const mpfr &val) { return ::mpfr::asec(val); } - mpfr acot(const mpfr &val) { return ::mpfr::acot(val); } - - mpfr sinh(const mpfr &val) { return ::mpfr::sinh(val); } - mpfr cosh(const mpfr &val) { return ::mpfr::cosh(val); } - mpfr tanh(const mpfr &val) { return ::mpfr::tanh(val); } - - mpfr asinh(const mpfr &val) { return ::mpfr::asinh(val); } - mpfr acosh(const mpfr &val) { return ::mpfr::acosh(val); } - mpfr atanh(const mpfr &val) { return ::mpfr::atanh(val); } - - mpfr csch(const mpfr &val) { return ::mpfr::csch(val); } - mpfr sech(const mpfr &val) { return ::mpfr::sech(val); } - mpfr coth(const mpfr &val) { return ::mpfr::coth(val); } - - mpfr acsch(const mpfr &val) { return ::mpfr::acsch(val); } - mpfr asech(const mpfr &val) { return ::mpfr::asech(val); } - mpfr acoth(const mpfr &val) { return ::mpfr::acoth(val); } + mpfr sin(const mpfr &val) { return ::mpfr::sin(val); } + mpfr cos(const mpfr &val) { return ::mpfr::cos(val); } + mpfr tan(const mpfr &val) { return ::mpfr::tan(val); } + + mpfr asin(const mpfr &val) { return ::mpfr::asin(val); } + mpfr acos(const mpfr &val) { return ::mpfr::acos(val); } + mpfr atan(const mpfr &val) { return ::mpfr::atan(val); } + mpfr atan2(const mpfr &dy, const mpfr &dx) { return ::mpfr::atan2(dy, dx); } + + mpfr csc(const mpfr &val) { return ::mpfr::csc(val); } + mpfr sec(const mpfr &val) { return ::mpfr::sec(val); } + mpfr cot(const mpfr &val) { return ::mpfr::cot(val); } + + mpfr acsc(const mpfr &val) { return ::mpfr::acsc(val); } + mpfr asec(const mpfr &val) { return ::mpfr::asec(val); } + mpfr acot(const mpfr &val) { return ::mpfr::acot(val); } + + mpfr sinh(const mpfr &val) { return ::mpfr::sinh(val); } + mpfr cosh(const mpfr &val) { return ::mpfr::cosh(val); } + mpfr tanh(const mpfr &val) { return ::mpfr::tanh(val); } + + mpfr asinh(const mpfr &val) { return ::mpfr::asinh(val); } + mpfr acosh(const mpfr &val) { return ::mpfr::acosh(val); } + mpfr atanh(const mpfr &val) { return ::mpfr::atanh(val); } + + mpfr csch(const mpfr &val) { return ::mpfr::csch(val); } + mpfr sech(const mpfr &val) { return ::mpfr::sech(val); } + mpfr coth(const mpfr &val) { return ::mpfr::coth(val); } + + mpfr acsch(const mpfr &val) { return ::mpfr::acsch(val); } + mpfr asech(const mpfr &val) { return ::mpfr::asech(val); } + mpfr acoth(const mpfr &val) { return ::mpfr::acoth(val); } } // namespace librapid #endif // LIBRAPID_USE_MULTIPREC diff --git a/librapid/src/openclConfigure.cpp b/librapid/src/openclConfigure.cpp index ffe63bdf..319a364d 100644 --- a/librapid/src/openclConfigure.cpp +++ b/librapid/src/openclConfigure.cpp @@ -3,290 +3,290 @@ namespace librapid { #if defined(LIBRAPID_HAS_OPENCL) - struct OpenCLTestResult { - bool pass; - int64_t err; - std::string errStr; - std::string buildLog; - }; - - OpenCLTestResult testOpenCLDevice(const cl::Device &device) { - try { - cl::Context context(device); - cl::CommandQueue queue(context, device); - - std::string source = R"V0G0N( + struct OpenCLTestResult { + bool pass; + int64_t err; + std::string errStr; + std::string buildLog; + }; + + OpenCLTestResult testOpenCLDevice(const cl::Device &device) { + try { + cl::Context context(device); + cl::CommandQueue queue(context, device); + + std::string source = R"V0G0N( __kernel void testAddition(__global const float *a, __global const float *b, __global float *c) { const int i = get_global_id(0); c[i] = a[i] + b[i]; } )V0G0N"; - cl::Program::Sources sources; - sources.emplace_back(source.c_str(), source.length() + 1); - - cl_int err; - cl::Program program(context, sources); - err = program.build(); - - // if (err != CL_SUCCESS) { - // auto format = fmt::fg(fmt::color::red) | fmt::emphasis::bold; - // fmt::print(format, - // "Error compiling test program: {}\n", - // program.getBuildInfo(device)); - // fmt::print(format, "Error Code [{}]: {}\n", err, opencl::getOpenCLErrorString(err)); - // return false; - // } - - // Check the build status - cl_build_status buildStatus = program.getBuildInfo(device); - - if (buildStatus != CL_BUILD_SUCCESS) { - return {false, - err, - opencl::getOpenCLErrorString(err), - program.getBuildInfo(device)}; - } - - std::vector srcA = {1, 2, 3, 4, 5}; - std::vector srcB = {5, 4, 3, 2, 1}; - std::vector dst(5); - size_t numElements = srcA.size(); - cl::Buffer bufA(context, CL_MEM_READ_ONLY, numElements * sizeof(float)); - cl::Buffer bufB(context, CL_MEM_READ_ONLY, numElements * sizeof(float)); - cl::Buffer bufC(context, CL_MEM_WRITE_ONLY, numElements * sizeof(float)); - - queue.enqueueWriteBuffer(bufA, CL_TRUE, 0, numElements * sizeof(float), srcA.data()); - queue.enqueueWriteBuffer(bufB, CL_TRUE, 0, numElements * sizeof(float), srcB.data()); - - cl::Kernel kernel(program, "testAddition"); - kernel.setArg(0, bufA); - kernel.setArg(1, bufB); - kernel.setArg(2, bufC); - - cl::NDRange global_size(numElements); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, global_size, cl::NullRange); - queue.enqueueReadBuffer(bufC, CL_TRUE, 0, numElements * sizeof(float), dst.data()); - - bool pass = dst == std::vector({6, 6, 6, 6, 6}); - return {pass, 0, "UNKNOWN_ERROR", ""}; - } catch (const std::exception &e) { - return { - false, - -1, - e.what(), - "UNKNOWN_ERROR", - }; - } - } - - int64_t openclDeviceCompute(const cl::Device &device) { - cl_uint computeUnits = device.getInfo(); - cl_uint clockFreq = device.getInfo(); - cl_ulong globalMemSize = device.getInfo(); - cl_device_type deviceType = device.getInfo(); - std::string vendorName = device.getInfo(); - - int64_t typeScore = (deviceType == CL_DEVICE_TYPE_GPU) ? 1000000 : 0; - int64_t cudaScore = (vendorName.find("NVIDIA") != std::string::npos) ? 1000000 : 0; - int64_t memScore = globalMemSize / (1024 * 1024); - - return static_cast(computeUnits * clockFreq) + typeScore + cudaScore + memScore; - } - - void updateOpenCLDevices(bool verbose) { - std::vector platforms; - cl::Platform::get(&platforms); - - for (const auto &platform : platforms) { - std::vector devices; - platform.getDevices(CL_DEVICE_TYPE_ALL, &devices); - if (!devices.empty()) { - if (verbose) { - fmt::print("Platform: {}\n", platform.getInfo()); - - fmt::print(fmt::fg(fmt::color::gray), - " Vendor : {}\n Version: {}\n", - platform.getInfo(), - platform.getInfo()); - } - - for (auto &device : devices) { - // Test the device to check it works - auto [pass, err, errStr, buildLog] = testOpenCLDevice(device); - - fmt::text_style format; - if (pass) - format = fmt::emphasis::bold | fmt::fg(fmt::color::green); - else - format = fmt::emphasis::bold | fmt::fg(fmt::color::red); - - if (verbose) { - fmt::print(format, - "\tDevice [id={}]: {}{}\n", - global::openclDevices.size(), - device.getInfo(), - pass ? "" : " [ FAILED ]"); - - auto computeUnits = device.getInfo(); - auto clocFreq = device.getInfo(); - auto memory = - (device.getInfo() + (1 << 30)) / (1 << 30); - auto version = device.getInfo(); - auto profile = device.getInfo(); - fmt::print(format, "\t\tCompute Units: {}\n", computeUnits); - fmt::print(format, "\t\tClock: {}MHz\n", clocFreq); - fmt::print(format, "\t\tMemory: {}GB\n", memory); - fmt::print(format, "\t\tVersion: {}\n", version); - fmt::print(format, "\t\tProfile: {}\n", profile); - fmt::print(format, "\t\tCompute Score: {}\n", openclDeviceCompute(device)); - - if (!pass) { - fmt::print(format, "\t\tError: {}\n", errStr); - fmt::print(format, "\t\tBuild Log: "); - fmt::print(fmt::fg(fmt::color::gray), "{}\n", buildLog); - } - } - - if (!pass) continue; - - global::openclDevices.push_back(device); - } - } - } - } - - cl::Device findFastestDevice(const std::vector &devices) { - LIBRAPID_ASSERT(!devices.empty(), "No OpenCL devices found"); - cl::Device fastest; - int64_t fastestCompute = 0; - for (auto &device : devices) { - int64_t compute = openclDeviceCompute(device); - if (compute > fastestCompute) { - fastestCompute = compute; - fastest = device; - } - } - return fastest; - } - - void addOpenCLKernelSource(const std::string &source) { - global::openCLSources.emplace_back(source.c_str(), source.size()); - } - - void addOpenCLKernelFile(const std::string &filename) { - std::ifstream file(filename); - std::string source((std::istreambuf_iterator(file)), - std::istreambuf_iterator()); - source += "\n\n\n"; - char *cstr = new char[source.length() + 1]; - strcpy(cstr, source.c_str()); - global::openCLSources.emplace_back(cstr, source.size()); - } - - void compileOpenCLKernels(bool verbose) { - bool finished = false; - std::thread printer; - - if (verbose) { - printer = std::thread([&]() { - int64_t dots = 0; - auto fmtInProg = fmt::fg(fmt::color::orange) | fmt::emphasis::italic; - auto fmtDone = fmt::fg(fmt::color::green) | fmt::emphasis::bold; - fmt::print(fmtInProg, "Compiling OpenCL kernels..."); - while (!finished) { - if (verbose) { - fmt::print(fmtInProg, "."); - sleep(0.5); - ++dots; - } - } - std::string padding(dots + 10, ' '); - fmt::print(fmtDone, "\rOpenCL Kernels Compiled{}", padding); - fmt::print("\n\n"); - }); - } - - global::openCLProgram = cl::Program(global::openCLContext, global::openCLSources); - global::openCLProgram.build({global::openCLDevice}); - cl_build_status status = - global::openCLProgram.getBuildInfo(global::openCLDevice); - - finished = true; - if (verbose) printer.join(); - - if (status != CL_BUILD_SUCCESS) { - std::string buildLog = - global::openCLProgram.getBuildInfo(global::openCLDevice); - std::string errorMsg = fmt::format("OpenCL kernel compilation failed:\n{}", buildLog); - fmt::print(stderr, "{}\n", errorMsg); - std::cout << std::endl; - throw std::runtime_error(errorMsg); - } - } - - void configureOpenCL(bool verbose, bool ask) { - LIBRAPID_ASSERT(!global::openCLConfigured, "OpenCL already configured"); - - if (verbose) { - auto format = fmt::emphasis::bold | fmt::fg(fmt::color::orange); - fmt::print(format, "============== OpenCL Configuration ==============\n"); - } - updateOpenCLDevices(verbose); - - if (!ask) { - // Select the fastest device by default - global::openCLDevice = findFastestDevice(global::openclDevices); - } else { - // Otherwise, prompt the user to select a device - int64_t deviceIndex = -1; - while (deviceIndex < 0 || deviceIndex >= int64_t(global::openclDevices.size())) { - std::string prompt = - fmt::format("Select OpenCL device [0-{}]: ", global::openclDevices.size() - 1); - scn::prompt(prompt.c_str(), "{}", deviceIndex); - } - - global::openCLDevice = global::openclDevices[deviceIndex]; - } - - if (verbose) { - auto format = fmt::emphasis::bold | fmt::fg(fmt::color::gold); - - std::string deviceDetails = - fmt::format("Selected Device: {}", global::openCLDevice.getInfo()); - fmt::print(format, - "\n{:=^{}}\n# {} #\n{:=^{}}\n\n", - "", - deviceDetails.length() + 6, - deviceDetails, - "", - deviceDetails.length() + 6); - } - - global::openCLContext = cl::Context(global::openCLDevice); - global::openCLQueue = cl::CommandQueue(global::openCLContext, global::openCLDevice); - - // Add kernel files - auto basePath = fmt::format("{}/include/librapid/OpenCL/kernels/", LIBRAPID_SOURCE); - addOpenCLKernelFile(basePath + "core.cl"); - addOpenCLKernelFile(basePath + "dual.cl"); - addOpenCLKernelFile(basePath + "fill.cl"); - addOpenCLKernelFile(basePath + "negate.cl"); - addOpenCLKernelFile(basePath + "arithmetic.cl"); - addOpenCLKernelFile(basePath + "abs.cl"); - addOpenCLKernelFile(basePath + "floorCeilRound.cl"); - addOpenCLKernelFile(basePath + "trigonometry.cl"); - addOpenCLKernelFile(basePath + "expLogPow.cl"); - addOpenCLKernelFile(basePath + "transpose.cl"); - addOpenCLKernelFile( - fmt::format("{}/include/librapid/array/linalg/level3/gemm.cl", LIBRAPID_SOURCE)); - addOpenCLKernelFile( - fmt::format("{}/include/librapid/array/linalg/level2/gemv.cl", LIBRAPID_SOURCE)); - addOpenCLKernelFile(basePath + "activations.cl"); - - // Compile kernels - compileOpenCLKernels(verbose); - - global::openCLConfigured = true; - } + cl::Program::Sources sources; + sources.emplace_back(source.c_str(), source.length() + 1); + + cl_int err; + cl::Program program(context, sources); + err = program.build(); + + // if (err != CL_SUCCESS) { + // auto format = fmt::fg(fmt::color::red) | fmt::emphasis::bold; + // fmt::print(format, + // "Error compiling test program: {}\n", + // program.getBuildInfo(device)); + // fmt::print(format, "Error Code [{}]: {}\n", err, opencl::getOpenCLErrorString(err)); + // return false; + // } + + // Check the build status + cl_build_status buildStatus = program.getBuildInfo(device); + + if (buildStatus != CL_BUILD_SUCCESS) { + return {false, + err, + opencl::getOpenCLErrorString(err), + program.getBuildInfo(device)}; + } + + std::vector srcA = {1, 2, 3, 4, 5}; + std::vector srcB = {5, 4, 3, 2, 1}; + std::vector dst(5); + size_t numElements = srcA.size(); + cl::Buffer bufA(context, CL_MEM_READ_ONLY, numElements * sizeof(float)); + cl::Buffer bufB(context, CL_MEM_READ_ONLY, numElements * sizeof(float)); + cl::Buffer bufC(context, CL_MEM_WRITE_ONLY, numElements * sizeof(float)); + + queue.enqueueWriteBuffer(bufA, CL_TRUE, 0, numElements * sizeof(float), srcA.data()); + queue.enqueueWriteBuffer(bufB, CL_TRUE, 0, numElements * sizeof(float), srcB.data()); + + cl::Kernel kernel(program, "testAddition"); + kernel.setArg(0, bufA); + kernel.setArg(1, bufB); + kernel.setArg(2, bufC); + + cl::NDRange global_size(numElements); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, global_size, cl::NullRange); + queue.enqueueReadBuffer(bufC, CL_TRUE, 0, numElements * sizeof(float), dst.data()); + + bool pass = dst == std::vector({6, 6, 6, 6, 6}); + return {pass, 0, "UNKNOWN_ERROR", ""}; + } catch (const std::exception &e) { + return { + false, + -1, + e.what(), + "UNKNOWN_ERROR", + }; + } + } + + int64_t openclDeviceCompute(const cl::Device &device) { + cl_uint computeUnits = device.getInfo(); + cl_uint clockFreq = device.getInfo(); + cl_ulong globalMemSize = device.getInfo(); + cl_device_type deviceType = device.getInfo(); + std::string vendorName = device.getInfo(); + + int64_t typeScore = (deviceType == CL_DEVICE_TYPE_GPU) ? 1000000 : 0; + int64_t cudaScore = (vendorName.find("NVIDIA") != std::string::npos) ? 1000000 : 0; + int64_t memScore = globalMemSize / (1024 * 1024); + + return static_cast(computeUnits * clockFreq) + typeScore + cudaScore + memScore; + } + + void updateOpenCLDevices(bool verbose) { + std::vector platforms; + cl::Platform::get(&platforms); + + for (const auto &platform : platforms) { + std::vector devices; + platform.getDevices(CL_DEVICE_TYPE_ALL, &devices); + if (!devices.empty()) { + if (verbose) { + fmt::print("Platform: {}\n", platform.getInfo()); + + fmt::print(fmt::fg(fmt::color::gray), + " Vendor : {}\n Version: {}\n", + platform.getInfo(), + platform.getInfo()); + } + + for (auto &device : devices) { + // Test the device to check it works + auto [pass, err, errStr, buildLog] = testOpenCLDevice(device); + + fmt::text_style format; + if (pass) + format = fmt::emphasis::bold | fmt::fg(fmt::color::green); + else + format = fmt::emphasis::bold | fmt::fg(fmt::color::red); + + if (verbose) { + fmt::print(format, + "\tDevice [id={}]: {}{}\n", + global::openclDevices.size(), + device.getInfo(), + pass ? "" : " [ FAILED ]"); + + auto computeUnits = device.getInfo(); + auto clocFreq = device.getInfo(); + auto memory = + (device.getInfo() + (1 << 30)) / (1 << 30); + auto version = device.getInfo(); + auto profile = device.getInfo(); + fmt::print(format, "\t\tCompute Units: {}\n", computeUnits); + fmt::print(format, "\t\tClock: {}MHz\n", clocFreq); + fmt::print(format, "\t\tMemory: {}GB\n", memory); + fmt::print(format, "\t\tVersion: {}\n", version); + fmt::print(format, "\t\tProfile: {}\n", profile); + fmt::print(format, "\t\tCompute Score: {}\n", openclDeviceCompute(device)); + + if (!pass) { + fmt::print(format, "\t\tError: {}\n", errStr); + fmt::print(format, "\t\tBuild Log: "); + fmt::print(fmt::fg(fmt::color::gray), "{}\n", buildLog); + } + } + + if (!pass) continue; + + global::openclDevices.push_back(device); + } + } + } + } + + cl::Device findFastestDevice(const std::vector &devices) { + LIBRAPID_ASSERT(!devices.empty(), "No OpenCL devices found"); + cl::Device fastest; + int64_t fastestCompute = 0; + for (auto &device : devices) { + int64_t compute = openclDeviceCompute(device); + if (compute > fastestCompute) { + fastestCompute = compute; + fastest = device; + } + } + return fastest; + } + + void addOpenCLKernelSource(const std::string &source) { + global::openCLSources.emplace_back(source.c_str(), source.size()); + } + + void addOpenCLKernelFile(const std::string &filename) { + std::ifstream file(filename); + std::string source((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + source += "\n\n\n"; + char *cstr = new char[source.length() + 1]; + strcpy(cstr, source.c_str()); + global::openCLSources.emplace_back(cstr, source.size()); + } + + void compileOpenCLKernels(bool verbose) { + bool finished = false; + std::thread printer; + + if (verbose) { + printer = std::thread([&]() { + int64_t dots = 0; + auto fmtInProg = fmt::fg(fmt::color::orange) | fmt::emphasis::italic; + auto fmtDone = fmt::fg(fmt::color::green) | fmt::emphasis::bold; + fmt::print(fmtInProg, "Compiling OpenCL kernels..."); + while (!finished) { + if (verbose) { + fmt::print(fmtInProg, "."); + sleep(0.5); + ++dots; + } + } + std::string padding(dots + 10, ' '); + fmt::print(fmtDone, "\rOpenCL Kernels Compiled{}", padding); + fmt::print("\n\n"); + }); + } + + global::openCLProgram = cl::Program(global::openCLContext, global::openCLSources); + global::openCLProgram.build({global::openCLDevice}); + cl_build_status status = + global::openCLProgram.getBuildInfo(global::openCLDevice); + + finished = true; + if (verbose) printer.join(); + + if (status != CL_BUILD_SUCCESS) { + std::string buildLog = + global::openCLProgram.getBuildInfo(global::openCLDevice); + std::string errorMsg = fmt::format("OpenCL kernel compilation failed:\n{}", buildLog); + fmt::print(stderr, "{}\n", errorMsg); + std::cout << std::endl; + throw std::runtime_error(errorMsg); + } + } + + void configureOpenCL(bool verbose, bool ask) { + LIBRAPID_ASSERT(!global::openCLConfigured, "OpenCL already configured"); + + if (verbose) { + auto format = fmt::emphasis::bold | fmt::fg(fmt::color::orange); + fmt::print(format, "============== OpenCL Configuration ==============\n"); + } + updateOpenCLDevices(verbose); + + if (!ask) { + // Select the fastest device by default + global::openCLDevice = findFastestDevice(global::openclDevices); + } else { + // Otherwise, prompt the user to select a device + int64_t deviceIndex = -1; + while (deviceIndex < 0 || deviceIndex >= int64_t(global::openclDevices.size())) { + std::string prompt = + fmt::format("Select OpenCL device [0-{}]: ", global::openclDevices.size() - 1); + scn::prompt(prompt.c_str(), "{}", deviceIndex); + } + + global::openCLDevice = global::openclDevices[deviceIndex]; + } + + if (verbose) { + auto format = fmt::emphasis::bold | fmt::fg(fmt::color::gold); + + std::string deviceDetails = + fmt::format("Selected Device: {}", global::openCLDevice.getInfo()); + fmt::print(format, + "\n{:=^{}}\n# {} #\n{:=^{}}\n\n", + "", + deviceDetails.length() + 6, + deviceDetails, + "", + deviceDetails.length() + 6); + } + + global::openCLContext = cl::Context(global::openCLDevice); + global::openCLQueue = cl::CommandQueue(global::openCLContext, global::openCLDevice); + + // Add kernel files + auto basePath = fmt::format("{}/include/librapid/OpenCL/kernels/", LIBRAPID_SOURCE); + addOpenCLKernelFile(basePath + "core.cl"); + addOpenCLKernelFile(basePath + "dual.cl"); + addOpenCLKernelFile(basePath + "fill.cl"); + addOpenCLKernelFile(basePath + "negate.cl"); + addOpenCLKernelFile(basePath + "arithmetic.cl"); + addOpenCLKernelFile(basePath + "abs.cl"); + addOpenCLKernelFile(basePath + "floorCeilRound.cl"); + addOpenCLKernelFile(basePath + "trigonometry.cl"); + addOpenCLKernelFile(basePath + "expLogPow.cl"); + addOpenCLKernelFile(basePath + "transpose.cl"); + addOpenCLKernelFile( + fmt::format("{}/include/librapid/array/linalg/level3/gemm.cl", LIBRAPID_SOURCE)); + addOpenCLKernelFile( + fmt::format("{}/include/librapid/array/linalg/level2/gemv.cl", LIBRAPID_SOURCE)); + addOpenCLKernelFile(basePath + "activations.cl"); + + // Compile kernels + compileOpenCLKernels(verbose); + + global::openCLConfigured = true; + } #endif // LIBRAPID_HAS_OPENCL } // namespace librapid diff --git a/librapid/src/openclErrorIdentifier.cpp b/librapid/src/openclErrorIdentifier.cpp index b1fee1b5..8d3a44e5 100644 --- a/librapid/src/openclErrorIdentifier.cpp +++ b/librapid/src/openclErrorIdentifier.cpp @@ -2,85 +2,85 @@ namespace librapid::opencl { #if defined(LIBRAPID_HAS_OPENCL) - std::string getOpenCLErrorString(int64_t error) { - static const char *strings[] = { // Error Codes - "CL_SUCCESS", // 0 - "CL_DEVICE_NOT_FOUND", // -1 - "CL_DEVICE_NOT_AVAILABLE", // -2 - "CL_COMPILER_NOT_AVAILABLE", // -3 - "CL_MEM_OBJECT_ALLOCATION_FAILURE", // -4 - "CL_OUT_OF_RESOURCES", // -5 - "CL_OUT_OF_HOST_MEMORY", // -6 - "CL_PROFILING_INFO_NOT_AVAILABLE", // -7 - "CL_MEM_COPY_OVERLAP", // -8 - "CL_IMAGE_FORMAT_MISMATCH", // -9 - "CL_IMAGE_FORMAT_NOT_SUPPORTED", // -10 - "CL_BUILD_PROGRAM_FAILURE", // -11 - "CL_MAP_FAILURE", // -12 + std::string getOpenCLErrorString(int64_t error) { + static const char *strings[] = { // Error Codes + "CL_SUCCESS", // 0 + "CL_DEVICE_NOT_FOUND", // -1 + "CL_DEVICE_NOT_AVAILABLE", // -2 + "CL_COMPILER_NOT_AVAILABLE", // -3 + "CL_MEM_OBJECT_ALLOCATION_FAILURE", // -4 + "CL_OUT_OF_RESOURCES", // -5 + "CL_OUT_OF_HOST_MEMORY", // -6 + "CL_PROFILING_INFO_NOT_AVAILABLE", // -7 + "CL_MEM_COPY_OVERLAP", // -8 + "CL_IMAGE_FORMAT_MISMATCH", // -9 + "CL_IMAGE_FORMAT_NOT_SUPPORTED", // -10 + "CL_BUILD_PROGRAM_FAILURE", // -11 + "CL_MAP_FAILURE", // -12 - "", // -13 - "", // -14 - "", // -15 - "", // -16 - "", // -17 - "", // -18 - "", // -19 + "", // -13 + "", // -14 + "", // -15 + "", // -16 + "", // -17 + "", // -18 + "", // -19 - "", // -20 - "", // -21 - "", // -22 - "", // -23 - "", // -24 - "", // -25 - "", // -26 - "", // -27 - "", // -28 - "", // -29 + "", // -20 + "", // -21 + "", // -22 + "", // -23 + "", // -24 + "", // -25 + "", // -26 + "", // -27 + "", // -28 + "", // -29 - "CL_INVALID_VALUE", // -30 - "CL_INVALID_DEVICE_TYPE", // -31 - "CL_INVALID_PLATFORM", // -32 - "CL_INVALID_DEVICE", // -33 - "CL_INVALID_CONTEXT", // -34 - "CL_INVALID_QUEUE_PROPERTIES", // -35 - "CL_INVALID_COMMAND_QUEUE", // -36 - "CL_INVALID_HOST_PTR", // -37 - "CL_INVALID_MEM_OBJECT", // -38 - "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR", // -39 - "CL_INVALID_IMAGE_SIZE", // -40 - "CL_INVALID_SAMPLER", // -41 - "CL_INVALID_BINARY", // -42 - "CL_INVALID_BUILD_OPTIONS", // -43 - "CL_INVALID_PROGRAM", // -44 - "CL_INVALID_PROGRAM_EXECUTABLE", // -45 - "CL_INVALID_KERNEL_NAME", // -46 - "CL_INVALID_KERNEL_DEFINITION", // -47 - "CL_INVALID_KERNEL", // -48 - "CL_INVALID_ARG_INDEX", // -49 - "CL_INVALID_ARG_VALUE", // -50 - "CL_INVALID_ARG_SIZE", // -51 - "CL_INVALID_KERNEL_ARGS", // -52 - "CL_INVALID_WORK_DIMENSION", // -53 - "CL_INVALID_WORK_GROUP_SIZE", // -54 - "CL_INVALID_WORK_ITEM_SIZE", // -55 - "CL_INVALID_GLOBAL_OFFSET", // -56 - "CL_INVALID_EVENT_WAIT_LIST", // -57 - "CL_INVALID_EVENT", // -58 - "CL_INVALID_OPERATION", // -59 - "CL_INVALID_GL_OBJECT", // -60 - "CL_INVALID_BUFFER_SIZE", // -61 - "CL_INVALID_MIP_LEVEL", // -62 - "CL_INVALID_GLOBAL_WORK_SIZE", // -63 - "CL_UNKNOWN_ERROR_CODE"}; + "CL_INVALID_VALUE", // -30 + "CL_INVALID_DEVICE_TYPE", // -31 + "CL_INVALID_PLATFORM", // -32 + "CL_INVALID_DEVICE", // -33 + "CL_INVALID_CONTEXT", // -34 + "CL_INVALID_QUEUE_PROPERTIES", // -35 + "CL_INVALID_COMMAND_QUEUE", // -36 + "CL_INVALID_HOST_PTR", // -37 + "CL_INVALID_MEM_OBJECT", // -38 + "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR", // -39 + "CL_INVALID_IMAGE_SIZE", // -40 + "CL_INVALID_SAMPLER", // -41 + "CL_INVALID_BINARY", // -42 + "CL_INVALID_BUILD_OPTIONS", // -43 + "CL_INVALID_PROGRAM", // -44 + "CL_INVALID_PROGRAM_EXECUTABLE", // -45 + "CL_INVALID_KERNEL_NAME", // -46 + "CL_INVALID_KERNEL_DEFINITION", // -47 + "CL_INVALID_KERNEL", // -48 + "CL_INVALID_ARG_INDEX", // -49 + "CL_INVALID_ARG_VALUE", // -50 + "CL_INVALID_ARG_SIZE", // -51 + "CL_INVALID_KERNEL_ARGS", // -52 + "CL_INVALID_WORK_DIMENSION", // -53 + "CL_INVALID_WORK_GROUP_SIZE", // -54 + "CL_INVALID_WORK_ITEM_SIZE", // -55 + "CL_INVALID_GLOBAL_OFFSET", // -56 + "CL_INVALID_EVENT_WAIT_LIST", // -57 + "CL_INVALID_EVENT", // -58 + "CL_INVALID_OPERATION", // -59 + "CL_INVALID_GL_OBJECT", // -60 + "CL_INVALID_BUFFER_SIZE", // -61 + "CL_INVALID_MIP_LEVEL", // -62 + "CL_INVALID_GLOBAL_WORK_SIZE", // -63 + "CL_UNKNOWN_ERROR_CODE"}; - if (error >= -63 && error <= 0) - return strings[-error]; - else - return strings[64]; - } + if (error >= -63 && error <= 0) + return strings[-error]; + else + return strings[64]; + } - std::string getCLBlastErrorString(clblast::StatusCode status) { - // clang-format off + std::string getCLBlastErrorString(clblast::StatusCode status) { + // clang-format off static const std::map statusMap = { {clblast::StatusCode::kSuccess, "CL_SUCCESS"}, {clblast::StatusCode::kOpenCLCompilerNotAvailable, "CL_COMPILER_NOT_AVAILABLE"}, @@ -142,13 +142,13 @@ namespace librapid::opencl { {clblast::StatusCode::kDatabaseError, "Entry for the device was not found in the database"}, {clblast::StatusCode::kUnknownError, "A catch-all error code representing an unspecified error"}, {clblast::StatusCode::kUnexpectedError, "A catch-all error code representing an unexpected exception"}}; - // clang-format on + // clang-format on - auto it = statusMap.find(status); - if (it != statusMap.end()) - return it->second; - else - return "Unknown error"; - } + auto it = statusMap.find(status); + if (it != statusMap.end()) + return it->second; + else + return "Unknown error"; + } #endif // LIBRAPID_HAS_OPENCL } // namespace librapid::opencl diff --git a/librapid/src/preMain.cpp b/librapid/src/preMain.cpp index 4eb89bda..c70b0678 100644 --- a/librapid/src/preMain.cpp +++ b/librapid/src/preMain.cpp @@ -1,42 +1,42 @@ #include namespace librapid::detail { - bool preMainRun = false; + bool preMainRun = false; - PreMain::PreMain() { - if (!preMainRun) { + PreMain::PreMain() { + if (!preMainRun) { #if defined(LIBRAPID_WINDOWS) // && !defined(LIBRAPID_NO_WINDOWS_H) - // Force the terminal to accept ANSI characters - system(("chcp " + std::to_string(CP_UTF8)).c_str()); + // Force the terminal to accept ANSI characters + system(("chcp " + std::to_string(CP_UTF8)).c_str()); #endif // LIBRAPID_WINDOWS - preMainRun = true; - global::cacheLineSize = cacheLineSize(); + preMainRun = true; + global::cacheLineSize = cacheLineSize(); - // OpenCL compatible devices are detected after this function is called, - // meaning nothing is found here. The user must call configureOpenCL() - // manually. + // OpenCL compatible devices are detected after this function is called, + // meaning nothing is found here. The user must call configureOpenCL() + // manually. - // #if defined(LIBRAPID_HAS_OPENCL) - // configureOpenCL(); - // #endif // LIBRAPID_HAS_OPENCL + // #if defined(LIBRAPID_HAS_OPENCL) + // configureOpenCL(); + // #endif // LIBRAPID_HAS_OPENCL #if defined(LIBRAPID_HAS_CUDA) - cudaSafeCall(cudaStreamCreate(&global::cudaStream)); - cublasSafeCall(cublasCreate(&global::cublasHandle)); - cublasSafeCall(cublasSetStream(global::cublasHandle, global::cudaStream)); + cudaSafeCall(cudaStreamCreate(&global::cudaStream)); + cublasSafeCall(cublasCreate(&global::cublasHandle)); + cublasSafeCall(cublasSetStream(global::cublasHandle, global::cudaStream)); - cudaSafeCall(cudaMallocAsync( - &global::cublasLtWorkspace, global::cublasLtWorkspaceSize, global::cudaStream)); + cudaSafeCall(cudaMallocAsync( + &global::cublasLtWorkspace, global::cublasLtWorkspaceSize, global::cudaStream)); - cublasSafeCall(cublasLtCreate(&global::cublasLtHandle)); - cublasSafeCall(cublasSetWorkspace( - global::cublasHandle, global::cublasLtWorkspace, global::cublasLtWorkspaceSize)); - // Stream is specified in the function calls + cublasSafeCall(cublasLtCreate(&global::cublasLtHandle)); + cublasSafeCall(cublasSetWorkspace( + global::cublasHandle, global::cublasLtWorkspace, global::cublasLtWorkspaceSize)); + // Stream is specified in the function calls #endif // LIBRAPID_HAS_CUDA - // Set the random seed to an initial value - global::randomSeed = (size_t) now(); - } - } + // Set the random seed to an initial value + global::randomSeed = (size_t)now(); + } + } } // namespace librapid::detail diff --git a/test/test-array.cpp b/test/test-array.cpp index 50dc12ec..cb64e078 100644 --- a/test/test-array.cpp +++ b/test/test-array.cpp @@ -3,638 +3,638 @@ #include #include -namespace lrc = librapid; +namespace lrc = librapid; constexpr double tolerance = 0.001; -using CPU = lrc::backend::CPU; -using OPENCL = lrc::backend::OpenCL; -using CUDA = lrc::backend::CUDA; +using CPU = lrc::backend::CPU; +using OPENCL = lrc::backend::OpenCL; +using CUDA = lrc::backend::CUDA; #define TEST_CONSTRUCTORS(SCALAR, BACKEND) \ - SECTION(fmt::format("Test Constructors [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ - lrc::Array testA; \ - REQUIRE(testA.shape() == lrc::Array::ShapeType {0}); \ - \ - lrc::Array testB(lrc::Array::ShapeType {3, 4}); \ - REQUIRE(testB.shape() == lrc::Array::ShapeType {3, 4}); \ - \ - lrc::Array testC(lrc::Array::ShapeType {3, 4}, 5); \ - REQUIRE(testC.shape() == lrc::Array::ShapeType {3, 4}); \ - REQUIRE(testC.storage()[0] == 5); \ - REQUIRE(testC.storage()[1] == 5); \ - REQUIRE(testC.storage()[2] == 5); \ - REQUIRE(testC.storage()[9] == 5); \ - REQUIRE(testC.storage()[10] == 5); \ - REQUIRE(testC.storage()[11] == 5); \ - \ - lrc::ArrayF testD(3); \ - REQUIRE(testD.storage()[0] == 3); \ - REQUIRE(testD.storage()[1] == 3); \ - REQUIRE(testD.storage()[2] == 3); \ - REQUIRE(testD.storage()[3] == 3); \ - \ - lrc::Array::ShapeType tmpShape({2, 3}); \ - lrc::Array testE(std::move(tmpShape)); \ - REQUIRE(testE.shape() == lrc::Array::ShapeType {2, 3}); \ - \ - lrc::Array testF(testC); \ - REQUIRE(testF.shape() == lrc::Array::ShapeType {3, 4}); \ - REQUIRE(testF.storage()[0] == 5); \ - REQUIRE(testF.storage()[1] == 5); \ - REQUIRE(testF.storage()[2] == 5); \ - REQUIRE(testF.storage()[9] == 5); \ - REQUIRE(testF.storage()[10] == 5); \ - REQUIRE(testF.storage()[11] == 5); \ - \ - lrc::Array testG(lrc::Array::ShapeType {3, 4}, 10); \ - testC = testG; \ - REQUIRE(testC.storage()[0] == 10); \ - REQUIRE(testC.storage()[1] == 10); \ - REQUIRE(testC.storage()[2] == 10); \ - REQUIRE(testC.storage()[9] == 10); \ - REQUIRE(testC.storage()[10] == 10); \ - REQUIRE(testC.storage()[11] == 10); \ - \ - lrc::Array testH(lrc::Array::ShapeType {3, 3}); \ - testH << 1, 2, 3, 4, 5, 6, 7, 8, 9; \ - REQUIRE(testH.storage()[0] == 1); \ - REQUIRE(testH.storage()[1] == 2); \ - REQUIRE(testH.storage()[2] == 3); \ - REQUIRE(testH.storage()[3] == 4); \ - REQUIRE(testH.storage()[4] == 5); \ - REQUIRE(testH.storage()[5] == 6); \ - REQUIRE(testH.storage()[6] == 7); \ - REQUIRE(testH.storage()[7] == 8); \ - REQUIRE(testH.storage()[8] == 9); \ - \ - /* It is necessary to define the type of the data, otherwise bad things happen for the \ - * MPFR type */ \ - using InitList = \ - std::initializer_list>>; \ - using Vec = std::vector>>; \ - \ - /* Due to the way the code works, if this passes for a 3D array, it *must* pass for all \ - * other dimensions */ \ - auto testI = \ - lrc::Array::fromData(InitList({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); \ - REQUIRE(testI.str() == fmt::format("[[[{} {}]\n [{} {}]]\n\n [[{} {}]\n [{} {}]]]", \ - SCALAR(1), \ - SCALAR(2), \ - SCALAR(3), \ - SCALAR(4), \ - SCALAR(5), \ - SCALAR(6), \ - SCALAR(7), \ - SCALAR(8))); \ - \ - auto testJ = \ - lrc::Array::fromData(Vec({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); \ - REQUIRE(testJ.str() == fmt::format("[[[{} {}]\n [{} {}]]\n\n [[{} {}]\n [{} {}]]]", \ - SCALAR(1), \ - SCALAR(2), \ - SCALAR(3), \ - SCALAR(4), \ - SCALAR(5), \ - SCALAR(6), \ - SCALAR(7), \ - SCALAR(8))); \ - } + SECTION(fmt::format("Test Constructors [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ + lrc::Array testA; \ + REQUIRE(testA.shape() == lrc::Array::ShapeType {0}); \ + \ + lrc::Array testB(lrc::Array::ShapeType {3, 4}); \ + REQUIRE(testB.shape() == lrc::Array::ShapeType {3, 4}); \ + \ + lrc::Array testC(lrc::Array::ShapeType {3, 4}, 5); \ + REQUIRE(testC.shape() == lrc::Array::ShapeType {3, 4}); \ + REQUIRE(testC.storage()[0] == 5); \ + REQUIRE(testC.storage()[1] == 5); \ + REQUIRE(testC.storage()[2] == 5); \ + REQUIRE(testC.storage()[9] == 5); \ + REQUIRE(testC.storage()[10] == 5); \ + REQUIRE(testC.storage()[11] == 5); \ + \ + lrc::ArrayF testD(3); \ + REQUIRE(testD.storage()[0] == 3); \ + REQUIRE(testD.storage()[1] == 3); \ + REQUIRE(testD.storage()[2] == 3); \ + REQUIRE(testD.storage()[3] == 3); \ + \ + lrc::Array::ShapeType tmpShape({2, 3}); \ + lrc::Array testE(std::move(tmpShape)); \ + REQUIRE(testE.shape() == lrc::Array::ShapeType {2, 3}); \ + \ + lrc::Array testF(testC); \ + REQUIRE(testF.shape() == lrc::Array::ShapeType {3, 4}); \ + REQUIRE(testF.storage()[0] == 5); \ + REQUIRE(testF.storage()[1] == 5); \ + REQUIRE(testF.storage()[2] == 5); \ + REQUIRE(testF.storage()[9] == 5); \ + REQUIRE(testF.storage()[10] == 5); \ + REQUIRE(testF.storage()[11] == 5); \ + \ + lrc::Array testG(lrc::Array::ShapeType {3, 4}, 10); \ + testC = testG; \ + REQUIRE(testC.storage()[0] == 10); \ + REQUIRE(testC.storage()[1] == 10); \ + REQUIRE(testC.storage()[2] == 10); \ + REQUIRE(testC.storage()[9] == 10); \ + REQUIRE(testC.storage()[10] == 10); \ + REQUIRE(testC.storage()[11] == 10); \ + \ + lrc::Array testH(lrc::Array::ShapeType {3, 3}); \ + testH << 1, 2, 3, 4, 5, 6, 7, 8, 9; \ + REQUIRE(testH.storage()[0] == 1); \ + REQUIRE(testH.storage()[1] == 2); \ + REQUIRE(testH.storage()[2] == 3); \ + REQUIRE(testH.storage()[3] == 4); \ + REQUIRE(testH.storage()[4] == 5); \ + REQUIRE(testH.storage()[5] == 6); \ + REQUIRE(testH.storage()[6] == 7); \ + REQUIRE(testH.storage()[7] == 8); \ + REQUIRE(testH.storage()[8] == 9); \ + \ + /* It is necessary to define the type of the data, otherwise bad things happen for the \ + * MPFR type */ \ + using InitList = \ + std::initializer_list>>; \ + using Vec = std::vector>>; \ + \ + /* Due to the way the code works, if this passes for a 3D array, it *must* pass for all \ + * other dimensions */ \ + auto testI = \ + lrc::Array::fromData(InitList({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); \ + REQUIRE(testI.str() == fmt::format("[[[{} {}]\n [{} {}]]\n\n [[{} {}]\n [{} {}]]]", \ + SCALAR(1), \ + SCALAR(2), \ + SCALAR(3), \ + SCALAR(4), \ + SCALAR(5), \ + SCALAR(6), \ + SCALAR(7), \ + SCALAR(8))); \ + \ + auto testJ = \ + lrc::Array::fromData(Vec({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); \ + REQUIRE(testJ.str() == fmt::format("[[[{} {}]\n [{} {}]]\n\n [[{} {}]\n [{} {}]]]", \ + SCALAR(1), \ + SCALAR(2), \ + SCALAR(3), \ + SCALAR(4), \ + SCALAR(5), \ + SCALAR(6), \ + SCALAR(7), \ + SCALAR(8))); \ + } #define TEST_INDEXING(SCALAR, BACKEND) \ - SECTION(fmt::format("Test Indexing [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ - lrc::Array testA(lrc::Array::ShapeType({5, 3})); \ - testA << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15; \ - std::string index0 = fmt::format("[{} {} {}]", SCALAR(1), SCALAR(2), SCALAR(3)); \ - std::string index1 = fmt::format("[{} {} {}]", SCALAR(4), SCALAR(5), SCALAR(6)); \ - std::string index2 = fmt::format("[{} {} {}]", SCALAR(7), SCALAR(8), SCALAR(9)); \ - std::string index3 = fmt::format("[{} {} {}]", SCALAR(10), SCALAR(11), SCALAR(12)); \ - std::string index4 = fmt::format("[{} {} {}]", SCALAR(13), SCALAR(14), SCALAR(15)); \ - REQUIRE(testA[0].str() == index0); \ - REQUIRE(testA[1].str() == index1); \ - REQUIRE(testA[2].str() == index2); \ - REQUIRE(testA[3].str() == index3); \ - REQUIRE(testA[4].str() == index4); \ - REQUIRE(testA[0][0].str() == fmt::format("{}", SCALAR(1))); \ - REQUIRE(testA[1][1].str() == fmt::format("{}", SCALAR(5))); \ - REQUIRE(testA[2][2].str() == fmt::format("{}", SCALAR(9))); \ - \ - testA[1][2] = 123; \ - \ - REQUIRE(testA[0][0].get() == SCALAR(1)); \ - REQUIRE(testA[1][1].get() == SCALAR(5)); \ - REQUIRE(testA[2][2].get() == SCALAR(9)); \ - REQUIRE(testA[1][2].get() == SCALAR(123)); \ - \ - testA[0][0] = 123; \ - testA[1][1] = 456; \ - testA[2][2] = 789; \ - REQUIRE((SCALAR)testA.storage()[0] == SCALAR(123)); \ - REQUIRE((SCALAR)testA.storage()[4] == SCALAR(456)); \ - REQUIRE((SCALAR)testA.storage()[8] == SCALAR(789)); \ - \ - lrc::Array testB(lrc::Array::ShapeType({10})); \ - testB << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10; \ - REQUIRE(testB[0].get() == SCALAR(1)); \ - REQUIRE(testB[9].get() == SCALAR(10)); \ - } + SECTION(fmt::format("Test Indexing [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ + lrc::Array testA(lrc::Array::ShapeType({5, 3})); \ + testA << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15; \ + std::string index0 = fmt::format("[{} {} {}]", SCALAR(1), SCALAR(2), SCALAR(3)); \ + std::string index1 = fmt::format("[{} {} {}]", SCALAR(4), SCALAR(5), SCALAR(6)); \ + std::string index2 = fmt::format("[{} {} {}]", SCALAR(7), SCALAR(8), SCALAR(9)); \ + std::string index3 = fmt::format("[{} {} {}]", SCALAR(10), SCALAR(11), SCALAR(12)); \ + std::string index4 = fmt::format("[{} {} {}]", SCALAR(13), SCALAR(14), SCALAR(15)); \ + REQUIRE(testA[0].str() == index0); \ + REQUIRE(testA[1].str() == index1); \ + REQUIRE(testA[2].str() == index2); \ + REQUIRE(testA[3].str() == index3); \ + REQUIRE(testA[4].str() == index4); \ + REQUIRE(testA[0][0].str() == fmt::format("{}", SCALAR(1))); \ + REQUIRE(testA[1][1].str() == fmt::format("{}", SCALAR(5))); \ + REQUIRE(testA[2][2].str() == fmt::format("{}", SCALAR(9))); \ + \ + testA[1][2] = 123; \ + \ + REQUIRE(testA[0][0].get() == SCALAR(1)); \ + REQUIRE(testA[1][1].get() == SCALAR(5)); \ + REQUIRE(testA[2][2].get() == SCALAR(9)); \ + REQUIRE(testA[1][2].get() == SCALAR(123)); \ + \ + testA[0][0] = 123; \ + testA[1][1] = 456; \ + testA[2][2] = 789; \ + REQUIRE((SCALAR)testA.storage()[0] == SCALAR(123)); \ + REQUIRE((SCALAR)testA.storage()[4] == SCALAR(456)); \ + REQUIRE((SCALAR)testA.storage()[8] == SCALAR(789)); \ + \ + lrc::Array testB(lrc::Array::ShapeType({10})); \ + testB << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10; \ + REQUIRE(testB[0].get() == SCALAR(1)); \ + REQUIRE(testB[9].get() == SCALAR(10)); \ + } #define TEST_STRING_FORMATTING(SCALAR, BACKEND) \ - SECTION( \ - fmt::format("Test String Formatting [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ - lrc::Array testA(lrc::Array::ShapeType({2, 3})); \ - testA << 1, 2, 3, 4, 5, 6; \ - \ - REQUIRE(testA.str() == fmt::format("[[{} {} {}]\n [{} {} {}]]", \ - SCALAR(1), \ - SCALAR(2), \ - SCALAR(3), \ - SCALAR(4), \ - SCALAR(5), \ - SCALAR(6))); \ - \ - lrc::Array testB(lrc::Array::ShapeType({2, 3})); \ - testB << 10, 2, 3, 4, 5, 6; \ - \ - REQUIRE(testB.str() == fmt::format("[[{} {} {}]\n [ {} {} {}]]", \ - SCALAR(10), \ - SCALAR(2), \ - SCALAR(3), \ - SCALAR(4), \ - SCALAR(5), \ - SCALAR(6))); \ - \ - lrc::Array testC(lrc::Array::ShapeType({2, 2, 2})); \ - testC << 100, 2, 3, 4, 5, 6, 7, 8; \ - REQUIRE(testC.str() == \ - fmt::format("[[[{} {}]\n [ {} {}]]\n\n [[ {} {}]\n [ {} {}]]]", \ - SCALAR(100), \ - SCALAR(2), \ - SCALAR(3), \ - SCALAR(4), \ - SCALAR(5), \ - SCALAR(6), \ - SCALAR(7), \ - SCALAR(8))); \ - } + SECTION( \ + fmt::format("Test String Formatting [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ + lrc::Array testA(lrc::Array::ShapeType({2, 3})); \ + testA << 1, 2, 3, 4, 5, 6; \ + \ + REQUIRE(testA.str() == fmt::format("[[{} {} {}]\n [{} {} {}]]", \ + SCALAR(1), \ + SCALAR(2), \ + SCALAR(3), \ + SCALAR(4), \ + SCALAR(5), \ + SCALAR(6))); \ + \ + lrc::Array testB(lrc::Array::ShapeType({2, 3})); \ + testB << 10, 2, 3, 4, 5, 6; \ + \ + REQUIRE(testB.str() == fmt::format("[[{} {} {}]\n [ {} {} {}]]", \ + SCALAR(10), \ + SCALAR(2), \ + SCALAR(3), \ + SCALAR(4), \ + SCALAR(5), \ + SCALAR(6))); \ + \ + lrc::Array testC(lrc::Array::ShapeType({2, 2, 2})); \ + testC << 100, 2, 3, 4, 5, 6, 7, 8; \ + REQUIRE(testC.str() == \ + fmt::format("[[[{} {}]\n [ {} {}]]\n\n [[ {} {}]\n [ {} {}]]]", \ + SCALAR(100), \ + SCALAR(2), \ + SCALAR(3), \ + SCALAR(4), \ + SCALAR(5), \ + SCALAR(6), \ + SCALAR(7), \ + SCALAR(8))); \ + } #define TEST_ARITHMETIC(SCALAR, BACKEND) \ - SECTION( \ - fmt::format("Test Array Operations [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ - lrc::Array::ShapeType shape({37, 41}); \ - lrc::Array testA(shape); /* Prime-dimensioned to force wrapping */ \ - lrc::Array testB(shape); \ - \ - for (int64_t i = 0; i < shape[0]; ++i) { \ - for (int64_t j = 0; j < shape[1]; ++j) { \ - SCALAR a = j + i * shape[1] + 1; \ - SCALAR b = i + j * shape[0] + 1; \ - \ - testA[i][j] = a; \ - testB[i][j] = b != 0 ? b : 1; \ - } \ - } \ - \ - auto negResult = (-testA).eval(); \ - bool negValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(negResult.scalar(i) == -(testA.scalar(i)))) { \ - REQUIRE(lrc::isClose(negResult.scalar(i), -(testA.scalar(i)), tolerance)); \ - negValid = false; \ - } \ - } \ - REQUIRE(negValid); \ - \ - auto sumResult = (testA + testB).eval(); \ - bool sumValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(sumResult.scalar(i) == testA.scalar(i) + testB.scalar(i))) { \ - REQUIRE(lrc::isClose( \ - sumResult.scalar(i), testA.scalar(i) + testB.scalar(i), tolerance)); \ - sumValid = false; \ - } \ - } \ - REQUIRE(sumValid); \ - \ - auto diffResult = (testA - testB).eval(); \ - bool diffValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(diffResult.scalar(i) == testA.scalar(i) - testB.scalar(i))) { \ - REQUIRE(lrc::isClose( \ - diffResult.scalar(i), testA.scalar(i) - testB.scalar(i), tolerance)); \ - diffValid = false; \ - } \ - } \ - REQUIRE(diffValid); \ - \ - auto prodResult = (testA * testB).eval(); \ - bool prodValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(prodResult.scalar(i) == testA.scalar(i) * testB.scalar(i))) { \ - REQUIRE(lrc::isClose( \ - prodResult.scalar(i), testA.scalar(i) * testB.scalar(i), tolerance)); \ - prodValid = false; \ - } \ - } \ - REQUIRE(prodValid); \ - \ - auto divResult = (testA / testB).eval(); \ - bool divValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(divResult.scalar(i) == testA.scalar(i) / testB.scalar(i))) { \ - REQUIRE(lrc::isClose( \ - divResult.scalar(i), testA.scalar(i) / testB.scalar(i), tolerance)); \ - divValid = false; \ - } \ - } \ - REQUIRE(diffValid); \ - } \ - do { \ - } while (false) + SECTION( \ + fmt::format("Test Array Operations [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ + lrc::Array::ShapeType shape({37, 41}); \ + lrc::Array testA(shape); /* Prime-dimensioned to force wrapping */ \ + lrc::Array testB(shape); \ + \ + for (int64_t i = 0; i < shape[0]; ++i) { \ + for (int64_t j = 0; j < shape[1]; ++j) { \ + SCALAR a = j + i * shape[1] + 1; \ + SCALAR b = i + j * shape[0] + 1; \ + \ + testA[i][j] = a; \ + testB[i][j] = b != 0 ? b : 1; \ + } \ + } \ + \ + auto negResult = (-testA).eval(); \ + bool negValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(negResult.scalar(i) == -(testA.scalar(i)))) { \ + REQUIRE(lrc::isClose(negResult.scalar(i), -(testA.scalar(i)), tolerance)); \ + negValid = false; \ + } \ + } \ + REQUIRE(negValid); \ + \ + auto sumResult = (testA + testB).eval(); \ + bool sumValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(sumResult.scalar(i) == testA.scalar(i) + testB.scalar(i))) { \ + REQUIRE(lrc::isClose( \ + sumResult.scalar(i), testA.scalar(i) + testB.scalar(i), tolerance)); \ + sumValid = false; \ + } \ + } \ + REQUIRE(sumValid); \ + \ + auto diffResult = (testA - testB).eval(); \ + bool diffValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(diffResult.scalar(i) == testA.scalar(i) - testB.scalar(i))) { \ + REQUIRE(lrc::isClose( \ + diffResult.scalar(i), testA.scalar(i) - testB.scalar(i), tolerance)); \ + diffValid = false; \ + } \ + } \ + REQUIRE(diffValid); \ + \ + auto prodResult = (testA * testB).eval(); \ + bool prodValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(prodResult.scalar(i) == testA.scalar(i) * testB.scalar(i))) { \ + REQUIRE(lrc::isClose( \ + prodResult.scalar(i), testA.scalar(i) * testB.scalar(i), tolerance)); \ + prodValid = false; \ + } \ + } \ + REQUIRE(prodValid); \ + \ + auto divResult = (testA / testB).eval(); \ + bool divValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(divResult.scalar(i) == testA.scalar(i) / testB.scalar(i))) { \ + REQUIRE(lrc::isClose( \ + divResult.scalar(i), testA.scalar(i) / testB.scalar(i), tolerance)); \ + divValid = false; \ + } \ + } \ + REQUIRE(diffValid); \ + } \ + do { \ + } while (false) #define TEST_ARITHMETIC_ARRAY_SCALAR(SCALAR, BACKEND) \ - SECTION(fmt::format( \ - "Test Array-Scalar Operations [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ - lrc::Array::ShapeType shape({37, 41}); \ - lrc::Array testA(shape); /* Prime-dimensioned to force wrapping */ \ - \ - for (int64_t i = 0; i < shape[0]; ++i) { \ - for (int64_t j = 0; j < shape[1]; ++j) { \ - SCALAR a = j + i * shape[1] + SCALAR(1); \ - testA[i][j] = a; \ - } \ - } \ - \ - auto sumResult = (testA + SCALAR(1)).eval(); \ - bool sumValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(sumResult.scalar(i) == testA.scalar(i) + SCALAR(1))) { \ - REQUIRE( \ - lrc::isClose(sumResult.scalar(i), testA.scalar(i) + SCALAR(1), tolerance)); \ - sumValid = false; \ - } \ - } \ - REQUIRE(sumValid); \ - \ - auto diffResult = (testA - SCALAR(1)).eval(); \ - bool diffValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(diffResult.scalar(i) == testA.scalar(i) - SCALAR(1))) { \ - REQUIRE( \ - lrc::isClose(diffResult.scalar(i), testA.scalar(i) - SCALAR(1), tolerance)); \ - diffValid = false; \ - } \ - } \ - REQUIRE(diffValid); \ - \ - auto prodResult = (testA * SCALAR(2)).eval(); \ - bool prodValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(prodResult.scalar(i) == testA.scalar(i) * SCALAR(2))) { \ - REQUIRE( \ - lrc::isClose(prodResult.scalar(i), testA.scalar(i) * SCALAR(2), tolerance)); \ - prodValid = false; \ - } \ - } \ - REQUIRE(prodValid); \ - \ - auto divResult = (testA / SCALAR(2)).eval(); \ - bool divValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(divResult.scalar(i) == testA.scalar(i) / SCALAR(2))) { \ - REQUIRE( \ - lrc::isClose(divResult.scalar(i), testA.scalar(i) / SCALAR(2), tolerance)); \ - divValid = false; \ - } \ - } \ - REQUIRE(diffValid); \ - } \ - do { \ - } while (false) + SECTION(fmt::format( \ + "Test Array-Scalar Operations [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ + lrc::Array::ShapeType shape({37, 41}); \ + lrc::Array testA(shape); /* Prime-dimensioned to force wrapping */ \ + \ + for (int64_t i = 0; i < shape[0]; ++i) { \ + for (int64_t j = 0; j < shape[1]; ++j) { \ + SCALAR a = j + i * shape[1] + SCALAR(1); \ + testA[i][j] = a; \ + } \ + } \ + \ + auto sumResult = (testA + SCALAR(1)).eval(); \ + bool sumValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(sumResult.scalar(i) == testA.scalar(i) + SCALAR(1))) { \ + REQUIRE( \ + lrc::isClose(sumResult.scalar(i), testA.scalar(i) + SCALAR(1), tolerance)); \ + sumValid = false; \ + } \ + } \ + REQUIRE(sumValid); \ + \ + auto diffResult = (testA - SCALAR(1)).eval(); \ + bool diffValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(diffResult.scalar(i) == testA.scalar(i) - SCALAR(1))) { \ + REQUIRE( \ + lrc::isClose(diffResult.scalar(i), testA.scalar(i) - SCALAR(1), tolerance)); \ + diffValid = false; \ + } \ + } \ + REQUIRE(diffValid); \ + \ + auto prodResult = (testA * SCALAR(2)).eval(); \ + bool prodValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(prodResult.scalar(i) == testA.scalar(i) * SCALAR(2))) { \ + REQUIRE( \ + lrc::isClose(prodResult.scalar(i), testA.scalar(i) * SCALAR(2), tolerance)); \ + prodValid = false; \ + } \ + } \ + REQUIRE(prodValid); \ + \ + auto divResult = (testA / SCALAR(2)).eval(); \ + bool divValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(divResult.scalar(i) == testA.scalar(i) / SCALAR(2))) { \ + REQUIRE( \ + lrc::isClose(divResult.scalar(i), testA.scalar(i) / SCALAR(2), tolerance)); \ + divValid = false; \ + } \ + } \ + REQUIRE(diffValid); \ + } \ + do { \ + } while (false) #define TEST_ARITHMETIC_SCALAR_ARRAY(SCALAR, BACKEND) \ - SECTION(fmt::format( \ - "Test Scalar-Array Operations [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ - lrc::Array::ShapeType shape({37, 41}); \ - lrc::Array testB(shape); \ - \ - for (int64_t i = 0; i < shape[0]; ++i) { \ - for (int64_t j = 0; j < shape[1]; ++j) { \ - SCALAR b = i + j * shape[0] + 1; \ - testB[i][j] = b != 0 ? b : 1; \ - } \ - } \ - \ - auto sumResult = (SCALAR(1) + testB).eval(); \ - bool sumValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(sumResult.scalar(i) == SCALAR(1) + testB.scalar(i))) { \ - REQUIRE( \ - lrc::isClose(sumResult.scalar(i), SCALAR(1) + testB.scalar(i), tolerance)); \ - sumValid = false; \ - } \ - } \ - REQUIRE(sumValid); \ - \ - auto diffResult = (1 - testB).eval(); \ - bool diffValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(diffResult.scalar(i) == SCALAR(1) - testB.scalar(i))) { \ - REQUIRE( \ - lrc::isClose(diffResult.scalar(i), SCALAR(1) - testB.scalar(i), tolerance)); \ - diffValid = false; \ - } \ - } \ - REQUIRE(diffValid); \ - \ - auto prodResult = (SCALAR(2) * testB).eval(); \ - bool prodValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(prodResult.scalar(i) == SCALAR(2) * testB.scalar(i))) { \ - REQUIRE( \ - lrc::isClose(prodResult.scalar(i), SCALAR(2) * testB.scalar(i), tolerance)); \ - prodValid = false; \ - } \ - } \ - REQUIRE(prodValid); \ - \ - auto divResult = (SCALAR(2) / testB).eval(); \ - bool divValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(divResult.scalar(i) == SCALAR(2) / testB.scalar(i))) { \ - REQUIRE( \ - lrc::isClose(divResult.scalar(i), SCALAR(2) / testB.scalar(i), tolerance)); \ - divValid = false; \ - } \ - } \ - REQUIRE(diffValid); \ - } \ - do { \ - } while (false) + SECTION(fmt::format( \ + "Test Scalar-Array Operations [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ + lrc::Array::ShapeType shape({37, 41}); \ + lrc::Array testB(shape); \ + \ + for (int64_t i = 0; i < shape[0]; ++i) { \ + for (int64_t j = 0; j < shape[1]; ++j) { \ + SCALAR b = i + j * shape[0] + 1; \ + testB[i][j] = b != 0 ? b : 1; \ + } \ + } \ + \ + auto sumResult = (SCALAR(1) + testB).eval(); \ + bool sumValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(sumResult.scalar(i) == SCALAR(1) + testB.scalar(i))) { \ + REQUIRE( \ + lrc::isClose(sumResult.scalar(i), SCALAR(1) + testB.scalar(i), tolerance)); \ + sumValid = false; \ + } \ + } \ + REQUIRE(sumValid); \ + \ + auto diffResult = (1 - testB).eval(); \ + bool diffValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(diffResult.scalar(i) == SCALAR(1) - testB.scalar(i))) { \ + REQUIRE( \ + lrc::isClose(diffResult.scalar(i), SCALAR(1) - testB.scalar(i), tolerance)); \ + diffValid = false; \ + } \ + } \ + REQUIRE(diffValid); \ + \ + auto prodResult = (SCALAR(2) * testB).eval(); \ + bool prodValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(prodResult.scalar(i) == SCALAR(2) * testB.scalar(i))) { \ + REQUIRE( \ + lrc::isClose(prodResult.scalar(i), SCALAR(2) * testB.scalar(i), tolerance)); \ + prodValid = false; \ + } \ + } \ + REQUIRE(prodValid); \ + \ + auto divResult = (SCALAR(2) / testB).eval(); \ + bool divValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(divResult.scalar(i) == SCALAR(2) / testB.scalar(i))) { \ + REQUIRE( \ + lrc::isClose(divResult.scalar(i), SCALAR(2) / testB.scalar(i), tolerance)); \ + divValid = false; \ + } \ + } \ + REQUIRE(diffValid); \ + } \ + do { \ + } while (false) #define TEST_COMPARISONS(SCALAR, BACKEND) \ - SECTION( \ - fmt::format("Test Array Comparisons [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ - lrc::Array::ShapeType shape({53, 79}); \ - lrc::Array testA(shape); /* Prime-dimensioned to force wrapping */ \ - lrc::Array testB(shape); \ - \ - for (int64_t i = 0; i < shape[0]; ++i) { \ - for (int64_t j = 0; j < shape[1]; ++j) { \ - SCALAR a = j + i * shape[1] + 1; \ - SCALAR b = i + j * shape[0] + 1; \ - \ - testA[i][j] = a; \ - testB[i][j] = b != 0 ? b : 1; \ - } \ - } \ - \ - auto gtResult = (testA > testB).eval(); \ - bool gtValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(gtResult.scalar(i) == testA.scalar(i) > testB.scalar(i))) { \ - REQUIRE(gtResult.scalar(i) == testA.scalar(i) > testB.scalar(i)); \ - gtValid = false; \ - } \ - } \ - REQUIRE(gtValid); \ - \ - auto geResult = (testA >= testB).eval(); \ - bool geValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(geResult.scalar(i) == testA.scalar(i) >= testB.scalar(i))) { \ - REQUIRE(geResult.scalar(i) == testA.scalar(i) >= testB.scalar(i)); \ - geValid = false; \ - } \ - } \ - REQUIRE(geValid); \ - \ - auto ltResult = (testA < testB).eval(); \ - bool ltValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(ltResult.scalar(i) == testA.scalar(i) < testB.scalar(i))) { \ - REQUIRE(ltResult.scalar(i) == testA.scalar(i) < testB.scalar(i)); \ - ltValid = false; \ - } \ - } \ - REQUIRE(ltValid); \ - \ - auto leResult = (testA <= testB).eval(); \ - bool leValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(leResult.scalar(i) == testA.scalar(i) <= testB.scalar(i))) { \ - REQUIRE(leResult.scalar(i) == testA.scalar(i) <= testB.scalar(i)); \ - leValid = false; \ - } \ - } \ - REQUIRE(leValid); \ - \ - auto eqResult = (testA == testB).eval(); \ - bool eqValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(eqResult.scalar(i) == (testA.scalar(i) == testB.scalar(i)))) { \ - REQUIRE(eqResult.scalar(i) == (testA.scalar(i) == testB.scalar(i))); \ - eqValid = false; \ - } \ - } \ - REQUIRE(eqValid); \ - \ - auto neResult = (testA != testB).eval(); \ - bool neValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(neResult.scalar(i) == (testA.scalar(i) != testB.scalar(i)))) { \ - REQUIRE(neResult.scalar(i) == (testA.scalar(i) != testB.scalar(i))); \ - neValid = false; \ - } \ - } \ - REQUIRE(neValid); \ - } \ - do { \ - } while (false) + SECTION( \ + fmt::format("Test Array Comparisons [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ + lrc::Array::ShapeType shape({53, 79}); \ + lrc::Array testA(shape); /* Prime-dimensioned to force wrapping */ \ + lrc::Array testB(shape); \ + \ + for (int64_t i = 0; i < shape[0]; ++i) { \ + for (int64_t j = 0; j < shape[1]; ++j) { \ + SCALAR a = j + i * shape[1] + 1; \ + SCALAR b = i + j * shape[0] + 1; \ + \ + testA[i][j] = a; \ + testB[i][j] = b != 0 ? b : 1; \ + } \ + } \ + \ + auto gtResult = (testA > testB).eval(); \ + bool gtValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(gtResult.scalar(i) == testA.scalar(i) > testB.scalar(i))) { \ + REQUIRE(gtResult.scalar(i) == testA.scalar(i) > testB.scalar(i)); \ + gtValid = false; \ + } \ + } \ + REQUIRE(gtValid); \ + \ + auto geResult = (testA >= testB).eval(); \ + bool geValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(geResult.scalar(i) == testA.scalar(i) >= testB.scalar(i))) { \ + REQUIRE(geResult.scalar(i) == testA.scalar(i) >= testB.scalar(i)); \ + geValid = false; \ + } \ + } \ + REQUIRE(geValid); \ + \ + auto ltResult = (testA < testB).eval(); \ + bool ltValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(ltResult.scalar(i) == testA.scalar(i) < testB.scalar(i))) { \ + REQUIRE(ltResult.scalar(i) == testA.scalar(i) < testB.scalar(i)); \ + ltValid = false; \ + } \ + } \ + REQUIRE(ltValid); \ + \ + auto leResult = (testA <= testB).eval(); \ + bool leValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(leResult.scalar(i) == testA.scalar(i) <= testB.scalar(i))) { \ + REQUIRE(leResult.scalar(i) == testA.scalar(i) <= testB.scalar(i)); \ + leValid = false; \ + } \ + } \ + REQUIRE(leValid); \ + \ + auto eqResult = (testA == testB).eval(); \ + bool eqValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(eqResult.scalar(i) == (testA.scalar(i) == testB.scalar(i)))) { \ + REQUIRE(eqResult.scalar(i) == (testA.scalar(i) == testB.scalar(i))); \ + eqValid = false; \ + } \ + } \ + REQUIRE(eqValid); \ + \ + auto neResult = (testA != testB).eval(); \ + bool neValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(neResult.scalar(i) == (testA.scalar(i) != testB.scalar(i)))) { \ + REQUIRE(neResult.scalar(i) == (testA.scalar(i) != testB.scalar(i))); \ + neValid = false; \ + } \ + } \ + REQUIRE(neValid); \ + } \ + do { \ + } while (false) #define TEST_COMPARISONS_ARRAY_SCALAR(SCALAR, BACKEND) \ - SECTION(fmt::format( \ - "Test Array-Scalar Comparisons [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ - lrc::Array::ShapeType shape({53, 79}); \ - lrc::Array testA(shape); /* Prime-dimensioned to force wrapping */ \ - \ - for (int64_t i = 0; i < shape[0]; ++i) { \ - for (int64_t j = 0; j < shape[1]; ++j) { \ - SCALAR a = j + i * shape[1] + 1; \ - testA[i][j] = a; \ - } \ - } \ - \ - auto gtResult = (testA > SCALAR(64)).eval(); \ - bool gtValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(gtResult.scalar(i) == testA.scalar(i) > SCALAR(64))) { \ - REQUIRE(gtResult.scalar(i) == testA.scalar(i) > SCALAR(64)); \ - gtValid = false; \ - } \ - } \ - REQUIRE(gtValid); \ - \ - auto geResult = (testA >= SCALAR(64)).eval(); \ - bool geValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(geResult.scalar(i) == testA.scalar(i) >= SCALAR(64))) { \ - REQUIRE(geResult.scalar(i) == testA.scalar(i) >= SCALAR(64)); \ - geValid = false; \ - } \ - } \ - REQUIRE(geValid); \ - \ - auto ltResult = (testA < SCALAR(64)).eval(); \ - bool ltValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(ltResult.scalar(i) == testA.scalar(i) < SCALAR(64))) { \ - REQUIRE(ltResult.scalar(i) == testA.scalar(i) < SCALAR(64)); \ - ltValid = false; \ - } \ - } \ - REQUIRE(ltValid); \ - \ - auto leResult = (testA <= SCALAR(64)).eval(); \ - bool leValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(leResult.scalar(i) == testA.scalar(i) <= SCALAR(64))) { \ - REQUIRE(leResult.scalar(i) == testA.scalar(i) <= SCALAR(64)); \ - leValid = false; \ - } \ - } \ - REQUIRE(leValid); \ - \ - auto eqResult = (testA == SCALAR(64)).eval(); \ - bool eqValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(eqResult.scalar(i) == (testA.scalar(i) == SCALAR(64)))) { \ - REQUIRE(eqResult.scalar(i) == (testA.scalar(i) == SCALAR(64))); \ - eqValid = false; \ - } \ - } \ - REQUIRE(eqValid); \ - \ - auto neResult = (testA != SCALAR(64)).eval(); \ - bool neValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(neResult.scalar(i) == (testA.scalar(i) != SCALAR(64)))) { \ - REQUIRE(neResult.scalar(i) == (testA.scalar(i) != SCALAR(64))); \ - neValid = false; \ - } \ - } \ - REQUIRE(neValid); \ - } \ - do { \ - } while (false) + SECTION(fmt::format( \ + "Test Array-Scalar Comparisons [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ + lrc::Array::ShapeType shape({53, 79}); \ + lrc::Array testA(shape); /* Prime-dimensioned to force wrapping */ \ + \ + for (int64_t i = 0; i < shape[0]; ++i) { \ + for (int64_t j = 0; j < shape[1]; ++j) { \ + SCALAR a = j + i * shape[1] + 1; \ + testA[i][j] = a; \ + } \ + } \ + \ + auto gtResult = (testA > SCALAR(64)).eval(); \ + bool gtValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(gtResult.scalar(i) == testA.scalar(i) > SCALAR(64))) { \ + REQUIRE(gtResult.scalar(i) == testA.scalar(i) > SCALAR(64)); \ + gtValid = false; \ + } \ + } \ + REQUIRE(gtValid); \ + \ + auto geResult = (testA >= SCALAR(64)).eval(); \ + bool geValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(geResult.scalar(i) == testA.scalar(i) >= SCALAR(64))) { \ + REQUIRE(geResult.scalar(i) == testA.scalar(i) >= SCALAR(64)); \ + geValid = false; \ + } \ + } \ + REQUIRE(geValid); \ + \ + auto ltResult = (testA < SCALAR(64)).eval(); \ + bool ltValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(ltResult.scalar(i) == testA.scalar(i) < SCALAR(64))) { \ + REQUIRE(ltResult.scalar(i) == testA.scalar(i) < SCALAR(64)); \ + ltValid = false; \ + } \ + } \ + REQUIRE(ltValid); \ + \ + auto leResult = (testA <= SCALAR(64)).eval(); \ + bool leValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(leResult.scalar(i) == testA.scalar(i) <= SCALAR(64))) { \ + REQUIRE(leResult.scalar(i) == testA.scalar(i) <= SCALAR(64)); \ + leValid = false; \ + } \ + } \ + REQUIRE(leValid); \ + \ + auto eqResult = (testA == SCALAR(64)).eval(); \ + bool eqValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(eqResult.scalar(i) == (testA.scalar(i) == SCALAR(64)))) { \ + REQUIRE(eqResult.scalar(i) == (testA.scalar(i) == SCALAR(64))); \ + eqValid = false; \ + } \ + } \ + REQUIRE(eqValid); \ + \ + auto neResult = (testA != SCALAR(64)).eval(); \ + bool neValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(neResult.scalar(i) == (testA.scalar(i) != SCALAR(64)))) { \ + REQUIRE(neResult.scalar(i) == (testA.scalar(i) != SCALAR(64))); \ + neValid = false; \ + } \ + } \ + REQUIRE(neValid); \ + } \ + do { \ + } while (false) #define TEST_COMPARISONS_SCALAR_ARRAY(SCALAR, BACKEND) \ - SECTION(fmt::format( \ - "Test Scalar-Array Comparisons [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ - lrc::Array::ShapeType shape({53, 79}); \ - lrc::Array testA(shape); /* Prime-dimensioned to force wrapping */ \ - \ - for (int64_t i = 0; i < shape[0]; ++i) { \ - for (int64_t j = 0; j < shape[1]; ++j) { \ - SCALAR a = j + i * shape[1] + 1; \ - testA[i][j] = a; \ - } \ - } \ - \ - auto gtResult = (SCALAR(64) > testA).eval(); \ - bool gtValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(gtResult.scalar(i) == SCALAR(64) > testA.scalar(i))) { \ - REQUIRE(gtResult.scalar(i) == SCALAR(64) > testA.scalar(i)); \ - gtValid = false; \ - } \ - } \ - REQUIRE(gtValid); \ - \ - auto geResult = (SCALAR(64) >= testA).eval(); \ - bool geValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(geResult.scalar(i) == SCALAR(64) >= testA.scalar(i))) { \ - REQUIRE(geResult.scalar(i) == SCALAR(64) >= testA.scalar(i)); \ - geValid = false; \ - } \ - } \ - REQUIRE(geValid); \ - \ - auto ltResult = (SCALAR(64) < testA).eval(); \ - bool ltValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(ltResult.scalar(i) == SCALAR(64) < testA.scalar(i))) { \ - REQUIRE(ltResult.scalar(i) == SCALAR(64) < testA.scalar(i)); \ - ltValid = false; \ - } \ - } \ - REQUIRE(ltValid); \ - \ - auto leResult = (SCALAR(64) <= testA).eval(); \ - bool leValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(leResult.scalar(i) == SCALAR(64) <= testA.scalar(i))) { \ - REQUIRE(leResult.scalar(i) == SCALAR(64) <= testA.scalar(i)); \ - leValid = false; \ - } \ - } \ - REQUIRE(leValid); \ - \ - auto eqResult = (SCALAR(64) == testA).eval(); \ - bool eqValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(eqResult.scalar(i) == (SCALAR(64) == testA.scalar(i)))) { \ - REQUIRE(eqResult.scalar(i) == (SCALAR(64) == testA.scalar(i))); \ - eqValid = false; \ - } \ - } \ - REQUIRE(eqValid); \ - \ - auto neResult = (SCALAR(64) != testA).eval(); \ - bool neValid = true; \ - for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ - if (!(neResult.scalar(i) == (SCALAR(64) != testA.scalar(i)))) { \ - REQUIRE(neResult.scalar(i) == (SCALAR(64) != testA.scalar(i))); \ - neValid = false; \ - } \ - } \ - REQUIRE(neValid); \ - } \ - do { \ - } while (false) + SECTION(fmt::format( \ + "Test Scalar-Array Comparisons [{} | {}]", STRINGIFY(SCALAR), STRINGIFY(BACKEND))) { \ + lrc::Array::ShapeType shape({53, 79}); \ + lrc::Array testA(shape); /* Prime-dimensioned to force wrapping */ \ + \ + for (int64_t i = 0; i < shape[0]; ++i) { \ + for (int64_t j = 0; j < shape[1]; ++j) { \ + SCALAR a = j + i * shape[1] + 1; \ + testA[i][j] = a; \ + } \ + } \ + \ + auto gtResult = (SCALAR(64) > testA).eval(); \ + bool gtValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(gtResult.scalar(i) == SCALAR(64) > testA.scalar(i))) { \ + REQUIRE(gtResult.scalar(i) == SCALAR(64) > testA.scalar(i)); \ + gtValid = false; \ + } \ + } \ + REQUIRE(gtValid); \ + \ + auto geResult = (SCALAR(64) >= testA).eval(); \ + bool geValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(geResult.scalar(i) == SCALAR(64) >= testA.scalar(i))) { \ + REQUIRE(geResult.scalar(i) == SCALAR(64) >= testA.scalar(i)); \ + geValid = false; \ + } \ + } \ + REQUIRE(geValid); \ + \ + auto ltResult = (SCALAR(64) < testA).eval(); \ + bool ltValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(ltResult.scalar(i) == SCALAR(64) < testA.scalar(i))) { \ + REQUIRE(ltResult.scalar(i) == SCALAR(64) < testA.scalar(i)); \ + ltValid = false; \ + } \ + } \ + REQUIRE(ltValid); \ + \ + auto leResult = (SCALAR(64) <= testA).eval(); \ + bool leValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(leResult.scalar(i) == SCALAR(64) <= testA.scalar(i))) { \ + REQUIRE(leResult.scalar(i) == SCALAR(64) <= testA.scalar(i)); \ + leValid = false; \ + } \ + } \ + REQUIRE(leValid); \ + \ + auto eqResult = (SCALAR(64) == testA).eval(); \ + bool eqValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(eqResult.scalar(i) == (SCALAR(64) == testA.scalar(i)))) { \ + REQUIRE(eqResult.scalar(i) == (SCALAR(64) == testA.scalar(i))); \ + eqValid = false; \ + } \ + } \ + REQUIRE(eqValid); \ + \ + auto neResult = (SCALAR(64) != testA).eval(); \ + bool neValid = true; \ + for (int64_t i = 0; i < shape[0] * shape[1]; ++i) { \ + if (!(neResult.scalar(i) == (SCALAR(64) != testA.scalar(i)))) { \ + REQUIRE(neResult.scalar(i) == (SCALAR(64) != testA.scalar(i))); \ + neValid = false; \ + } \ + } \ + REQUIRE(neValid); \ + } \ + do { \ + } while (false) #define TEST_ALL(SCALAR, BACKEND) \ - TEST_CONSTRUCTORS(SCALAR, BACKEND); \ - TEST_INDEXING(SCALAR, BACKEND); \ - TEST_STRING_FORMATTING(SCALAR, BACKEND); \ - TEST_ARITHMETIC(SCALAR, BACKEND); \ - TEST_ARITHMETIC_ARRAY_SCALAR(SCALAR, BACKEND); \ - TEST_ARITHMETIC_SCALAR_ARRAY(SCALAR, BACKEND); \ - TEST_COMPARISONS(SCALAR, BACKEND); \ - TEST_COMPARISONS_ARRAY_SCALAR(SCALAR, BACKEND); \ - TEST_COMPARISONS_SCALAR_ARRAY(SCALAR, BACKEND); + TEST_CONSTRUCTORS(SCALAR, BACKEND); \ + TEST_INDEXING(SCALAR, BACKEND); \ + TEST_STRING_FORMATTING(SCALAR, BACKEND); \ + TEST_ARITHMETIC(SCALAR, BACKEND); \ + TEST_ARITHMETIC_ARRAY_SCALAR(SCALAR, BACKEND); \ + TEST_ARITHMETIC_SCALAR_ARRAY(SCALAR, BACKEND); \ + TEST_COMPARISONS(SCALAR, BACKEND); \ + TEST_COMPARISONS_ARRAY_SCALAR(SCALAR, BACKEND); \ + TEST_COMPARISONS_SCALAR_ARRAY(SCALAR, BACKEND); TEST_CASE("Test Array -- int8_t CPU", "[array-lib]") { - TEST_CONSTRUCTORS(int8_t, CPU); - TEST_INDEXING(int8_t, CPU); - TEST_STRING_FORMATTING(int8_t, CPU); + TEST_CONSTRUCTORS(int8_t, CPU); + TEST_INDEXING(int8_t, CPU); + TEST_STRING_FORMATTING(int8_t, CPU); } TEST_CASE("Test Array -- uint8_t CPU", "[array-lib]") { - TEST_CONSTRUCTORS(uint8_t, CPU); - TEST_INDEXING(uint8_t, CPU); - TEST_STRING_FORMATTING(uint8_t, CPU); + TEST_CONSTRUCTORS(uint8_t, CPU); + TEST_INDEXING(uint8_t, CPU); + TEST_STRING_FORMATTING(uint8_t, CPU); } TEST_CASE("Test Array -- int16_t CPU", "[array-lib]") { - TEST_CONSTRUCTORS(int16_t, CPU); - TEST_INDEXING(int16_t, CPU); - TEST_STRING_FORMATTING(int16_t, CPU); + TEST_CONSTRUCTORS(int16_t, CPU); + TEST_INDEXING(int16_t, CPU); + TEST_STRING_FORMATTING(int16_t, CPU); } TEST_CASE("Test Array -- uint16_t CPU", "[array-lib]") { - TEST_CONSTRUCTORS(uint16_t, CPU); - TEST_INDEXING(uint16_t, CPU); - TEST_STRING_FORMATTING(uint16_t, CPU); + TEST_CONSTRUCTORS(uint16_t, CPU); + TEST_INDEXING(uint16_t, CPU); + TEST_STRING_FORMATTING(uint16_t, CPU); } TEST_CASE("Test Array -- int32_t CPU", "[array-lib]") { TEST_ALL(int32_t, CPU); } @@ -664,27 +664,27 @@ TEST_CASE("Test Array -- double OpenCL", "[array-lib]") { TEST_ALL(double, OPENC #if defined(LIBRAPID_HAS_CUDA) TEST_CASE("Test Array -- int8_t CUDA", "[array-lib]") { - TEST_CONSTRUCTORS(int8_t, CUDA); - TEST_INDEXING(int8_t, CUDA); - TEST_STRING_FORMATTING(int8_t, CUDA); + TEST_CONSTRUCTORS(int8_t, CUDA); + TEST_INDEXING(int8_t, CUDA); + TEST_STRING_FORMATTING(int8_t, CUDA); } TEST_CASE("Test Array -- uint8_t CUDA", "[array-lib]") { - TEST_CONSTRUCTORS(uint8_t, CUDA); - TEST_INDEXING(uint8_t, CUDA); - TEST_STRING_FORMATTING(uint8_t, CUDA); + TEST_CONSTRUCTORS(uint8_t, CUDA); + TEST_INDEXING(uint8_t, CUDA); + TEST_STRING_FORMATTING(uint8_t, CUDA); } TEST_CASE("Test Array -- int16_t CUDA", "[array-lib]") { - TEST_CONSTRUCTORS(int16_t, CUDA); - TEST_INDEXING(int16_t, CUDA); - TEST_STRING_FORMATTING(int16_t, CUDA); + TEST_CONSTRUCTORS(int16_t, CUDA); + TEST_INDEXING(int16_t, CUDA); + TEST_STRING_FORMATTING(int16_t, CUDA); } TEST_CASE("Test Array -- uint16_t CUDA", "[array-lib]") { - TEST_CONSTRUCTORS(uint16_t, CUDA); - TEST_INDEXING(uint16_t, CUDA); - TEST_STRING_FORMATTING(uint16_t, CUDA); + TEST_CONSTRUCTORS(uint16_t, CUDA); + TEST_INDEXING(uint16_t, CUDA); + TEST_STRING_FORMATTING(uint16_t, CUDA); } // TEST_CASE("Test Array -- int8_t CUDA", "[array-lib]") { TEST_ALL(int8_t, CUDA); } diff --git a/test/test-arrayOps.cpp b/test/test-arrayOps.cpp index a70958d4..3c331c40 100644 --- a/test/test-arrayOps.cpp +++ b/test/test-arrayOps.cpp @@ -3,47 +3,47 @@ #include #include -namespace lrc = librapid; +namespace lrc = librapid; constexpr double tolerance = 1e-5; -using CPU = lrc::backend::CPU; -using OPENCL = lrc::backend::OpenCL; -using CUDA = lrc::backend::CUDA; +using CPU = lrc::backend::CPU; +using OPENCL = lrc::backend::OpenCL; +using CUDA = lrc::backend::CUDA; // #define SCALAR float // #define BACKEND CPU #define TEST_OP(NAME, SCALAR) \ - auto NAME##X = lrc::NAME(x).eval(); \ - for (int i = 0; i < NAME##X.shape().size(); ++i) { \ - REQUIRE(lrc::isClose((SCALAR)NAME##X(i), (SCALAR)lrc::NAME((SCALAR)x(i)), tolerance)); \ - } + auto NAME##X = lrc::NAME(x).eval(); \ + for (int i = 0; i < NAME##X.shape().size(); ++i) { \ + REQUIRE(lrc::isClose((SCALAR)NAME##X(i), (SCALAR)lrc::NAME((SCALAR)x(i)), tolerance)); \ + } #define TRIG_TEST_IMPL(SCALAR, BACKEND) \ - TEST_CASE(fmt::format("Test Trigonometry -- {} {}", STRINGIFY(SCALAR), STRINGIFY(BACKEND)), \ - "[array-lib]") { \ - /* Valid range for all functions */ \ - auto x = lrc::linspace(0.1, 0.5, 100, false); \ + TEST_CASE(fmt::format("Test Trigonometry -- {} {}", STRINGIFY(SCALAR), STRINGIFY(BACKEND)), \ + "[array-lib]") { \ + /* Valid range for all functions */ \ + auto x = lrc::linspace(0.1, 0.5, 100, false); \ \ - TEST_OP(sin, SCALAR); \ - TEST_OP(cos, SCALAR); \ - TEST_OP(tan, SCALAR); \ - TEST_OP(asin, SCALAR); \ - TEST_OP(acos, SCALAR); \ - TEST_OP(atan, SCALAR); \ - TEST_OP(sinh, SCALAR); \ - TEST_OP(cosh, SCALAR); \ - TEST_OP(tanh, SCALAR); \ + TEST_OP(sin, SCALAR); \ + TEST_OP(cos, SCALAR); \ + TEST_OP(tan, SCALAR); \ + TEST_OP(asin, SCALAR); \ + TEST_OP(acos, SCALAR); \ + TEST_OP(atan, SCALAR); \ + TEST_OP(sinh, SCALAR); \ + TEST_OP(cosh, SCALAR); \ + TEST_OP(tanh, SCALAR); \ \ - TEST_OP(exp, SCALAR); \ - TEST_OP(log, SCALAR); \ - TEST_OP(log2, SCALAR); \ - TEST_OP(log10, SCALAR); \ - TEST_OP(sqrt, SCALAR); \ - TEST_OP(cbrt, SCALAR); \ - TEST_OP(abs, SCALAR); \ - TEST_OP(floor, SCALAR); \ - TEST_OP(ceil, SCALAR); \ - } + TEST_OP(exp, SCALAR); \ + TEST_OP(log, SCALAR); \ + TEST_OP(log2, SCALAR); \ + TEST_OP(log10, SCALAR); \ + TEST_OP(sqrt, SCALAR); \ + TEST_OP(cbrt, SCALAR); \ + TEST_OP(abs, SCALAR); \ + TEST_OP(floor, SCALAR); \ + TEST_OP(ceil, SCALAR); \ + } TRIG_TEST_IMPL(float, CPU) TRIG_TEST_IMPL(double, CPU) diff --git a/test/test-arrayView.cpp b/test/test-arrayView.cpp index 5ce57b1b..a47f00f0 100644 --- a/test/test-arrayView.cpp +++ b/test/test-arrayView.cpp @@ -3,76 +3,76 @@ #include #include -namespace lrc = librapid; +namespace lrc = librapid; constexpr double tolerance = 0.001; // #define SCALAR float // #define BACKEND lrc::backend::CPU #define TEST_ARRAY_VIEW(SCALAR, BACKEND) \ - TEST_CASE(fmt::format("Test ArrayView -- {} {}", STRINGIFY(SCALAR), STRINGIFY(BACKEND)), \ - "[array-lib]") { \ - lrc::Shape shape({7, 11}); \ - lrc::Array testArr(shape); \ + TEST_CASE(fmt::format("Test ArrayView -- {} {}", STRINGIFY(SCALAR), STRINGIFY(BACKEND)), \ + "[array-lib]") { \ + lrc::Shape shape({7, 11}); \ + lrc::Array testArr(shape); \ \ - for (int64_t i = 0; i < testArr.shape().size(); ++i) { testArr.storage()[i] = i; } \ + for (int64_t i = 0; i < testArr.shape().size(); ++i) { testArr.storage()[i] = i; } \ \ - auto testView = lrc::array::ArrayView(testArr); \ - auto testViewCopy = lrc::array::ArrayView(testView); \ - auto testViewMoveView = lrc::array::ArrayView(lrc::array::ArrayView(testArr)); \ + auto testView = lrc::array::ArrayView(testArr); \ + auto testViewCopy = lrc::array::ArrayView(testView); \ + auto testViewMoveView = lrc::array::ArrayView(lrc::array::ArrayView(testArr)); \ \ - REQUIRE(testView.ndim() == 2); \ - REQUIRE(testViewCopy.ndim() == 2); \ - REQUIRE(testViewMoveView.ndim() == 2); \ + REQUIRE(testView.ndim() == 2); \ + REQUIRE(testViewCopy.ndim() == 2); \ + REQUIRE(testViewMoveView.ndim() == 2); \ \ - REQUIRE(testView.shape() == shape); \ - REQUIRE(testViewCopy.shape() == shape); \ - REQUIRE(testViewMoveView.shape() == shape); \ + REQUIRE(testView.shape() == shape); \ + REQUIRE(testViewCopy.shape() == shape); \ + REQUIRE(testViewMoveView.shape() == shape); \ \ - auto checkValues = [](const auto &view) { \ - if (view.ndim() == 2) { \ - for (int64_t row = 0; row < view.shape()[0]; ++row) { \ - for (int64_t col = 0; col < view.shape()[1]; ++col) { \ - REQUIRE(view[row][col].get() == row * view.shape()[1] + col); \ - } \ - } \ - } else if (view.ndim() == 3) { \ - for (int64_t row = 0; row < view.shape()[0]; ++row) { \ - for (int64_t col = 0; col < view.shape()[1]; ++col) { \ - for (int64_t depth = 0; depth < view.shape()[2]; ++depth) { \ - REQUIRE(view[row][col][depth].get() == \ - row * view.shape()[1] * view.shape()[2] + \ - col * view.shape()[2] + depth); \ - } \ - } \ - } \ - } else { \ - REQUIRE(true); \ - } \ - }; \ + auto checkValues = [](const auto &view) { \ + if (view.ndim() == 2) { \ + for (int64_t row = 0; row < view.shape()[0]; ++row) { \ + for (int64_t col = 0; col < view.shape()[1]; ++col) { \ + REQUIRE(view[row][col].get() == row * view.shape()[1] + col); \ + } \ + } \ + } else if (view.ndim() == 3) { \ + for (int64_t row = 0; row < view.shape()[0]; ++row) { \ + for (int64_t col = 0; col < view.shape()[1]; ++col) { \ + for (int64_t depth = 0; depth < view.shape()[2]; ++depth) { \ + REQUIRE(view[row][col][depth].get() == \ + row * view.shape()[1] * view.shape()[2] + \ + col * view.shape()[2] + depth); \ + } \ + } \ + } \ + } else { \ + REQUIRE(true); \ + } \ + }; \ \ - checkValues(testView); \ - checkValues(testViewCopy); \ - checkValues(testViewMoveView); \ + checkValues(testView); \ + checkValues(testViewCopy); \ + checkValues(testViewMoveView); \ \ - auto evalTest = testView.eval(); \ - auto evalTestCopy = testViewCopy.eval(); \ - auto evalTestMoveView = testViewMoveView.eval(); \ + auto evalTest = testView.eval(); \ + auto evalTestCopy = testViewCopy.eval(); \ + auto evalTestMoveView = testViewMoveView.eval(); \ \ - REQUIRE(evalTest.ndim() == 2); \ - REQUIRE(evalTestCopy.ndim() == 2); \ - REQUIRE(evalTestMoveView.ndim() == 2); \ + REQUIRE(evalTest.ndim() == 2); \ + REQUIRE(evalTestCopy.ndim() == 2); \ + REQUIRE(evalTestMoveView.ndim() == 2); \ \ - REQUIRE(evalTest.shape() == shape); \ - REQUIRE(evalTestCopy.shape() == shape); \ - REQUIRE(evalTestMoveView.shape() == shape); \ + REQUIRE(evalTest.shape() == shape); \ + REQUIRE(evalTestCopy.shape() == shape); \ + REQUIRE(evalTestMoveView.shape() == shape); \ \ - for (int64_t i = 0; i < evalTest.shape().size(); ++i) { \ - REQUIRE(evalTest.storage()[i] == i); \ - REQUIRE(evalTestCopy.storage()[i] == i); \ - REQUIRE(evalTestMoveView.storage()[i] == i); \ - } \ - } + for (int64_t i = 0; i < evalTest.shape().size(); ++i) { \ + REQUIRE(evalTest.storage()[i] == i); \ + REQUIRE(evalTestCopy.storage()[i] == i); \ + REQUIRE(evalTestMoveView.storage()[i] == i); \ + } \ + } // TEST_ARRAY_VIEW(int8_t, lrc::backend::CPU) TEST_ARRAY_VIEW(int16_t, lrc::backend::CPU) diff --git a/test/test-complex.cpp b/test/test-complex.cpp index f4d163a1..d69eb7b5 100644 --- a/test/test-complex.cpp +++ b/test/test-complex.cpp @@ -3,333 +3,333 @@ #include #include -namespace lrc = librapid; +namespace lrc = librapid; static double tolerance = 1e-5; // using SCALAR = double; #define TEST_COMPLEX(SCALAR) \ - TEST_CASE(fmt::format("Test Complex {}", STRINGIFY(SCALAR)), "[math]") { \ - SECTION("Constructors") { \ - lrc::Complex z1; \ - REQUIRE(z1.real() == 0); \ - REQUIRE(z1.imag() == 0); \ - \ - lrc::Complex z2(1, 2); \ - REQUIRE(z2.real() == 1); \ - REQUIRE(z2.imag() == 2); \ - \ - lrc::Complex z3(z2); \ - REQUIRE(z3.real() == 1); \ - REQUIRE(z3.imag() == 2); \ - \ - lrc::Complex z4 = z2; \ - REQUIRE(z4.real() == 1); \ - REQUIRE(z4.imag() == 2); \ - \ - lrc::Complex z5 = {1, 2}; \ - REQUIRE(z5.real() == 1); \ - REQUIRE(z5.imag() == 2); \ - \ - lrc::Complex z6(1); \ - REQUIRE(z6.real() == 1); \ - REQUIRE(z6.imag() == 0); \ - \ - lrc::Complex z7(lrc::Complex(1, 2)); \ - REQUIRE(z7.real() == 1); \ - REQUIRE(z7.imag() == 2); \ - \ - z7 = 123; \ - REQUIRE(z7.real() == 123); \ - REQUIRE(z7.imag() == 0); \ - \ - z1.real(5); \ - z1.imag(10); \ - REQUIRE(z1.real() == 5); \ - REQUIRE(z1.imag() == 10); \ - \ - REQUIRE(lrc::real(z1) == z1.real()); \ - REQUIRE(lrc::imag(z1) == z1.imag()); \ - } \ - \ - SECTION("Inplace Arithmetic") { \ - lrc::Complex z1(1, 2); \ - lrc::Complex z2(3, 4); \ - \ - z1 += SCALAR(1); \ - REQUIRE(z1.real() == 2); \ - REQUIRE(z1.imag() == 2); \ - \ - z1 -= SCALAR(1); \ - REQUIRE(z1.real() == 1); \ - REQUIRE(z1.imag() == 2); \ - \ - z1 *= SCALAR(2); \ - REQUIRE(z1.real() == 2); \ - REQUIRE(z1.imag() == 4); \ - \ - z1 /= SCALAR(2); \ - REQUIRE(z1.real() == 1); \ - REQUIRE(z1.imag() == 2); \ - \ - z1 += z2; \ - REQUIRE(z1.real() == 4); \ - REQUIRE(z1.imag() == 6); \ - \ - z1 -= z2; \ - REQUIRE(z1.real() == 1); \ - REQUIRE(z1.imag() == 2); \ - \ - z1 *= z2; \ - REQUIRE(z1.real() == -5); \ - REQUIRE(z1.imag() == 10); \ - \ - z1 /= z2; \ - REQUIRE(z1.real() == 1); \ - REQUIRE(z1.imag() == 2); \ - } \ - \ - SECTION("Casting") { \ - lrc::Complex z1(1, 2); \ - lrc::Complex z2(3, 4); \ - \ - REQUIRE((int)z1 == 1); \ - REQUIRE((int)z2 == 3); \ - \ - REQUIRE(lrc::Complex(z1) == lrc::Complex(1, 2)); \ - REQUIRE(lrc::Complex(z2) == lrc::Complex(3, 4)); \ - \ - REQUIRE(z1.str() == fmt::format("({}+{}j)", z1.real(), z1.imag())); \ - REQUIRE(z2.str() == fmt::format("({}+{}j)", z2.real(), z2.imag())); \ - REQUIRE((-z1).str() == fmt::format("(-{}-{}j)", z1.real(), z1.imag())); \ - REQUIRE((-z2).str() == fmt::format("(-{}-{}j)", z2.real(), z2.imag())); \ - } \ - \ - SECTION("Out-of-Place Arithmetic") { \ - lrc::Complex z1(1, 2); \ - lrc::Complex z2(3, 4); \ - \ - auto neg = -z1; \ - REQUIRE(neg.real() == -1); \ - REQUIRE(neg.imag() == -2); \ - \ - auto add1 = z1 + z2; \ - REQUIRE(add1.real() == 4); \ - REQUIRE(add1.imag() == 6); \ - \ - auto sub1 = z1 - z2; \ - REQUIRE(sub1.real() == -2); \ - REQUIRE(sub1.imag() == -2); \ - \ - auto mul1 = z1 * z2; \ - REQUIRE(mul1.real() == -5); \ - REQUIRE(mul1.imag() == 10); \ - \ - auto div1 = z1 / z2; \ - REQUIRE(lrc::isClose(div1.real(), 0.44, tolerance)); \ - REQUIRE(lrc::isClose(div1.imag(), 0.08, tolerance)); \ - \ - auto add2 = z1 + 1; \ - REQUIRE(add2.real() == 2); \ - REQUIRE(add2.imag() == 2); \ - \ - auto sub2 = z1 - 1; \ - REQUIRE(sub2.real() == 0); \ - REQUIRE(sub2.imag() == 2); \ - \ - auto mul2 = z1 * 2; \ - REQUIRE(mul2.real() == 2); \ - REQUIRE(mul2.imag() == 4); \ - \ - auto div2 = z1 / 2; \ - REQUIRE(lrc::isClose(div2.real(), 0.5, tolerance)); \ - REQUIRE(lrc::isClose(div2.imag(), 1.0, tolerance)); \ - \ - auto add3 = 1 + z1; \ - REQUIRE(add3.real() == 2); \ - REQUIRE(add3.imag() == 2); \ - \ - auto sub3 = 1 - z1; \ - REQUIRE(sub3.real() == 0); \ - REQUIRE(sub3.imag() == -2); \ - \ - auto mul3 = 2 * z1; \ - REQUIRE(mul3.real() == 2); \ - REQUIRE(mul3.imag() == 4); \ - \ - auto div3 = 2 / z1; \ - REQUIRE(lrc::isClose(div3.real(), 0.4, tolerance)); \ - REQUIRE(lrc::isClose(div3.imag(), -0.8, tolerance)); \ - } \ - \ - SECTION("Complex Functions") { \ - lrc::Complex z1(1, 2); \ - lrc::Complex z2(-3, 4); \ - \ - REQUIRE(lrc::sqrt(z2) == lrc::Complex(1, 2)); \ - REQUIRE(lrc::isClose(lrc::abs(z1), lrc::sqrt(SCALAR(5)), tolerance)); \ - REQUIRE(lrc::isClose(lrc::abs(z2), 5, tolerance)); \ - REQUIRE(lrc::conj(z1) == lrc::Complex(1, -2)); \ - REQUIRE(lrc::conj(z2) == lrc::Complex(-3, -4)); \ - \ - auto acos = lrc::acos(z1); \ - REQUIRE( \ - lrc::isClose(acos.real(), 1.143717740402420493750674808320794582795, tolerance)); \ - REQUIRE( \ - lrc::isClose(acos.imag(), -1.528570919480998161272456184793673393288, tolerance)); \ - \ - auto acosh = lrc::acosh(z1); \ - REQUIRE( \ - lrc::isClose(acosh.real(), 1.528570919480998161272456184793673393288, tolerance)); \ - REQUIRE( \ - lrc::isClose(acosh.imag(), 1.143717740402420493750674808320794582795, tolerance)); \ - \ - auto asinh = lrc::asinh(z1); \ - REQUIRE( \ - lrc::isClose(asinh.real(), 1.469351744368185273255844317361647616787, tolerance)); \ - REQUIRE( \ - lrc::isClose(asinh.imag(), 1.063440023577752056189491997089551002851, tolerance)); \ - \ - auto asin = lrc::asin(z1); \ - REQUIRE( \ - lrc::isClose(asin.real(), 0.42707858639247612548064688331895685930333, tolerance)); \ - REQUIRE( \ - lrc::isClose(asin.imag(), 1.52857091948099816127245618479367339328868, tolerance)); \ - \ - auto atanh = lrc::atanh(z1); \ - REQUIRE( \ - lrc::isClose(atanh.real(), 0.173286795139986351536318642871984660455, tolerance)); \ - REQUIRE( \ - lrc::isClose(atanh.imag(), 1.178097245096172464423491268729813577364, tolerance)); \ - \ - auto atan = lrc::atan(z1); \ - REQUIRE( \ - lrc::isClose(atan.real(), 1.338972522294493561202819911642758892643, tolerance)); \ - REQUIRE( \ - lrc::isClose(atan.imag(), 0.402359478108525093650936383865827688755, tolerance)); \ - \ - auto cosh = lrc::cosh(z1); \ - REQUIRE( \ - lrc::isClose(cosh.real(), -0.64214812471551996484480068696227878947, tolerance)); \ - REQUIRE( \ - lrc::isClose(cosh.imag(), 1.068607421382778339597440033783951588665, tolerance)); \ - \ - auto exp = lrc::exp(z1); \ - REQUIRE( \ - lrc::isClose(exp.real(), -1.131204383756813638431255255510794710628, tolerance)); \ - REQUIRE( \ - lrc::isClose(exp.imag(), 2.4717266720048189276169308935516645327361, tolerance)); \ - \ - auto exp2 = lrc::exp2(z1); \ - REQUIRE( \ - lrc::isClose(exp2.real(), 0.366913949486603353679882473618470209036, tolerance)); \ - REQUIRE( \ - lrc::isClose(exp2.imag(), 1.966055480822487441172329700685456305222, tolerance)); \ - \ - auto exp10 = lrc::exp10(z1); \ - REQUIRE( \ - lrc::isClose(exp10.real(), -1.0701348355877020772086517528518239460, tolerance)); \ - REQUIRE( \ - lrc::isClose(exp10.imag(), -9.9425756941378968736161937190915602112, tolerance)); \ - \ - auto log = lrc::log(z1); \ - REQUIRE( \ - lrc::isClose(log.real(), 0.804718956217050314461503047313945610162, tolerance)); \ - REQUIRE( \ - lrc::isClose(log.imag(), 1.107148717794090503017065460178537040070, tolerance)); \ - \ - auto log2 = lrc::log2(z1); \ - REQUIRE( \ - lrc::isClose(log2.real(), 1.1609640474436811739351597147446950879, tolerance)); \ - REQUIRE( \ - lrc::isClose(log2.imag(), 1.5972779646881088066382317418569791182, tolerance)); \ - \ - auto pow3 = lrc::pow(z1, 3); \ - REQUIRE(lrc::isClose(pow3.real(), -11, tolerance)); \ - REQUIRE(lrc::isClose(pow3.imag(), -2, tolerance)); \ - \ - auto realPow = lrc::pow(SCALAR(5), z1); \ - REQUIRE( \ - lrc::isClose(realPow.real(), -4.98507570899023509256310961483534856535, tolerance)); \ - REQUIRE( \ - lrc::isClose(realPow.imag(), -0.38603131431984235432537596434762968808, tolerance)); \ - \ - auto sinh = lrc::sinh(z1); \ - REQUIRE( \ - lrc::isClose(sinh.real(), -0.48905625904129372065865106274460904854, tolerance)); \ - REQUIRE( \ - lrc::isClose(sinh.imag(), 1.4031192506220405511576005806225837627, tolerance)); \ - \ - auto sqrt = lrc::sqrt(z1); \ - REQUIRE( \ - lrc::isClose(sqrt.real(), 1.272019649514068964252422461737491491715, tolerance)); \ - REQUIRE( \ - lrc::isClose(sqrt.imag(), 0.786151377757423286069558585842958929523, tolerance)); \ - \ - auto tanh = lrc::tanh(z1); \ - REQUIRE( \ - lrc::isClose(tanh.real(), 1.166736257240919881810070397144984248593, tolerance)); \ - REQUIRE( \ - lrc::isClose(tanh.imag(), -0.24345820118572525270261038865215160145, tolerance)); \ - \ - auto arg = lrc::arg(z1); \ - REQUIRE(lrc::isClose(arg, lrc::atan(2), tolerance)); \ - \ - auto cos = lrc::cos(z1); \ - REQUIRE( \ - lrc::isClose(cos.real(), 2.0327230070196655294363434484995142637319, tolerance)); \ - REQUIRE( \ - lrc::isClose(cos.imag(), -3.051897799151800057512115686895105452888, tolerance)); \ - \ - auto csc = lrc::csc(z1); \ - REQUIRE( \ - lrc::isClose(csc.real(), 0.22837506559968659341093330251058291161553, tolerance)); \ - REQUIRE( \ - lrc::isClose(csc.imag(), -0.1413630216124078007231203906301757072451, tolerance)); \ - \ - auto sec = lrc::sec(z1); \ - REQUIRE( \ - lrc::isClose(sec.real(), 0.15117629826557722714368596016961254310795, tolerance)); \ - REQUIRE( \ - lrc::isClose(sec.imag(), 0.22697367539372159536972826811917694791070, tolerance)); \ - \ - auto cot = lrc::cot(z1); \ - REQUIRE( \ - lrc::isClose(cot.real(), 0.032797755533752594062764546576583062934, tolerance)); \ - REQUIRE( \ - lrc::isClose(cot.imag(), -0.984329226458191029471888181689464448193, tolerance)); \ - \ - auto acsc = lrc::acsc(lrc::csc(z1)); \ - REQUIRE(lrc::isClose(acsc.real(), z1.real(), tolerance)); \ - REQUIRE(lrc::isClose(acsc.imag(), z1.imag(), tolerance)); \ - \ - auto asec = lrc::asec(lrc::sec(z1)); \ - REQUIRE(lrc::isClose(asec.real(), z1.real(), tolerance)); \ - REQUIRE(lrc::isClose(asec.imag(), z1.imag(), tolerance)); \ - \ - auto acot = lrc::acot(lrc::cot(z1)); \ - REQUIRE(lrc::isClose(acot.real(), z1.real(), tolerance)); \ - REQUIRE(lrc::isClose(acot.imag(), z1.imag(), tolerance)); \ - \ - REQUIRE(lrc::norm(z1) == 5); \ - \ - auto polar = lrc::polar(lrc::sqrt(SCALAR(5)), lrc::atan(SCALAR(2))); \ - REQUIRE(lrc::isClose(polar.real(), 1, tolerance)); \ - REQUIRE(lrc::isClose(polar.imag(), 2, tolerance)); \ - \ - auto sin = lrc::sin(z1); \ - REQUIRE( \ - lrc::isClose(sin.real(), 3.165778513216168146740734617191905538379110, tolerance)); \ - REQUIRE( \ - lrc::isClose(sin.imag(), 1.959601041421605897070352049989358278436320, tolerance)); \ - \ - auto floor = lrc::floor(lrc::Complex(1.5, 2.5)); \ - REQUIRE(floor == lrc::Complex(1, 2)); \ - \ - auto ceil = lrc::ceil(lrc::Complex(1.5, 2.5)); \ - REQUIRE(ceil == lrc::Complex(2, 3)); \ - } \ - } + TEST_CASE(fmt::format("Test Complex {}", STRINGIFY(SCALAR)), "[math]") { \ + SECTION("Constructors") { \ + lrc::Complex z1; \ + REQUIRE(z1.real() == 0); \ + REQUIRE(z1.imag() == 0); \ + \ + lrc::Complex z2(1, 2); \ + REQUIRE(z2.real() == 1); \ + REQUIRE(z2.imag() == 2); \ + \ + lrc::Complex z3(z2); \ + REQUIRE(z3.real() == 1); \ + REQUIRE(z3.imag() == 2); \ + \ + lrc::Complex z4 = z2; \ + REQUIRE(z4.real() == 1); \ + REQUIRE(z4.imag() == 2); \ + \ + lrc::Complex z5 = {1, 2}; \ + REQUIRE(z5.real() == 1); \ + REQUIRE(z5.imag() == 2); \ + \ + lrc::Complex z6(1); \ + REQUIRE(z6.real() == 1); \ + REQUIRE(z6.imag() == 0); \ + \ + lrc::Complex z7(lrc::Complex(1, 2)); \ + REQUIRE(z7.real() == 1); \ + REQUIRE(z7.imag() == 2); \ + \ + z7 = 123; \ + REQUIRE(z7.real() == 123); \ + REQUIRE(z7.imag() == 0); \ + \ + z1.real(5); \ + z1.imag(10); \ + REQUIRE(z1.real() == 5); \ + REQUIRE(z1.imag() == 10); \ + \ + REQUIRE(lrc::real(z1) == z1.real()); \ + REQUIRE(lrc::imag(z1) == z1.imag()); \ + } \ + \ + SECTION("Inplace Arithmetic") { \ + lrc::Complex z1(1, 2); \ + lrc::Complex z2(3, 4); \ + \ + z1 += SCALAR(1); \ + REQUIRE(z1.real() == 2); \ + REQUIRE(z1.imag() == 2); \ + \ + z1 -= SCALAR(1); \ + REQUIRE(z1.real() == 1); \ + REQUIRE(z1.imag() == 2); \ + \ + z1 *= SCALAR(2); \ + REQUIRE(z1.real() == 2); \ + REQUIRE(z1.imag() == 4); \ + \ + z1 /= SCALAR(2); \ + REQUIRE(z1.real() == 1); \ + REQUIRE(z1.imag() == 2); \ + \ + z1 += z2; \ + REQUIRE(z1.real() == 4); \ + REQUIRE(z1.imag() == 6); \ + \ + z1 -= z2; \ + REQUIRE(z1.real() == 1); \ + REQUIRE(z1.imag() == 2); \ + \ + z1 *= z2; \ + REQUIRE(z1.real() == -5); \ + REQUIRE(z1.imag() == 10); \ + \ + z1 /= z2; \ + REQUIRE(z1.real() == 1); \ + REQUIRE(z1.imag() == 2); \ + } \ + \ + SECTION("Casting") { \ + lrc::Complex z1(1, 2); \ + lrc::Complex z2(3, 4); \ + \ + REQUIRE((int)z1 == 1); \ + REQUIRE((int)z2 == 3); \ + \ + REQUIRE(lrc::Complex(z1) == lrc::Complex(1, 2)); \ + REQUIRE(lrc::Complex(z2) == lrc::Complex(3, 4)); \ + \ + REQUIRE(z1.str() == fmt::format("({}+{}j)", z1.real(), z1.imag())); \ + REQUIRE(z2.str() == fmt::format("({}+{}j)", z2.real(), z2.imag())); \ + REQUIRE((-z1).str() == fmt::format("(-{}-{}j)", z1.real(), z1.imag())); \ + REQUIRE((-z2).str() == fmt::format("(-{}-{}j)", z2.real(), z2.imag())); \ + } \ + \ + SECTION("Out-of-Place Arithmetic") { \ + lrc::Complex z1(1, 2); \ + lrc::Complex z2(3, 4); \ + \ + auto neg = -z1; \ + REQUIRE(neg.real() == -1); \ + REQUIRE(neg.imag() == -2); \ + \ + auto add1 = z1 + z2; \ + REQUIRE(add1.real() == 4); \ + REQUIRE(add1.imag() == 6); \ + \ + auto sub1 = z1 - z2; \ + REQUIRE(sub1.real() == -2); \ + REQUIRE(sub1.imag() == -2); \ + \ + auto mul1 = z1 * z2; \ + REQUIRE(mul1.real() == -5); \ + REQUIRE(mul1.imag() == 10); \ + \ + auto div1 = z1 / z2; \ + REQUIRE(lrc::isClose(div1.real(), 0.44, tolerance)); \ + REQUIRE(lrc::isClose(div1.imag(), 0.08, tolerance)); \ + \ + auto add2 = z1 + 1; \ + REQUIRE(add2.real() == 2); \ + REQUIRE(add2.imag() == 2); \ + \ + auto sub2 = z1 - 1; \ + REQUIRE(sub2.real() == 0); \ + REQUIRE(sub2.imag() == 2); \ + \ + auto mul2 = z1 * 2; \ + REQUIRE(mul2.real() == 2); \ + REQUIRE(mul2.imag() == 4); \ + \ + auto div2 = z1 / 2; \ + REQUIRE(lrc::isClose(div2.real(), 0.5, tolerance)); \ + REQUIRE(lrc::isClose(div2.imag(), 1.0, tolerance)); \ + \ + auto add3 = 1 + z1; \ + REQUIRE(add3.real() == 2); \ + REQUIRE(add3.imag() == 2); \ + \ + auto sub3 = 1 - z1; \ + REQUIRE(sub3.real() == 0); \ + REQUIRE(sub3.imag() == -2); \ + \ + auto mul3 = 2 * z1; \ + REQUIRE(mul3.real() == 2); \ + REQUIRE(mul3.imag() == 4); \ + \ + auto div3 = 2 / z1; \ + REQUIRE(lrc::isClose(div3.real(), 0.4, tolerance)); \ + REQUIRE(lrc::isClose(div3.imag(), -0.8, tolerance)); \ + } \ + \ + SECTION("Complex Functions") { \ + lrc::Complex z1(1, 2); \ + lrc::Complex z2(-3, 4); \ + \ + REQUIRE(lrc::sqrt(z2) == lrc::Complex(1, 2)); \ + REQUIRE(lrc::isClose(lrc::abs(z1), lrc::sqrt(SCALAR(5)), tolerance)); \ + REQUIRE(lrc::isClose(lrc::abs(z2), 5, tolerance)); \ + REQUIRE(lrc::conj(z1) == lrc::Complex(1, -2)); \ + REQUIRE(lrc::conj(z2) == lrc::Complex(-3, -4)); \ + \ + auto acos = lrc::acos(z1); \ + REQUIRE( \ + lrc::isClose(acos.real(), 1.143717740402420493750674808320794582795, tolerance)); \ + REQUIRE( \ + lrc::isClose(acos.imag(), -1.528570919480998161272456184793673393288, tolerance)); \ + \ + auto acosh = lrc::acosh(z1); \ + REQUIRE( \ + lrc::isClose(acosh.real(), 1.528570919480998161272456184793673393288, tolerance)); \ + REQUIRE( \ + lrc::isClose(acosh.imag(), 1.143717740402420493750674808320794582795, tolerance)); \ + \ + auto asinh = lrc::asinh(z1); \ + REQUIRE( \ + lrc::isClose(asinh.real(), 1.469351744368185273255844317361647616787, tolerance)); \ + REQUIRE( \ + lrc::isClose(asinh.imag(), 1.063440023577752056189491997089551002851, tolerance)); \ + \ + auto asin = lrc::asin(z1); \ + REQUIRE( \ + lrc::isClose(asin.real(), 0.42707858639247612548064688331895685930333, tolerance)); \ + REQUIRE( \ + lrc::isClose(asin.imag(), 1.52857091948099816127245618479367339328868, tolerance)); \ + \ + auto atanh = lrc::atanh(z1); \ + REQUIRE( \ + lrc::isClose(atanh.real(), 0.173286795139986351536318642871984660455, tolerance)); \ + REQUIRE( \ + lrc::isClose(atanh.imag(), 1.178097245096172464423491268729813577364, tolerance)); \ + \ + auto atan = lrc::atan(z1); \ + REQUIRE( \ + lrc::isClose(atan.real(), 1.338972522294493561202819911642758892643, tolerance)); \ + REQUIRE( \ + lrc::isClose(atan.imag(), 0.402359478108525093650936383865827688755, tolerance)); \ + \ + auto cosh = lrc::cosh(z1); \ + REQUIRE( \ + lrc::isClose(cosh.real(), -0.64214812471551996484480068696227878947, tolerance)); \ + REQUIRE( \ + lrc::isClose(cosh.imag(), 1.068607421382778339597440033783951588665, tolerance)); \ + \ + auto exp = lrc::exp(z1); \ + REQUIRE( \ + lrc::isClose(exp.real(), -1.131204383756813638431255255510794710628, tolerance)); \ + REQUIRE( \ + lrc::isClose(exp.imag(), 2.4717266720048189276169308935516645327361, tolerance)); \ + \ + auto exp2 = lrc::exp2(z1); \ + REQUIRE( \ + lrc::isClose(exp2.real(), 0.366913949486603353679882473618470209036, tolerance)); \ + REQUIRE( \ + lrc::isClose(exp2.imag(), 1.966055480822487441172329700685456305222, tolerance)); \ + \ + auto exp10 = lrc::exp10(z1); \ + REQUIRE( \ + lrc::isClose(exp10.real(), -1.0701348355877020772086517528518239460, tolerance)); \ + REQUIRE( \ + lrc::isClose(exp10.imag(), -9.9425756941378968736161937190915602112, tolerance)); \ + \ + auto log = lrc::log(z1); \ + REQUIRE( \ + lrc::isClose(log.real(), 0.804718956217050314461503047313945610162, tolerance)); \ + REQUIRE( \ + lrc::isClose(log.imag(), 1.107148717794090503017065460178537040070, tolerance)); \ + \ + auto log2 = lrc::log2(z1); \ + REQUIRE( \ + lrc::isClose(log2.real(), 1.1609640474436811739351597147446950879, tolerance)); \ + REQUIRE( \ + lrc::isClose(log2.imag(), 1.5972779646881088066382317418569791182, tolerance)); \ + \ + auto pow3 = lrc::pow(z1, 3); \ + REQUIRE(lrc::isClose(pow3.real(), -11, tolerance)); \ + REQUIRE(lrc::isClose(pow3.imag(), -2, tolerance)); \ + \ + auto realPow = lrc::pow(SCALAR(5), z1); \ + REQUIRE( \ + lrc::isClose(realPow.real(), -4.98507570899023509256310961483534856535, tolerance)); \ + REQUIRE( \ + lrc::isClose(realPow.imag(), -0.38603131431984235432537596434762968808, tolerance)); \ + \ + auto sinh = lrc::sinh(z1); \ + REQUIRE( \ + lrc::isClose(sinh.real(), -0.48905625904129372065865106274460904854, tolerance)); \ + REQUIRE( \ + lrc::isClose(sinh.imag(), 1.4031192506220405511576005806225837627, tolerance)); \ + \ + auto sqrt = lrc::sqrt(z1); \ + REQUIRE( \ + lrc::isClose(sqrt.real(), 1.272019649514068964252422461737491491715, tolerance)); \ + REQUIRE( \ + lrc::isClose(sqrt.imag(), 0.786151377757423286069558585842958929523, tolerance)); \ + \ + auto tanh = lrc::tanh(z1); \ + REQUIRE( \ + lrc::isClose(tanh.real(), 1.166736257240919881810070397144984248593, tolerance)); \ + REQUIRE( \ + lrc::isClose(tanh.imag(), -0.24345820118572525270261038865215160145, tolerance)); \ + \ + auto arg = lrc::arg(z1); \ + REQUIRE(lrc::isClose(arg, lrc::atan(2), tolerance)); \ + \ + auto cos = lrc::cos(z1); \ + REQUIRE( \ + lrc::isClose(cos.real(), 2.0327230070196655294363434484995142637319, tolerance)); \ + REQUIRE( \ + lrc::isClose(cos.imag(), -3.051897799151800057512115686895105452888, tolerance)); \ + \ + auto csc = lrc::csc(z1); \ + REQUIRE( \ + lrc::isClose(csc.real(), 0.22837506559968659341093330251058291161553, tolerance)); \ + REQUIRE( \ + lrc::isClose(csc.imag(), -0.1413630216124078007231203906301757072451, tolerance)); \ + \ + auto sec = lrc::sec(z1); \ + REQUIRE( \ + lrc::isClose(sec.real(), 0.15117629826557722714368596016961254310795, tolerance)); \ + REQUIRE( \ + lrc::isClose(sec.imag(), 0.22697367539372159536972826811917694791070, tolerance)); \ + \ + auto cot = lrc::cot(z1); \ + REQUIRE( \ + lrc::isClose(cot.real(), 0.032797755533752594062764546576583062934, tolerance)); \ + REQUIRE( \ + lrc::isClose(cot.imag(), -0.984329226458191029471888181689464448193, tolerance)); \ + \ + auto acsc = lrc::acsc(lrc::csc(z1)); \ + REQUIRE(lrc::isClose(acsc.real(), z1.real(), tolerance)); \ + REQUIRE(lrc::isClose(acsc.imag(), z1.imag(), tolerance)); \ + \ + auto asec = lrc::asec(lrc::sec(z1)); \ + REQUIRE(lrc::isClose(asec.real(), z1.real(), tolerance)); \ + REQUIRE(lrc::isClose(asec.imag(), z1.imag(), tolerance)); \ + \ + auto acot = lrc::acot(lrc::cot(z1)); \ + REQUIRE(lrc::isClose(acot.real(), z1.real(), tolerance)); \ + REQUIRE(lrc::isClose(acot.imag(), z1.imag(), tolerance)); \ + \ + REQUIRE(lrc::norm(z1) == 5); \ + \ + auto polar = lrc::polar(lrc::sqrt(SCALAR(5)), lrc::atan(SCALAR(2))); \ + REQUIRE(lrc::isClose(polar.real(), 1, tolerance)); \ + REQUIRE(lrc::isClose(polar.imag(), 2, tolerance)); \ + \ + auto sin = lrc::sin(z1); \ + REQUIRE( \ + lrc::isClose(sin.real(), 3.165778513216168146740734617191905538379110, tolerance)); \ + REQUIRE( \ + lrc::isClose(sin.imag(), 1.959601041421605897070352049989358278436320, tolerance)); \ + \ + auto floor = lrc::floor(lrc::Complex(1.5, 2.5)); \ + REQUIRE(floor == lrc::Complex(1, 2)); \ + \ + auto ceil = lrc::ceil(lrc::Complex(1.5, 2.5)); \ + REQUIRE(ceil == lrc::Complex(2, 3)); \ + } \ + } TEST_COMPLEX(float) TEST_COMPLEX(double) diff --git a/test/test-cudaStorage.cpp b/test/test-cudaStorage.cpp index 6c1b1e61..4db716f9 100644 --- a/test/test-cudaStorage.cpp +++ b/test/test-cudaStorage.cpp @@ -7,143 +7,143 @@ namespace lrc = librapid; #if defined(LIBRAPID_HAS_CUDA) -# define REGISTER_CASES(TYPE) \ - SECTION("Type: " STRINGIFY(TYPE)) { \ - using ScalarType = TYPE; \ - lrc::CudaStorage storage(5); \ - \ - REQUIRE(storage.size() == 5); \ - \ - storage[0] = 1; \ - storage[1] = 10; \ - \ - REQUIRE(storage[0] == 1); \ - REQUIRE(storage[1] == 10); \ - \ - lrc::CudaStorage storage2({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); \ - \ - REQUIRE(storage2.size() == 10); \ - REQUIRE(storage2[0] == 1); \ - REQUIRE(storage2[1] == 2); \ - REQUIRE(storage2[8] == 9); \ - REQUIRE(storage2[9] == 10); \ - \ - lrc::CudaStorage storage3(10, 1); \ - \ - REQUIRE(storage3.size() == 10); \ - REQUIRE(storage3[0] == 1); \ - REQUIRE(storage3[1] == 1); \ - REQUIRE(storage3[8] == 1); \ - REQUIRE(storage3[9] == 1); \ - \ - auto storage4 = lrc::CudaStorage(storage2); \ - \ - REQUIRE(storage4.size() == 10); \ - REQUIRE(storage4[0] == 1); \ - REQUIRE(storage4[1] == 2); \ - REQUIRE(storage4[8] == 9); \ - REQUIRE(storage4[9] == 10); \ - \ - /* storage4 = lrc::CudaStorage(100); */ \ - /* REQUIRE(storage4.size() == 100); */ \ - /* storage4[0] = 1; */ \ - /* storage4[1] = 2; */ \ - /* storage4[98] = 99; */ \ - /* storage4[99] = 100; */ \ - /* REQUIRE(storage4[0] == 1); */ \ - /* REQUIRE(storage4[1] == 2); */ \ - /* REQUIRE(storage4[98] == 99); */ \ - /* REQUIRE(storage4[99] == 100); */ \ - \ - storage4 = storage3; \ - \ - REQUIRE(storage4.size() == 10); \ - REQUIRE(storage4[0] == 1); \ - REQUIRE(storage4[1] == 1); \ - REQUIRE(storage4[8] == 1); \ - REQUIRE(storage4[9] == 1); \ - \ - lrc::CudaStorage storage6(20, 123); \ - REQUIRE(storage6.size() == 20); \ - storage6.resize(5); \ - REQUIRE(storage6.size() == 5); \ - REQUIRE(storage6[0] == 123); \ - REQUIRE(storage6[1] == 123); \ - REQUIRE(storage6[2] == 123); \ - REQUIRE(storage6[3] == 123); \ - REQUIRE(storage6[4] == 123); \ - \ - storage6.resize(10); \ - REQUIRE(storage6.size() == 10); \ - REQUIRE(storage6[0] == 123); \ - REQUIRE(storage6[1] == 123); \ - REQUIRE(storage6[2] == 123); \ - REQUIRE(storage6[3] == 123); \ - REQUIRE(storage6[4] == 123); \ - \ - storage6.resize(100, 0); \ - REQUIRE(storage6.size() == 100); \ - } +# define REGISTER_CASES(TYPE) \ + SECTION("Type: " STRINGIFY(TYPE)) { \ + using ScalarType = TYPE; \ + lrc::CudaStorage storage(5); \ + \ + REQUIRE(storage.size() == 5); \ + \ + storage[0] = 1; \ + storage[1] = 10; \ + \ + REQUIRE(storage[0] == 1); \ + REQUIRE(storage[1] == 10); \ + \ + lrc::CudaStorage storage2({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); \ + \ + REQUIRE(storage2.size() == 10); \ + REQUIRE(storage2[0] == 1); \ + REQUIRE(storage2[1] == 2); \ + REQUIRE(storage2[8] == 9); \ + REQUIRE(storage2[9] == 10); \ + \ + lrc::CudaStorage storage3(10, 1); \ + \ + REQUIRE(storage3.size() == 10); \ + REQUIRE(storage3[0] == 1); \ + REQUIRE(storage3[1] == 1); \ + REQUIRE(storage3[8] == 1); \ + REQUIRE(storage3[9] == 1); \ + \ + auto storage4 = lrc::CudaStorage(storage2); \ + \ + REQUIRE(storage4.size() == 10); \ + REQUIRE(storage4[0] == 1); \ + REQUIRE(storage4[1] == 2); \ + REQUIRE(storage4[8] == 9); \ + REQUIRE(storage4[9] == 10); \ + \ + /* storage4 = lrc::CudaStorage(100); */ \ + /* REQUIRE(storage4.size() == 100); */ \ + /* storage4[0] = 1; */ \ + /* storage4[1] = 2; */ \ + /* storage4[98] = 99; */ \ + /* storage4[99] = 100; */ \ + /* REQUIRE(storage4[0] == 1); */ \ + /* REQUIRE(storage4[1] == 2); */ \ + /* REQUIRE(storage4[98] == 99); */ \ + /* REQUIRE(storage4[99] == 100); */ \ + \ + storage4 = storage3; \ + \ + REQUIRE(storage4.size() == 10); \ + REQUIRE(storage4[0] == 1); \ + REQUIRE(storage4[1] == 1); \ + REQUIRE(storage4[8] == 1); \ + REQUIRE(storage4[9] == 1); \ + \ + lrc::CudaStorage storage6(20, 123); \ + REQUIRE(storage6.size() == 20); \ + storage6.resize(5); \ + REQUIRE(storage6.size() == 5); \ + REQUIRE(storage6[0] == 123); \ + REQUIRE(storage6[1] == 123); \ + REQUIRE(storage6[2] == 123); \ + REQUIRE(storage6[3] == 123); \ + REQUIRE(storage6[4] == 123); \ + \ + storage6.resize(10); \ + REQUIRE(storage6.size() == 10); \ + REQUIRE(storage6[0] == 123); \ + REQUIRE(storage6[1] == 123); \ + REQUIRE(storage6[2] == 123); \ + REQUIRE(storage6[3] == 123); \ + REQUIRE(storage6[4] == 123); \ + \ + storage6.resize(100, 0); \ + REQUIRE(storage6.size() == 100); \ + } -# define BENCHMARK_CONSTRUCTORS(TYPE_, FILL_) \ - BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 10") { \ - lrc::CudaStorage storage(10); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 1000") { \ - lrc::CudaStorage storage(1000); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 1000000") { \ - lrc::CudaStorage storage(1000000); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 10 FILLED") { \ - lrc::CudaStorage storage(10, FILL_); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 1000 FILLED") { \ - lrc::CudaStorage storage(1000, FILL_); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 1000000 FILLED") { \ - lrc::CudaStorage storage(1000000, FILL_); \ - return storage.size(); \ - } +# define BENCHMARK_CONSTRUCTORS(TYPE_, FILL_) \ + BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 10") { \ + lrc::CudaStorage storage(10); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 1000") { \ + lrc::CudaStorage storage(1000); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 1000000") { \ + lrc::CudaStorage storage(1000000); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 10 FILLED") { \ + lrc::CudaStorage storage(10, FILL_); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 1000 FILLED") { \ + lrc::CudaStorage storage(1000, FILL_); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("CudaStorage<" STRINGIFY(TYPE_) "> 1000000 FILLED") { \ + lrc::CudaStorage storage(1000000, FILL_); \ + return storage.size(); \ + } TEST_CASE("Test CudaStorage", "[storage]") { - SECTION("Test CudaStorage") { - REGISTER_CASES(char); - REGISTER_CASES(unsigned char); - REGISTER_CASES(short); - REGISTER_CASES(unsigned short); - REGISTER_CASES(int); - REGISTER_CASES(unsigned int); - REGISTER_CASES(long); - REGISTER_CASES(unsigned long); - REGISTER_CASES(long long); - REGISTER_CASES(unsigned long long); - REGISTER_CASES(float); - REGISTER_CASES(double); - REGISTER_CASES(long double); - } + SECTION("Test CudaStorage") { + REGISTER_CASES(char); + REGISTER_CASES(unsigned char); + REGISTER_CASES(short); + REGISTER_CASES(unsigned short); + REGISTER_CASES(int); + REGISTER_CASES(unsigned int); + REGISTER_CASES(long); + REGISTER_CASES(unsigned long); + REGISTER_CASES(long long); + REGISTER_CASES(unsigned long long); + REGISTER_CASES(float); + REGISTER_CASES(double); + REGISTER_CASES(long double); + } - SECTION("Benchmarks") { - BENCHMARK_CONSTRUCTORS(int, 123); - BENCHMARK_CONSTRUCTORS(double, 456); - } + SECTION("Benchmarks") { + BENCHMARK_CONSTRUCTORS(int, 123); + BENCHMARK_CONSTRUCTORS(double, 456); + } } #else TEST_CASE("Default", "[storage]") { - LIBRAPID_WARN("OpenCL not available, skipping tests"); - SECTION("Default") { REQUIRE(true); } + LIBRAPID_WARN("OpenCL not available, skipping tests"); + SECTION("Default") { REQUIRE(true); } } #endif // LIBRAPID_HAS_CUDA diff --git a/test/test-fixedStorage.cpp b/test/test-fixedStorage.cpp index d99aebc1..7d68772a 100644 --- a/test/test-fixedStorage.cpp +++ b/test/test-fixedStorage.cpp @@ -6,181 +6,181 @@ namespace lrc = librapid; #define REGISTER_CASES(TYPE) \ - SECTION("Type: " STRINGIFY(TYPE)) { \ - using ScalarType = TYPE; \ - lrc::FixedStorage storage; \ - \ - REQUIRE(storage.size() == 9); \ - \ - storage[0] = 1; \ - storage[1] = 10; \ - \ - REQUIRE(storage[0] == 1); \ - REQUIRE(storage[1] == 10); \ - \ - lrc::FixedStorage storage2({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); \ - \ - REQUIRE(storage2.size() == 10); \ - REQUIRE(storage2[0] == 1); \ - REQUIRE(storage2[1] == 2); \ - REQUIRE(storage2[8] == 9); \ - REQUIRE(storage2[9] == 10); \ - \ - lrc::FixedStorage storage3(1); \ - \ - REQUIRE(storage3.size() == 20); \ - REQUIRE(storage3[0] == 1); \ - REQUIRE(storage3[1] == 1); \ - REQUIRE(storage3[18] == 1); \ - REQUIRE(storage3[19] == 1); \ - \ - auto storage4 = lrc::FixedStorage(storage2); \ - \ - REQUIRE(storage4.size() == 10); \ - REQUIRE(storage4[0] == 1); \ - REQUIRE(storage4[1] == 2); \ - REQUIRE(storage4[8] == 9); \ - REQUIRE(storage4[9] == 10); \ - \ - SECTION("Const Iterator") { \ - ScalarType i = 1; \ - for (const auto &val : storage2) { \ - REQUIRE(val == i); \ - i += 1; \ - } \ - } \ - \ - SECTION("Non-Const Iterator") { \ - ScalarType i = 1; \ - for (auto &val : storage2) { \ - REQUIRE(val == i); \ - val += 1; \ - i += 1; \ - } \ - } \ - } + SECTION("Type: " STRINGIFY(TYPE)) { \ + using ScalarType = TYPE; \ + lrc::FixedStorage storage; \ + \ + REQUIRE(storage.size() == 9); \ + \ + storage[0] = 1; \ + storage[1] = 10; \ + \ + REQUIRE(storage[0] == 1); \ + REQUIRE(storage[1] == 10); \ + \ + lrc::FixedStorage storage2({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); \ + \ + REQUIRE(storage2.size() == 10); \ + REQUIRE(storage2[0] == 1); \ + REQUIRE(storage2[1] == 2); \ + REQUIRE(storage2[8] == 9); \ + REQUIRE(storage2[9] == 10); \ + \ + lrc::FixedStorage storage3(1); \ + \ + REQUIRE(storage3.size() == 20); \ + REQUIRE(storage3[0] == 1); \ + REQUIRE(storage3[1] == 1); \ + REQUIRE(storage3[18] == 1); \ + REQUIRE(storage3[19] == 1); \ + \ + auto storage4 = lrc::FixedStorage(storage2); \ + \ + REQUIRE(storage4.size() == 10); \ + REQUIRE(storage4[0] == 1); \ + REQUIRE(storage4[1] == 2); \ + REQUIRE(storage4[8] == 9); \ + REQUIRE(storage4[9] == 10); \ + \ + SECTION("Const Iterator") { \ + ScalarType i = 1; \ + for (const auto &val : storage2) { \ + REQUIRE(val == i); \ + i += 1; \ + } \ + } \ + \ + SECTION("Non-Const Iterator") { \ + ScalarType i = 1; \ + for (auto &val : storage2) { \ + REQUIRE(val == i); \ + val += 1; \ + i += 1; \ + } \ + } \ + } #define BENCHMARK_CONSTRUCTORS(TYPE_, FILL_) \ - BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 10") { \ - lrc::FixedStorage storage; \ - return storage.size(); \ - }; \ - \ - BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 100") { \ - lrc::FixedStorage storage; \ - return storage.size(); \ - }; \ - \ - BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 1000") { \ - lrc::FixedStorage storage; \ - return storage.size(); \ - }; \ - \ - BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 10000") { \ - lrc::FixedStorage storage; \ - return storage.size(); \ - }; \ - \ - BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 10 FILLED") { \ - lrc::FixedStorage storage(FILL_); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 100 FILLED") { \ - lrc::FixedStorage storage(FILL_); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 1000 FILLED") { \ - lrc::FixedStorage storage(FILL_); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 10000 FILLED") { \ - lrc::FixedStorage storage(FILL_); \ - return storage.size(); \ - } + BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 10") { \ + lrc::FixedStorage storage; \ + return storage.size(); \ + }; \ + \ + BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 100") { \ + lrc::FixedStorage storage; \ + return storage.size(); \ + }; \ + \ + BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 1000") { \ + lrc::FixedStorage storage; \ + return storage.size(); \ + }; \ + \ + BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 10000") { \ + lrc::FixedStorage storage; \ + return storage.size(); \ + }; \ + \ + BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 10 FILLED") { \ + lrc::FixedStorage storage(FILL_); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 100 FILLED") { \ + lrc::FixedStorage storage(FILL_); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 1000 FILLED") { \ + lrc::FixedStorage storage(FILL_); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("FixedStorage<" STRINGIFY(TYPE_) "> 10000 FILLED") { \ + lrc::FixedStorage storage(FILL_); \ + return storage.size(); \ + } TEST_CASE("Test FixedStorage", "[fixed-storage]") { - SECTION("Trivially Constructible Storage") { - REGISTER_CASES(char); - REGISTER_CASES(unsigned char); - REGISTER_CASES(short); - REGISTER_CASES(unsigned short); - REGISTER_CASES(int); - REGISTER_CASES(unsigned int); - REGISTER_CASES(long); - REGISTER_CASES(unsigned long); - REGISTER_CASES(long long); - REGISTER_CASES(unsigned long long); - REGISTER_CASES(float); - REGISTER_CASES(double); - REGISTER_CASES(long double); - } + SECTION("Trivially Constructible Storage") { + REGISTER_CASES(char); + REGISTER_CASES(unsigned char); + REGISTER_CASES(short); + REGISTER_CASES(unsigned short); + REGISTER_CASES(int); + REGISTER_CASES(unsigned int); + REGISTER_CASES(long); + REGISTER_CASES(unsigned long); + REGISTER_CASES(long long); + REGISTER_CASES(unsigned long long); + REGISTER_CASES(float); + REGISTER_CASES(double); + REGISTER_CASES(long double); + } - SECTION("Non-Trivially Constructible Storage") { - // Can't use normal tests, so just test a few things - lrc::FixedStorage storage; - REQUIRE(storage.size() == 5); - storage[0] = "Hello"; - storage[1] = "World"; - REQUIRE(storage[0] == "Hello"); - REQUIRE(storage[1] == "World"); + SECTION("Non-Trivially Constructible Storage") { + // Can't use normal tests, so just test a few things + lrc::FixedStorage storage; + REQUIRE(storage.size() == 5); + storage[0] = "Hello"; + storage[1] = "World"; + REQUIRE(storage[0] == "Hello"); + REQUIRE(storage[1] == "World"); - lrc::FixedStorage storage2({"Hello", "World"}); - REQUIRE(storage2.size() == 2); - REQUIRE(storage2[0] == "Hello"); - REQUIRE(storage2[1] == "World"); + lrc::FixedStorage storage2({"Hello", "World"}); + REQUIRE(storage2.size() == 2); + REQUIRE(storage2[0] == "Hello"); + REQUIRE(storage2[1] == "World"); - lrc::FixedStorage storage3("Hello"); - REQUIRE(storage3.size() == 20); - REQUIRE(storage3[0] == "Hello"); - REQUIRE(storage3[1] == "Hello"); - REQUIRE(storage3[18] == "Hello"); - REQUIRE(storage3[19] == "Hello"); + lrc::FixedStorage storage3("Hello"); + REQUIRE(storage3.size() == 20); + REQUIRE(storage3[0] == "Hello"); + REQUIRE(storage3[1] == "Hello"); + REQUIRE(storage3[18] == "Hello"); + REQUIRE(storage3[19] == "Hello"); - auto storage4 = lrc::FixedStorage, 10>(); - REQUIRE(storage4.size() == 10); - storage4[0].push_back("Hello"); - storage4[0].push_back("World"); - REQUIRE(storage4[0][0] == "Hello"); - REQUIRE(storage4[0][1] == "World"); + auto storage4 = lrc::FixedStorage, 10>(); + REQUIRE(storage4.size() == 10); + storage4[0].push_back("Hello"); + storage4[0].push_back("World"); + REQUIRE(storage4[0][0] == "Hello"); + REQUIRE(storage4[0][1] == "World"); - struct Three { - int a; - int b; - int c; - }; + struct Three { + int a; + int b; + int c; + }; - auto storage5 = lrc::FixedStorage(); - REQUIRE(storage5.size() == 10); - storage5[0].a = 1; - storage5[0].b = 2; - storage5[0].c = 3; - storage5[1].a = 4; - storage5[1].b = 5; - storage5[1].c = 6; - REQUIRE(storage5[0].a == 1); - REQUIRE(storage5[0].b == 2); - REQUIRE(storage5[0].c == 3); - REQUIRE(storage5[1].a == 4); - REQUIRE(storage5[1].b == 5); - REQUIRE(storage5[1].c == 6); + auto storage5 = lrc::FixedStorage(); + REQUIRE(storage5.size() == 10); + storage5[0].a = 1; + storage5[0].b = 2; + storage5[0].c = 3; + storage5[1].a = 4; + storage5[1].b = 5; + storage5[1].c = 6; + REQUIRE(storage5[0].a == 1); + REQUIRE(storage5[0].b == 2); + REQUIRE(storage5[0].c == 3); + REQUIRE(storage5[1].a == 4); + REQUIRE(storage5[1].b == 5); + REQUIRE(storage5[1].c == 6); - auto storage6 = storage5; - REQUIRE(storage5[0].a == 1); - REQUIRE(storage5[0].b == 2); - REQUIRE(storage5[0].c == 3); - REQUIRE(storage5[1].a == 4); - REQUIRE(storage5[1].b == 5); - REQUIRE(storage5[1].c == 6); - } + auto storage6 = storage5; + REQUIRE(storage5[0].a == 1); + REQUIRE(storage5[0].b == 2); + REQUIRE(storage5[0].c == 3); + REQUIRE(storage5[1].a == 4); + REQUIRE(storage5[1].b == 5); + REQUIRE(storage5[1].c == 6); + } - SECTION("Benchmarks") { - BENCHMARK_CONSTRUCTORS(int, 123); - BENCHMARK_CONSTRUCTORS(double, 456); - BENCHMARK_CONSTRUCTORS(std::string, "Hello, World"); - BENCHMARK_CONSTRUCTORS(std::vector, {1 COMMA 2 COMMA 3 COMMA 4}); - BENCHMARK_CONSTRUCTORS(std::vector, {1 COMMA 2 COMMA 3 COMMA 4}); - } + SECTION("Benchmarks") { + BENCHMARK_CONSTRUCTORS(int, 123); + BENCHMARK_CONSTRUCTORS(double, 456); + BENCHMARK_CONSTRUCTORS(std::string, "Hello, World"); + BENCHMARK_CONSTRUCTORS(std::vector, {1 COMMA 2 COMMA 3 COMMA 4}); + BENCHMARK_CONSTRUCTORS(std::vector, {1 COMMA 2 COMMA 3 COMMA 4}); + } } diff --git a/test/test-mathUtilities.cpp b/test/test-mathUtilities.cpp index 5db8d197..59a80961 100644 --- a/test/test-mathUtilities.cpp +++ b/test/test-mathUtilities.cpp @@ -6,27 +6,27 @@ namespace lrc = librapid; TEST_CASE("Test Math Utilities", "[math]") { - REQUIRE(lrc::clamp(5.f, 0.f, 10.f) == 5.f); - REQUIRE(lrc::clamp(0.f, 0.f, 10.f) == 0.f); - REQUIRE(lrc::clamp(10.f, 0.f, 10.f) == 10.f); - REQUIRE(lrc::clamp(-10.f, 0.f, 10.f) == 0.f); - REQUIRE(lrc::clamp(20.f, 0.f, 10.f) == 10.f); + REQUIRE(lrc::clamp(5.f, 0.f, 10.f) == 5.f); + REQUIRE(lrc::clamp(0.f, 0.f, 10.f) == 0.f); + REQUIRE(lrc::clamp(10.f, 0.f, 10.f) == 10.f); + REQUIRE(lrc::clamp(-10.f, 0.f, 10.f) == 0.f); + REQUIRE(lrc::clamp(20.f, 0.f, 10.f) == 10.f); - REQUIRE(lrc::lerp(0.f, 0.f, 1.f) == 0.f); - REQUIRE(lrc::lerp(0.5f, 0.f, 1.f) == 0.5f); - REQUIRE(lrc::lerp(1.f, 0.f, 1.f) == 1.f); - REQUIRE(lrc::lerp(0.f, 0.f, 10.f) == 0.f); - REQUIRE(lrc::lerp(0.5f, 0.f, 10.f) == 5.f); - REQUIRE(lrc::lerp(1.f, 0.f, 10.f) == 10.f); - REQUIRE(lrc::lerp(2.f, 0.f, 10.f) == 20.f); + REQUIRE(lrc::lerp(0.f, 0.f, 1.f) == 0.f); + REQUIRE(lrc::lerp(0.5f, 0.f, 1.f) == 0.5f); + REQUIRE(lrc::lerp(1.f, 0.f, 1.f) == 1.f); + REQUIRE(lrc::lerp(0.f, 0.f, 10.f) == 0.f); + REQUIRE(lrc::lerp(0.5f, 0.f, 10.f) == 5.f); + REQUIRE(lrc::lerp(1.f, 0.f, 10.f) == 10.f); + REQUIRE(lrc::lerp(2.f, 0.f, 10.f) == 20.f); - REQUIRE(lrc::smoothStep(0.f) == 0.f); - REQUIRE(lrc::smoothStep(0.5f) == 0.5f); - REQUIRE(lrc::smoothStep(1.f) == 1.f); - REQUIRE(lrc::smoothStep(0.f, 0.f, 10.f) == 0.f); - REQUIRE(lrc::smoothStep(5.f, 0.f, 10.f) == 0.5f); - REQUIRE(lrc::smoothStep(10.f, 0.f, 10.f) == 1.f); - REQUIRE(lrc::smoothStep(20.f, 0.f, 10.f) == 1.f); - REQUIRE(lrc::smoothStep(0.25f, 0.f, 1.f) - 0.1035f < 0.001f); - REQUIRE(lrc::smoothStep(0.75f, 0.f, 1.f) - 0.8965f < 0.001f); + REQUIRE(lrc::smoothStep(0.f) == 0.f); + REQUIRE(lrc::smoothStep(0.5f) == 0.5f); + REQUIRE(lrc::smoothStep(1.f) == 1.f); + REQUIRE(lrc::smoothStep(0.f, 0.f, 10.f) == 0.f); + REQUIRE(lrc::smoothStep(5.f, 0.f, 10.f) == 0.5f); + REQUIRE(lrc::smoothStep(10.f, 0.f, 10.f) == 1.f); + REQUIRE(lrc::smoothStep(20.f, 0.f, 10.f) == 1.f); + REQUIRE(lrc::smoothStep(0.25f, 0.f, 1.f) - 0.1035f < 0.001f); + REQUIRE(lrc::smoothStep(0.75f, 0.f, 1.f) - 0.8965f < 0.001f); } diff --git a/test/test-multiprecision.cpp b/test/test-multiprecision.cpp index 05c0a94e..da0273a0 100644 --- a/test/test-multiprecision.cpp +++ b/test/test-multiprecision.cpp @@ -12,80 +12,80 @@ namespace lrc = librapid; #if defined(LIBRAPID_USE_MULTIPREC) TEST_CASE("Test Multiprecision", "[multiprecision]") { - lrc::prec(16); - REQUIRE(lrc::mpz(1) == 1); - REQUIRE(lrc::mpq(1) == 1); - REQUIRE(lrc::mpf(1) == 1); - REQUIRE(lrc::mpfr(1) == 1); - - REQUIRE(lrc::mpz(1) == lrc::mpz(1)); - REQUIRE(lrc::mpq(1) == lrc::mpq(1)); - REQUIRE(lrc::mpf(1) == lrc::mpf(1)); - REQUIRE(lrc::mpfr(1) == lrc::mpfr(1)); - - REQUIRE(fmt::format("{}", lrc::mpz(1)) == "1"); - REQUIRE(fmt::format("{}", lrc::mpq(1)) == "1"); - REQUIRE(fmt::format("{}", lrc::mpf(1)) == "1.0"); - - REQUIRE(fmt::format("{}", lrc::mpz(1234)) == "1234"); - REQUIRE(fmt::format("{}", lrc::mpq(1234)) == "1234"); - REQUIRE(fmt::format("{}", lrc::mpf(1234)) == "1234.0"); - - REQUIRE(lrc::mpz("1234") == 1234); - REQUIRE(lrc::mpq("1234") == 1234); - REQUIRE(lrc::mpq("1234/617") == 2); - REQUIRE(lrc::mpf("1234") == 1234); - REQUIRE(lrc::mpfr("1234") == 1234); - - REQUIRE(lrc::abs(lrc::sin(lrc::constPi()) - 0) < lrc::exp10(-15)); - REQUIRE(lrc::abs(lrc::cos(lrc::constPi()) + 1) < lrc::exp10(-15)); - REQUIRE(lrc::abs(lrc::tan(lrc::constPi()) - 0) < lrc::exp10(-15)); - - REQUIRE(lrc::mpz("10") + lrc::mpz("100") == 110); - REQUIRE(lrc::mpq("10") + lrc::mpq("100") == 110); - REQUIRE(lrc::mpf("10") + lrc::mpf("100") == 110); - REQUIRE(lrc::mpfr("10") + lrc::mpfr("100") == 110); - - lrc::prec(500); - REQUIRE( - lrc::constPi() == - "3." - "14159265358979323846264338327950288419716939937510582097494459230781640628620899862803482534" - "21170679821480865132823066470938446095505822317253594081284811174502841027019385211055596446" - "22948954930381964428810975665933446128475648233786783165271201909145648566923460348610454326" - "64821339360726024914127372458700660631558817488152092096282925409171536436789259036001133053" - "05488204665213841469519415116094330572703657595919530921861173819326117931051185480744623799" - "627495673518857527248912279381830119491304781324640566153215447811706812394506641467563306"); - - SECTION("Benchmarks") { - for (int64_t prec = 128; prec <= 1 << 24; prec <<= 1) { - lrc::prec2(prec); - - lrc::mpz bigInt(1); - bigInt <<= prec; - BENCHMARK(fmt::format("Integer Addition\n[{} bits]", prec)) { return bigInt + bigInt; }; - - BENCHMARK(fmt::format("Integer Multiplication\n[{} bits]", prec)) { - return bigInt * bigInt; - }; - - lrc::mpq bigRat = lrc::mpq(bigInt) / lrc::mpq(bigInt + 1); - BENCHMARK(fmt::format("Rational Addition\n[{} bits]", prec)) { - return bigRat + bigRat; - }; - - lrc::mpfr bigFloat = lrc::constPi(); - BENCHMARK(fmt::format("Floating Point Addition\n[{} bits]", prec)) { - return bigFloat + bigFloat; - }; - - BENCHMARK(fmt::format("Floating Point Multiplication\n[{} bits]", prec)) { - return bigFloat * bigFloat; - }; - - BENCHMARK(fmt::format("Pi Calculation\n[{} bits]", prec)) { return lrc::constPi(); }; - } - } + lrc::prec(16); + REQUIRE(lrc::mpz(1) == 1); + REQUIRE(lrc::mpq(1) == 1); + REQUIRE(lrc::mpf(1) == 1); + REQUIRE(lrc::mpfr(1) == 1); + + REQUIRE(lrc::mpz(1) == lrc::mpz(1)); + REQUIRE(lrc::mpq(1) == lrc::mpq(1)); + REQUIRE(lrc::mpf(1) == lrc::mpf(1)); + REQUIRE(lrc::mpfr(1) == lrc::mpfr(1)); + + REQUIRE(fmt::format("{}", lrc::mpz(1)) == "1"); + REQUIRE(fmt::format("{}", lrc::mpq(1)) == "1"); + REQUIRE(fmt::format("{}", lrc::mpf(1)) == "1.0"); + + REQUIRE(fmt::format("{}", lrc::mpz(1234)) == "1234"); + REQUIRE(fmt::format("{}", lrc::mpq(1234)) == "1234"); + REQUIRE(fmt::format("{}", lrc::mpf(1234)) == "1234.0"); + + REQUIRE(lrc::mpz("1234") == 1234); + REQUIRE(lrc::mpq("1234") == 1234); + REQUIRE(lrc::mpq("1234/617") == 2); + REQUIRE(lrc::mpf("1234") == 1234); + REQUIRE(lrc::mpfr("1234") == 1234); + + REQUIRE(lrc::abs(lrc::sin(lrc::constPi()) - 0) < lrc::exp10(-15)); + REQUIRE(lrc::abs(lrc::cos(lrc::constPi()) + 1) < lrc::exp10(-15)); + REQUIRE(lrc::abs(lrc::tan(lrc::constPi()) - 0) < lrc::exp10(-15)); + + REQUIRE(lrc::mpz("10") + lrc::mpz("100") == 110); + REQUIRE(lrc::mpq("10") + lrc::mpq("100") == 110); + REQUIRE(lrc::mpf("10") + lrc::mpf("100") == 110); + REQUIRE(lrc::mpfr("10") + lrc::mpfr("100") == 110); + + lrc::prec(500); + REQUIRE( + lrc::constPi() == + "3." + "14159265358979323846264338327950288419716939937510582097494459230781640628620899862803482534" + "21170679821480865132823066470938446095505822317253594081284811174502841027019385211055596446" + "22948954930381964428810975665933446128475648233786783165271201909145648566923460348610454326" + "64821339360726024914127372458700660631558817488152092096282925409171536436789259036001133053" + "05488204665213841469519415116094330572703657595919530921861173819326117931051185480744623799" + "627495673518857527248912279381830119491304781324640566153215447811706812394506641467563306"); + + SECTION("Benchmarks") { + for (int64_t prec = 128; prec <= 1 << 24; prec <<= 1) { + lrc::prec2(prec); + + lrc::mpz bigInt(1); + bigInt <<= prec; + BENCHMARK(fmt::format("Integer Addition\n[{} bits]", prec)) { return bigInt + bigInt; }; + + BENCHMARK(fmt::format("Integer Multiplication\n[{} bits]", prec)) { + return bigInt * bigInt; + }; + + lrc::mpq bigRat = lrc::mpq(bigInt) / lrc::mpq(bigInt + 1); + BENCHMARK(fmt::format("Rational Addition\n[{} bits]", prec)) { + return bigRat + bigRat; + }; + + lrc::mpfr bigFloat = lrc::constPi(); + BENCHMARK(fmt::format("Floating Point Addition\n[{} bits]", prec)) { + return bigFloat + bigFloat; + }; + + BENCHMARK(fmt::format("Floating Point Multiplication\n[{} bits]", prec)) { + return bigFloat * bigFloat; + }; + + BENCHMARK(fmt::format("Pi Calculation\n[{} bits]", prec)) { return lrc::constPi(); }; + } + } } #else diff --git a/test/test-openCLStorage.cpp b/test/test-openCLStorage.cpp index 12ae1527..9a0f5180 100644 --- a/test/test-openCLStorage.cpp +++ b/test/test-openCLStorage.cpp @@ -7,147 +7,147 @@ namespace lrc = librapid; #if defined(LIBRAPID_HAS_OPENCL) -# define REGISTER_CASES(TYPE) \ - SECTION("Type: " STRINGIFY(TYPE)) { \ - using ScalarType = TYPE; \ - lrc::OpenCLStorage storage(5); \ - \ - REQUIRE(storage.size() == 5); \ - \ - storage[0] = 1; \ - storage[1] = 10; \ - \ - REQUIRE(storage[0] == 1); \ - REQUIRE(storage[1] == 10); \ - \ - lrc::OpenCLStorage storage2({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); \ - \ - REQUIRE(storage2.size() == 10); \ - REQUIRE(storage2[0] == 1); \ - REQUIRE(storage2[1] == 2); \ - REQUIRE(storage2[8] == 9); \ - REQUIRE(storage2[9] == 10); \ - \ - lrc::OpenCLStorage storage3(100, 1); \ - \ - REQUIRE(storage3.size() == 100); \ - REQUIRE(storage3[0] == 1); \ - REQUIRE(storage3[1] == 1); \ - REQUIRE(storage3[98] == 1); \ - REQUIRE(storage3[99] == 1); \ - \ - auto storage4 = lrc::OpenCLStorage(storage2); \ - \ - REQUIRE(storage4.size() == 10); \ - REQUIRE(storage4[0] == 1); \ - REQUIRE(storage4[1] == 2); \ - REQUIRE(storage4[8] == 9); \ - REQUIRE(storage4[9] == 10); \ - \ - /* storage4 = lrc::OpenCLStorage(100); */ \ - /* REQUIRE(storage4.size() == 100); */ \ - /* storage4[0] = 1; */ \ - /* storage4[1] = 2; */ \ - /* storage4[98] = 99; */ \ - /* storage4[99] = 100; */ \ - /* REQUIRE(storage4[0] == 1); */ \ - /* REQUIRE(storage4[1] == 2); */ \ - /* REQUIRE(storage4[98] == 99); */ \ - /* REQUIRE(storage4[99] == 100); */ \ - \ - storage4 = storage3; \ - \ - REQUIRE(storage4.size() == 100); \ - REQUIRE(storage4[0] == 1); \ - REQUIRE(storage4[1] == 1); \ - REQUIRE(storage4[98] == 1); \ - REQUIRE(storage4[99] == 1); \ - \ - lrc::OpenCLStorage storage6(20, 123); \ - REQUIRE(storage6.size() == 20); \ - storage6.resize(5); \ - REQUIRE(storage6.size() == 5); \ - REQUIRE(storage6[0] == 123); \ - REQUIRE(storage6[1] == 123); \ - REQUIRE(storage6[2] == 123); \ - REQUIRE(storage6[3] == 123); \ - REQUIRE(storage6[4] == 123); \ - \ - storage6.resize(10); \ - REQUIRE(storage6.size() == 10); \ - REQUIRE(storage6[0] == 123); \ - REQUIRE(storage6[1] == 123); \ - REQUIRE(storage6[2] == 123); \ - REQUIRE(storage6[3] == 123); \ - REQUIRE(storage6[4] == 123); \ - \ - storage6.resize(100, 0); \ - REQUIRE(storage6.size() == 100); \ - } +# define REGISTER_CASES(TYPE) \ + SECTION("Type: " STRINGIFY(TYPE)) { \ + using ScalarType = TYPE; \ + lrc::OpenCLStorage storage(5); \ + \ + REQUIRE(storage.size() == 5); \ + \ + storage[0] = 1; \ + storage[1] = 10; \ + \ + REQUIRE(storage[0] == 1); \ + REQUIRE(storage[1] == 10); \ + \ + lrc::OpenCLStorage storage2({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); \ + \ + REQUIRE(storage2.size() == 10); \ + REQUIRE(storage2[0] == 1); \ + REQUIRE(storage2[1] == 2); \ + REQUIRE(storage2[8] == 9); \ + REQUIRE(storage2[9] == 10); \ + \ + lrc::OpenCLStorage storage3(100, 1); \ + \ + REQUIRE(storage3.size() == 100); \ + REQUIRE(storage3[0] == 1); \ + REQUIRE(storage3[1] == 1); \ + REQUIRE(storage3[98] == 1); \ + REQUIRE(storage3[99] == 1); \ + \ + auto storage4 = lrc::OpenCLStorage(storage2); \ + \ + REQUIRE(storage4.size() == 10); \ + REQUIRE(storage4[0] == 1); \ + REQUIRE(storage4[1] == 2); \ + REQUIRE(storage4[8] == 9); \ + REQUIRE(storage4[9] == 10); \ + \ + /* storage4 = lrc::OpenCLStorage(100); */ \ + /* REQUIRE(storage4.size() == 100); */ \ + /* storage4[0] = 1; */ \ + /* storage4[1] = 2; */ \ + /* storage4[98] = 99; */ \ + /* storage4[99] = 100; */ \ + /* REQUIRE(storage4[0] == 1); */ \ + /* REQUIRE(storage4[1] == 2); */ \ + /* REQUIRE(storage4[98] == 99); */ \ + /* REQUIRE(storage4[99] == 100); */ \ + \ + storage4 = storage3; \ + \ + REQUIRE(storage4.size() == 100); \ + REQUIRE(storage4[0] == 1); \ + REQUIRE(storage4[1] == 1); \ + REQUIRE(storage4[98] == 1); \ + REQUIRE(storage4[99] == 1); \ + \ + lrc::OpenCLStorage storage6(20, 123); \ + REQUIRE(storage6.size() == 20); \ + storage6.resize(5); \ + REQUIRE(storage6.size() == 5); \ + REQUIRE(storage6[0] == 123); \ + REQUIRE(storage6[1] == 123); \ + REQUIRE(storage6[2] == 123); \ + REQUIRE(storage6[3] == 123); \ + REQUIRE(storage6[4] == 123); \ + \ + storage6.resize(10); \ + REQUIRE(storage6.size() == 10); \ + REQUIRE(storage6[0] == 123); \ + REQUIRE(storage6[1] == 123); \ + REQUIRE(storage6[2] == 123); \ + REQUIRE(storage6[3] == 123); \ + REQUIRE(storage6[4] == 123); \ + \ + storage6.resize(100, 0); \ + REQUIRE(storage6.size() == 100); \ + } -# define BENCHMARK_CONSTRUCTORS(TYPE_, FILL_) \ - BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 10") { \ - lrc::OpenCLStorage storage(10); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 1000") { \ - lrc::OpenCLStorage storage(1000); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 1000000") { \ - lrc::OpenCLStorage storage(1000000); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 10 FILLED") { \ - lrc::OpenCLStorage storage(10, FILL_); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 1000 FILLED") { \ - lrc::OpenCLStorage storage(1000, FILL_); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 1000000 FILLED") { \ - lrc::OpenCLStorage storage(1000000, FILL_); \ - return storage.size(); \ - } +# define BENCHMARK_CONSTRUCTORS(TYPE_, FILL_) \ + BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 10") { \ + lrc::OpenCLStorage storage(10); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 1000") { \ + lrc::OpenCLStorage storage(1000); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 1000000") { \ + lrc::OpenCLStorage storage(1000000); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 10 FILLED") { \ + lrc::OpenCLStorage storage(10, FILL_); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 1000 FILLED") { \ + lrc::OpenCLStorage storage(1000, FILL_); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("OpenCLStorage<" STRINGIFY(TYPE_) "> 1000000 FILLED") { \ + lrc::OpenCLStorage storage(1000000, FILL_); \ + return storage.size(); \ + } TEST_CASE("Configure OpenCL", "[storage]") { - SECTION("Configure OpenCL") { lrc::configureOpenCL(true); } + SECTION("Configure OpenCL") { lrc::configureOpenCL(true); } } TEST_CASE("Test OpenCLStorage", "[storage]") { - SECTION("Test OpenCLStorage") { - REGISTER_CASES(char); - REGISTER_CASES(unsigned char); - REGISTER_CASES(short); - REGISTER_CASES(unsigned short); - REGISTER_CASES(int); - REGISTER_CASES(unsigned int); - REGISTER_CASES(long); - REGISTER_CASES(unsigned long); - REGISTER_CASES(long long); - REGISTER_CASES(unsigned long long); - REGISTER_CASES(float); - REGISTER_CASES(double); - REGISTER_CASES(long double); - } + SECTION("Test OpenCLStorage") { + REGISTER_CASES(char); + REGISTER_CASES(unsigned char); + REGISTER_CASES(short); + REGISTER_CASES(unsigned short); + REGISTER_CASES(int); + REGISTER_CASES(unsigned int); + REGISTER_CASES(long); + REGISTER_CASES(unsigned long); + REGISTER_CASES(long long); + REGISTER_CASES(unsigned long long); + REGISTER_CASES(float); + REGISTER_CASES(double); + REGISTER_CASES(long double); + } - SECTION("Benchmarks") { - BENCHMARK_CONSTRUCTORS(int, 123); - BENCHMARK_CONSTRUCTORS(double, 456); - } + SECTION("Benchmarks") { + BENCHMARK_CONSTRUCTORS(int, 123); + BENCHMARK_CONSTRUCTORS(double, 456); + } } #else TEST_CASE("Default", "[storage]") { - LIBRAPID_WARN("OpenCL not available, skipping tests"); - SECTION("Default") { REQUIRE(true); } + LIBRAPID_WARN("OpenCL not available, skipping tests"); + SECTION("Default") { REQUIRE(true); } } #endif // LIBRAPID_HAS_OPENCL diff --git a/test/test-pseudoConstructors.cpp b/test/test-pseudoConstructors.cpp index 4e8d1db4..473c4da4 100644 --- a/test/test-pseudoConstructors.cpp +++ b/test/test-pseudoConstructors.cpp @@ -3,66 +3,66 @@ #include #include -namespace lrc = librapid; +namespace lrc = librapid; constexpr double tolerance = 0.001; // #define SCALAR float // #define BACKEND lrc::backend::CPU TEST_CASE("Test Array Generation Methods", "[array-lib]") { - SECTION("Test zeros()") { - auto a = lrc::zeros({3, 4, 5}); - REQUIRE(a.shape() == lrc::Shape({3, 4, 5})); - REQUIRE(a.storage().size() == 60); - for (size_t i = 0; i < a.storage().size(); i++) { REQUIRE(a.storage()[i] - 0 < tolerance); } - } + SECTION("Test zeros()") { + auto a = lrc::zeros({3, 4, 5}); + REQUIRE(a.shape() == lrc::Shape({3, 4, 5})); + REQUIRE(a.storage().size() == 60); + for (size_t i = 0; i < a.storage().size(); i++) { REQUIRE(a.storage()[i] - 0 < tolerance); } + } - SECTION("Test ones()") { - auto a = lrc::ones({3, 4, 5}); - REQUIRE(a.shape() == lrc::Shape({3, 4, 5})); - REQUIRE(a.storage().size() == 60); - for (size_t i = 0; i < a.storage().size(); i++) { REQUIRE(a.storage()[i] - 1 < tolerance); } - } + SECTION("Test ones()") { + auto a = lrc::ones({3, 4, 5}); + REQUIRE(a.shape() == lrc::Shape({3, 4, 5})); + REQUIRE(a.storage().size() == 60); + for (size_t i = 0; i < a.storage().size(); i++) { REQUIRE(a.storage()[i] - 1 < tolerance); } + } - SECTION("Test ordered()") { - auto a = lrc::ordered({3, 4, 5}); - REQUIRE(a.shape() == lrc::Shape({3, 4, 5})); - REQUIRE(a.storage().size() == 60); - for (size_t i = 0; i < a.storage().size(); i++) { REQUIRE(a.storage()[i] - i < tolerance); } - } + SECTION("Test ordered()") { + auto a = lrc::ordered({3, 4, 5}); + REQUIRE(a.shape() == lrc::Shape({3, 4, 5})); + REQUIRE(a.storage().size() == 60); + for (size_t i = 0; i < a.storage().size(); i++) { REQUIRE(a.storage()[i] - i < tolerance); } + } - SECTION("Test arange()") { - auto a = lrc::arange(0, 10, 1); - REQUIRE(a.shape() == lrc::Shape({10})); - REQUIRE(a.storage().size() == 10); - for (size_t i = 0; i < a.storage().size(); i++) { REQUIRE(a.storage()[i] - i < tolerance); } + SECTION("Test arange()") { + auto a = lrc::arange(0, 10, 1); + REQUIRE(a.shape() == lrc::Shape({10})); + REQUIRE(a.storage().size() == 10); + for (size_t i = 0; i < a.storage().size(); i++) { REQUIRE(a.storage()[i] - i < tolerance); } - auto b = lrc::arange(0, 10, 2); - REQUIRE(b.shape() == lrc::Shape({5})); - REQUIRE(b.storage().size() == 5); - for (size_t i = 0; i < b.storage().size(); i++) { - REQUIRE(b.storage()[i] - i * 2 < tolerance); - } - } + auto b = lrc::arange(0, 10, 2); + REQUIRE(b.shape() == lrc::Shape({5})); + REQUIRE(b.storage().size() == 5); + for (size_t i = 0; i < b.storage().size(); i++) { + REQUIRE(b.storage()[i] - i * 2 < tolerance); + } + } - SECTION("Test linspace()") { - auto a = lrc::linspace(0, 10, 10, false); - REQUIRE(a.shape() == lrc::Shape({10})); - REQUIRE(a.storage().size() == 10); - for (size_t i = 0; i < a.storage().size(); i++) { REQUIRE(a.storage()[i] - i < tolerance); } + SECTION("Test linspace()") { + auto a = lrc::linspace(0, 10, 10, false); + REQUIRE(a.shape() == lrc::Shape({10})); + REQUIRE(a.storage().size() == 10); + for (size_t i = 0; i < a.storage().size(); i++) { REQUIRE(a.storage()[i] - i < tolerance); } - auto b = lrc::linspace(0, 10, 100, false); - REQUIRE(b.shape() == lrc::Shape({100})); - REQUIRE(b.storage().size() == 100); - for (size_t i = 0; i < b.storage().size(); i++) { - REQUIRE(b.storage()[i] - static_cast(i) / 10 < tolerance); - } + auto b = lrc::linspace(0, 10, 100, false); + REQUIRE(b.shape() == lrc::Shape({100})); + REQUIRE(b.storage().size() == 100); + for (size_t i = 0; i < b.storage().size(); i++) { + REQUIRE(b.storage()[i] - static_cast(i) / 10 < tolerance); + } - auto c = lrc::linspace(0, 10, 10, true); - REQUIRE(c.shape() == lrc::Shape({10})); - REQUIRE(c.storage().size() == 10); - for (size_t i = 0; i < c.storage().size(); i++) { - REQUIRE(c.storage()[i] - static_cast(i) * (10.0 / 9.0) < tolerance); - } - } + auto c = lrc::linspace(0, 10, 10, true); + REQUIRE(c.shape() == lrc::Shape({10})); + REQUIRE(c.storage().size() == 10); + for (size_t i = 0; i < c.storage().size(); i++) { + REQUIRE(c.storage()[i] - static_cast(i) * (10.0 / 9.0) < tolerance); + } + } } diff --git a/test/test-sigmoid.cpp b/test/test-sigmoid.cpp index 54b9f486..9f853da1 100644 --- a/test/test-sigmoid.cpp +++ b/test/test-sigmoid.cpp @@ -3,53 +3,53 @@ #include #include -namespace lrc = librapid; +namespace lrc = librapid; constexpr double tolerance = 0.001; -using CPU = lrc::backend::CPU; -using OPENCL = lrc::backend::OpenCL; -using CUDA = lrc::backend::CUDA; +using CPU = lrc::backend::CPU; +using OPENCL = lrc::backend::OpenCL; +using CUDA = lrc::backend::CUDA; #define TEST_SIGMOID(SCALAR, BACKEND) \ - TEST_CASE(fmt::format("Test Sigmoid -- [ {} | {} ]", STRINGIFY(SCALAR), STRINGIFY(BACKEND)), \ - "[sigmoid]") { \ - SECTION("Forward Sigmoid") { \ - int64_t n = 100; \ - auto sigmoid = lrc::ml::Sigmoid(); \ - auto data = lrc::linspace(-10, 10, n); \ - auto f = [](SCALAR x) { return 1 / (1 + lrc::exp(-x)); }; \ + TEST_CASE(fmt::format("Test Sigmoid -- [ {} | {} ]", STRINGIFY(SCALAR), STRINGIFY(BACKEND)), \ + "[sigmoid]") { \ + SECTION("Forward Sigmoid") { \ + int64_t n = 100; \ + auto sigmoid = lrc::ml::Sigmoid(); \ + auto data = lrc::linspace(-10, 10, n); \ + auto f = [](SCALAR x) { return 1 / (1 + lrc::exp(-x)); }; \ \ - auto result = lrc::zeros(lrc::Shape({n})); \ - sigmoid.forward(result, data); \ - auto result2 = sigmoid(data); \ - auto result3 = sigmoid.forward(data); \ + auto result = lrc::zeros(lrc::Shape({n})); \ + sigmoid.forward(result, data); \ + auto result2 = sigmoid(data); \ + auto result3 = sigmoid.forward(data); \ \ - for (int64_t i = 0; i < n; ++i) { \ - REQUIRE(lrc::isClose((SCALAR)result(i), (SCALAR)f(data(i)), tolerance)); \ - REQUIRE(lrc::isClose((SCALAR)result2(i), (SCALAR)f(data(i)), tolerance)); \ - REQUIRE(lrc::isClose((SCALAR)result3(i), (SCALAR)f(data(i)), tolerance)); \ - } \ - } \ + for (int64_t i = 0; i < n; ++i) { \ + REQUIRE(lrc::isClose((SCALAR)result(i), (SCALAR)f(data(i)), tolerance)); \ + REQUIRE(lrc::isClose((SCALAR)result2(i), (SCALAR)f(data(i)), tolerance)); \ + REQUIRE(lrc::isClose((SCALAR)result3(i), (SCALAR)f(data(i)), tolerance)); \ + } \ + } \ \ - SECTION("Backward Sigmoid") { \ - int64_t n = 100; \ - auto sigmoid = lrc::ml::Sigmoid(); \ - auto data = lrc::linspace(-10, 10, n); \ - auto f = [](SCALAR x) { return 1 / (1 + lrc::exp(-x)); }; \ - auto fPrime = [](SCALAR x) { return x * (1 - x); }; \ + SECTION("Backward Sigmoid") { \ + int64_t n = 100; \ + auto sigmoid = lrc::ml::Sigmoid(); \ + auto data = lrc::linspace(-10, 10, n); \ + auto f = [](SCALAR x) { return 1 / (1 + lrc::exp(-x)); }; \ + auto fPrime = [](SCALAR x) { return x * (1 - x); }; \ \ - auto result = lrc::zeros(lrc::Shape({n})); \ - sigmoid.forward(result, data); \ - sigmoid.backward(result, result); \ - auto result2 = sigmoid.backward(sigmoid(data)); \ - auto result3 = sigmoid.backward(sigmoid.forward(data)); \ + auto result = lrc::zeros(lrc::Shape({n})); \ + sigmoid.forward(result, data); \ + sigmoid.backward(result, result); \ + auto result2 = sigmoid.backward(sigmoid(data)); \ + auto result3 = sigmoid.backward(sigmoid.forward(data)); \ \ - for (int64_t i = 0; i < n; ++i) { \ - REQUIRE(lrc::isClose((SCALAR)result(i), (SCALAR)fPrime(f(data(i))), tolerance)); \ - REQUIRE(lrc::isClose((SCALAR)result2(i), (SCALAR)fPrime(f(data(i))), tolerance)); \ - REQUIRE(lrc::isClose((SCALAR)result3(i), (SCALAR)fPrime(f(data(i))), tolerance)); \ - } \ - } \ - } + for (int64_t i = 0; i < n; ++i) { \ + REQUIRE(lrc::isClose((SCALAR)result(i), (SCALAR)fPrime(f(data(i))), tolerance)); \ + REQUIRE(lrc::isClose((SCALAR)result2(i), (SCALAR)fPrime(f(data(i))), tolerance)); \ + REQUIRE(lrc::isClose((SCALAR)result3(i), (SCALAR)fPrime(f(data(i))), tolerance)); \ + } \ + } \ + } TEST_SIGMOID(float, CPU) TEST_SIGMOID(double, CPU) diff --git a/test/test-sizetype.cpp b/test/test-sizetype.cpp index 4e783fd5..eb6730c1 100644 --- a/test/test-sizetype.cpp +++ b/test/test-sizetype.cpp @@ -6,53 +6,53 @@ namespace lrc = librapid; TEST_CASE("Test Storage", "[storage]") { - lrc::Shape shape1({1, 2, 3, 4}); - REQUIRE(shape1.str() == "(1, 2, 3, 4)"); - - lrc::Shape zero = lrc::Shape::zeros(3); - REQUIRE(zero.str() == "(0, 0, 0)"); - - lrc::Shape ones = lrc::Shape::ones(3); - REQUIRE(ones.str() == "(1, 1, 1)"); - - REQUIRE(shape1.ndim() == 4); - REQUIRE(shape1[0] == 1); - REQUIRE(shape1[1] == 2); - REQUIRE(shape1[2] == 3); - REQUIRE(shape1[3] == 4); - - REQUIRE(shape1.size() == 24); - REQUIRE(zero.size() == 0); - REQUIRE(ones.size() == 1); - - REQUIRE(shape1 == shape1); - REQUIRE_FALSE(shape1 != shape1); - REQUIRE_FALSE(shape1 == zero); - REQUIRE(shape1 != zero); - - REQUIRE(ones == lrc::Shape({1, 1, 1})); - REQUIRE(zero == lrc::Shape({0, 0, 0})); - REQUIRE(lrc::Shape({1, 2, 3, 4}) == lrc::Shape({1, 2, 3, 4})); - REQUIRE(lrc::Shape({1, 2, 3, 4}) != lrc::Shape({1, 2, 3, 5})); - - REQUIRE(lrc::shapesMatch(lrc::Shape({1, 2, 3, 4}), lrc::Shape({1, 2, 3, 4}))); - REQUIRE_FALSE(lrc::shapesMatch(lrc::Shape({1, 2, 3, 4}), lrc::Shape({1, 2, 3, 5}))); - REQUIRE(lrc::shapesMatch( - lrc::Shape({1, 2, 3, 4}), lrc::Shape({1, 2, 3, 4}), lrc::Shape({1, 2, 3, 4}))); - - SECTION("Benchmarks") { - BENCHMARK("Shape::zeros(5)") { - auto shape = lrc::Shape::zeros(5); - return shape.size(); - }; - - BENCHMARK("Shape::ones(5)") { - auto shape = lrc::Shape::ones(5); - return shape.size(); - }; - - auto lhs = lrc::Shape::ones(128); - auto rhs = lrc::Shape::ones(128); - BENCHMARK("Equality") { return lhs == rhs; }; - } + lrc::Shape shape1({1, 2, 3, 4}); + REQUIRE(shape1.str() == "(1, 2, 3, 4)"); + + lrc::Shape zero = lrc::Shape::zeros(3); + REQUIRE(zero.str() == "(0, 0, 0)"); + + lrc::Shape ones = lrc::Shape::ones(3); + REQUIRE(ones.str() == "(1, 1, 1)"); + + REQUIRE(shape1.ndim() == 4); + REQUIRE(shape1[0] == 1); + REQUIRE(shape1[1] == 2); + REQUIRE(shape1[2] == 3); + REQUIRE(shape1[3] == 4); + + REQUIRE(shape1.size() == 24); + REQUIRE(zero.size() == 0); + REQUIRE(ones.size() == 1); + + REQUIRE(shape1 == shape1); + REQUIRE_FALSE(shape1 != shape1); + REQUIRE_FALSE(shape1 == zero); + REQUIRE(shape1 != zero); + + REQUIRE(ones == lrc::Shape({1, 1, 1})); + REQUIRE(zero == lrc::Shape({0, 0, 0})); + REQUIRE(lrc::Shape({1, 2, 3, 4}) == lrc::Shape({1, 2, 3, 4})); + REQUIRE(lrc::Shape({1, 2, 3, 4}) != lrc::Shape({1, 2, 3, 5})); + + REQUIRE(lrc::shapesMatch(lrc::Shape({1, 2, 3, 4}), lrc::Shape({1, 2, 3, 4}))); + REQUIRE_FALSE(lrc::shapesMatch(lrc::Shape({1, 2, 3, 4}), lrc::Shape({1, 2, 3, 5}))); + REQUIRE(lrc::shapesMatch( + lrc::Shape({1, 2, 3, 4}), lrc::Shape({1, 2, 3, 4}), lrc::Shape({1, 2, 3, 4}))); + + SECTION("Benchmarks") { + BENCHMARK("Shape::zeros(5)") { + auto shape = lrc::Shape::zeros(5); + return shape.size(); + }; + + BENCHMARK("Shape::ones(5)") { + auto shape = lrc::Shape::ones(5); + return shape.size(); + }; + + auto lhs = lrc::Shape::ones(128); + auto rhs = lrc::Shape::ones(128); + BENCHMARK("Equality") { return lhs == rhs; }; + } } diff --git a/test/test-storage.cpp b/test/test-storage.cpp index 8b77e8a7..f39166a2 100644 --- a/test/test-storage.cpp +++ b/test/test-storage.cpp @@ -6,210 +6,210 @@ namespace lrc = librapid; #define REGISTER_CASES(TYPE) \ - SECTION("Type: " STRINGIFY(TYPE)) { \ - using ScalarType = TYPE; \ - lrc::Storage storage(5); \ - \ - REQUIRE(storage.size() == 5); \ - \ - storage[0] = 1; \ - storage[1] = 10; \ - \ - REQUIRE(storage[0] == 1); \ - REQUIRE(storage[1] == 10); \ - \ - lrc::Storage storage2({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); \ - \ - REQUIRE(storage2.size() == 10); \ - REQUIRE(storage2[0] == 1); \ - REQUIRE(storage2[1] == 2); \ - REQUIRE(storage2[8] == 9); \ - REQUIRE(storage2[9] == 10); \ - \ - lrc::Storage storage3(100, 1); \ - \ - REQUIRE(storage3.size() == 100); \ - REQUIRE(storage3[0] == 1); \ - REQUIRE(storage3[1] == 1); \ - REQUIRE(storage3[98] == 1); \ - REQUIRE(storage3[99] == 1); \ - \ - auto storage4 = lrc::Storage(storage2); \ - \ - REQUIRE(storage4.size() == 10); \ - REQUIRE(storage4[0] == 1); \ - REQUIRE(storage4[1] == 2); \ - REQUIRE(storage4[8] == 9); \ - REQUIRE(storage4[9] == 10); \ - \ - /* storage4 = lrc::Storage(100); */ \ - /* REQUIRE(storage4.size() == 100); */ \ - /* storage4[0] = 1; */ \ - /* storage4[1] = 2; */ \ - /* storage4[98] = 99; */ \ - /* storage4[99] = 100; */ \ - /* REQUIRE(storage4[0] == 1); */ \ - /* REQUIRE(storage4[1] == 2); */ \ - /* REQUIRE(storage4[98] == 99); */ \ - /* REQUIRE(storage4[99] == 100); */ \ - \ - storage4 = storage3; \ - \ - REQUIRE(storage4.size() == 100); \ - REQUIRE(storage4[0] == 1); \ - REQUIRE(storage4[1] == 1); \ - REQUIRE(storage4[98] == 1); \ - REQUIRE(storage4[99] == 1); \ - \ - SECTION("Const Iterator") { \ - ScalarType i = 1; \ - for (const auto &val : storage2) { \ - REQUIRE(val == i); \ - i += 1; \ - } \ - } \ - \ - SECTION("Non-Const Iterator") { \ - ScalarType i = 1; \ - for (auto &val : storage2) { \ - REQUIRE(val == i); \ - i += 1; \ - } \ - } \ - \ - lrc::Storage storage6(20, 123); \ - REQUIRE(storage6.size() == 20); \ - storage6.resize(5); \ - REQUIRE(storage6.size() == 5); \ - REQUIRE(storage6[0] == 123); \ - REQUIRE(storage6[1] == 123); \ - REQUIRE(storage6[2] == 123); \ - REQUIRE(storage6[3] == 123); \ - REQUIRE(storage6[4] == 123); \ - \ - storage6.resize(10); \ - REQUIRE(storage6.size() == 10); \ - REQUIRE(storage6[0] == 123); \ - REQUIRE(storage6[1] == 123); \ - REQUIRE(storage6[2] == 123); \ - REQUIRE(storage6[3] == 123); \ - REQUIRE(storage6[4] == 123); \ - \ - storage6.resize(100, 0); \ - REQUIRE(storage6.size() == 100); \ - } + SECTION("Type: " STRINGIFY(TYPE)) { \ + using ScalarType = TYPE; \ + lrc::Storage storage(5); \ + \ + REQUIRE(storage.size() == 5); \ + \ + storage[0] = 1; \ + storage[1] = 10; \ + \ + REQUIRE(storage[0] == 1); \ + REQUIRE(storage[1] == 10); \ + \ + lrc::Storage storage2({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); \ + \ + REQUIRE(storage2.size() == 10); \ + REQUIRE(storage2[0] == 1); \ + REQUIRE(storage2[1] == 2); \ + REQUIRE(storage2[8] == 9); \ + REQUIRE(storage2[9] == 10); \ + \ + lrc::Storage storage3(100, 1); \ + \ + REQUIRE(storage3.size() == 100); \ + REQUIRE(storage3[0] == 1); \ + REQUIRE(storage3[1] == 1); \ + REQUIRE(storage3[98] == 1); \ + REQUIRE(storage3[99] == 1); \ + \ + auto storage4 = lrc::Storage(storage2); \ + \ + REQUIRE(storage4.size() == 10); \ + REQUIRE(storage4[0] == 1); \ + REQUIRE(storage4[1] == 2); \ + REQUIRE(storage4[8] == 9); \ + REQUIRE(storage4[9] == 10); \ + \ + /* storage4 = lrc::Storage(100); */ \ + /* REQUIRE(storage4.size() == 100); */ \ + /* storage4[0] = 1; */ \ + /* storage4[1] = 2; */ \ + /* storage4[98] = 99; */ \ + /* storage4[99] = 100; */ \ + /* REQUIRE(storage4[0] == 1); */ \ + /* REQUIRE(storage4[1] == 2); */ \ + /* REQUIRE(storage4[98] == 99); */ \ + /* REQUIRE(storage4[99] == 100); */ \ + \ + storage4 = storage3; \ + \ + REQUIRE(storage4.size() == 100); \ + REQUIRE(storage4[0] == 1); \ + REQUIRE(storage4[1] == 1); \ + REQUIRE(storage4[98] == 1); \ + REQUIRE(storage4[99] == 1); \ + \ + SECTION("Const Iterator") { \ + ScalarType i = 1; \ + for (const auto &val : storage2) { \ + REQUIRE(val == i); \ + i += 1; \ + } \ + } \ + \ + SECTION("Non-Const Iterator") { \ + ScalarType i = 1; \ + for (auto &val : storage2) { \ + REQUIRE(val == i); \ + i += 1; \ + } \ + } \ + \ + lrc::Storage storage6(20, 123); \ + REQUIRE(storage6.size() == 20); \ + storage6.resize(5); \ + REQUIRE(storage6.size() == 5); \ + REQUIRE(storage6[0] == 123); \ + REQUIRE(storage6[1] == 123); \ + REQUIRE(storage6[2] == 123); \ + REQUIRE(storage6[3] == 123); \ + REQUIRE(storage6[4] == 123); \ + \ + storage6.resize(10); \ + REQUIRE(storage6.size() == 10); \ + REQUIRE(storage6[0] == 123); \ + REQUIRE(storage6[1] == 123); \ + REQUIRE(storage6[2] == 123); \ + REQUIRE(storage6[3] == 123); \ + REQUIRE(storage6[4] == 123); \ + \ + storage6.resize(100, 0); \ + REQUIRE(storage6.size() == 100); \ + } #define BENCHMARK_CONSTRUCTORS(TYPE_, FILL_) \ - BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 10") { \ - lrc::Storage storage(10); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 1000") { \ - lrc::Storage storage(1000); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 1000000") { \ - lrc::Storage storage(1000000); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 10 FILLED") { \ - lrc::Storage storage(10, FILL_); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 1000 FILLED") { \ - lrc::Storage storage(1000, FILL_); \ - return storage.size(); \ - }; \ - \ - BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 1000000 FILLED") { \ - lrc::Storage storage(1000000, FILL_); \ - return storage.size(); \ - } + BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 10") { \ + lrc::Storage storage(10); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 1000") { \ + lrc::Storage storage(1000); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 1000000") { \ + lrc::Storage storage(1000000); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 10 FILLED") { \ + lrc::Storage storage(10, FILL_); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 1000 FILLED") { \ + lrc::Storage storage(1000, FILL_); \ + return storage.size(); \ + }; \ + \ + BENCHMARK("Storage<" STRINGIFY(TYPE_) "> 1000000 FILLED") { \ + lrc::Storage storage(1000000, FILL_); \ + return storage.size(); \ + } TEST_CASE("Test Storage", "[storage]") { - SECTION("Trivially Constructible Storage") { - REGISTER_CASES(char); - REGISTER_CASES(unsigned char); - REGISTER_CASES(short); - REGISTER_CASES(unsigned short); - REGISTER_CASES(int); - REGISTER_CASES(unsigned int); - REGISTER_CASES(long); - REGISTER_CASES(unsigned long); - REGISTER_CASES(long long); - REGISTER_CASES(unsigned long long); - REGISTER_CASES(float); - REGISTER_CASES(double); - REGISTER_CASES(long double); - } + SECTION("Trivially Constructible Storage") { + REGISTER_CASES(char); + REGISTER_CASES(unsigned char); + REGISTER_CASES(short); + REGISTER_CASES(unsigned short); + REGISTER_CASES(int); + REGISTER_CASES(unsigned int); + REGISTER_CASES(long); + REGISTER_CASES(unsigned long); + REGISTER_CASES(long long); + REGISTER_CASES(unsigned long long); + REGISTER_CASES(float); + REGISTER_CASES(double); + REGISTER_CASES(long double); + } - SECTION("Non-Trivially Constructible Storage") { - // Can't use normal tests, so just test a few things - lrc::Storage storage(5); - REQUIRE(storage.size() == 5); - storage[0] = "Hello"; - storage[1] = "World"; - REQUIRE(storage[0] == "Hello"); - REQUIRE(storage[1] == "World"); + SECTION("Non-Trivially Constructible Storage") { + // Can't use normal tests, so just test a few things + lrc::Storage storage(5); + REQUIRE(storage.size() == 5); + storage[0] = "Hello"; + storage[1] = "World"; + REQUIRE(storage[0] == "Hello"); + REQUIRE(storage[1] == "World"); - lrc::Storage storage2({"Hello", "World"}); - REQUIRE(storage2.size() == 2); - REQUIRE(storage2[0] == "Hello"); - REQUIRE(storage2[1] == "World"); + lrc::Storage storage2({"Hello", "World"}); + REQUIRE(storage2.size() == 2); + REQUIRE(storage2[0] == "Hello"); + REQUIRE(storage2[1] == "World"); - lrc::Storage storage3(20, "Hello"); - REQUIRE(storage3.size() == 20); - REQUIRE(storage3[0] == "Hello"); - REQUIRE(storage3[1] == "Hello"); - REQUIRE(storage3[18] == "Hello"); - REQUIRE(storage3[19] == "Hello"); + lrc::Storage storage3(20, "Hello"); + REQUIRE(storage3.size() == 20); + REQUIRE(storage3[0] == "Hello"); + REQUIRE(storage3[1] == "Hello"); + REQUIRE(storage3[18] == "Hello"); + REQUIRE(storage3[19] == "Hello"); - auto storage4 = lrc::Storage>(10); - REQUIRE(storage4.size() == 10); - storage4[0].push_back("Hello"); - storage4[0].push_back("World"); - REQUIRE(storage4[0][0] == "Hello"); - REQUIRE(storage4[0][1] == "World"); + auto storage4 = lrc::Storage>(10); + REQUIRE(storage4.size() == 10); + storage4[0].push_back("Hello"); + storage4[0].push_back("World"); + REQUIRE(storage4[0][0] == "Hello"); + REQUIRE(storage4[0][1] == "World"); - struct Three { - int a; - int b; - int c; - }; + struct Three { + int a; + int b; + int c; + }; - auto storage5 = lrc::Storage(10); - REQUIRE(storage5.size() == 10); - storage5[0].a = 1; - storage5[0].b = 2; - storage5[0].c = 3; - storage5[1].a = 4; - storage5[1].b = 5; - storage5[1].c = 6; - REQUIRE(storage5[0].a == 1); - REQUIRE(storage5[0].b == 2); - REQUIRE(storage5[0].c == 3); - REQUIRE(storage5[1].a == 4); - REQUIRE(storage5[1].b == 5); - REQUIRE(storage5[1].c == 6); + auto storage5 = lrc::Storage(10); + REQUIRE(storage5.size() == 10); + storage5[0].a = 1; + storage5[0].b = 2; + storage5[0].c = 3; + storage5[1].a = 4; + storage5[1].b = 5; + storage5[1].c = 6; + REQUIRE(storage5[0].a == 1); + REQUIRE(storage5[0].b == 2); + REQUIRE(storage5[0].c == 3); + REQUIRE(storage5[1].a == 4); + REQUIRE(storage5[1].b == 5); + REQUIRE(storage5[1].c == 6); - auto storage6 = storage5; - REQUIRE(storage5[0].a == 1); - REQUIRE(storage5[0].b == 2); - REQUIRE(storage5[0].c == 3); - REQUIRE(storage5[1].a == 4); - REQUIRE(storage5[1].b == 5); - REQUIRE(storage5[1].c == 6); - } + auto storage6 = storage5; + REQUIRE(storage5[0].a == 1); + REQUIRE(storage5[0].b == 2); + REQUIRE(storage5[0].c == 3); + REQUIRE(storage5[1].a == 4); + REQUIRE(storage5[1].b == 5); + REQUIRE(storage5[1].c == 6); + } - SECTION("Benchmarks") { - BENCHMARK_CONSTRUCTORS(int, 123); - BENCHMARK_CONSTRUCTORS(double, 456); - BENCHMARK_CONSTRUCTORS(std::string, "Hello, World"); - BENCHMARK_CONSTRUCTORS(std::vector, {1 COMMA 2 COMMA 3 COMMA 4}); - BENCHMARK_CONSTRUCTORS(std::vector, {1 COMMA 2 COMMA 3 COMMA 4}); - } + SECTION("Benchmarks") { + BENCHMARK_CONSTRUCTORS(int, 123); + BENCHMARK_CONSTRUCTORS(double, 456); + BENCHMARK_CONSTRUCTORS(std::string, "Hello, World"); + BENCHMARK_CONSTRUCTORS(std::vector, {1 COMMA 2 COMMA 3 COMMA 4}); + BENCHMARK_CONSTRUCTORS(std::vector, {1 COMMA 2 COMMA 3 COMMA 4}); + } } diff --git a/test/test-vector.cpp b/test/test-vector.cpp index d6767a0c..e4e4eda9 100644 --- a/test/test-vector.cpp +++ b/test/test-vector.cpp @@ -9,8 +9,6 @@ namespace lrc = librapid; // If the results are within this tolerance, they are likely correct. constexpr double tolerance = 1e-3; #define VEC_TYPE lrc::GenericVector -#define SCALAR double +#define SCALAR double -TEST_CASE("Temporary") { - REQUIRE(1 == 1); -} +TEST_CASE("Temporary") { REQUIRE(1 == 1); } From 2f9f8c7d7adff027c595db8dd762c90e09c48669 Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Sat, 12 Aug 2023 22:11:18 -0700 Subject: [PATCH 20/29] Continue updating to C++23 --- .../include/librapid/array/arrayContainer.hpp | 57 +------ .../include/librapid/array/arrayTypeDef.hpp | 22 +-- librapid/include/librapid/array/arrayView.hpp | 54 +------ .../librapid/array/arrayViewString.hpp | 6 +- librapid/include/librapid/array/function.hpp | 18 +-- .../librapid/array/linalg/arrayMultiply.hpp | 24 ++- .../include/librapid/array/linalg/linalg.hpp | 3 +- .../librapid/array/linalg/transpose.hpp | 30 ++-- librapid/include/librapid/array/sizetype.hpp | 40 +++-- librapid/include/librapid/autodiff/dual.hpp | 102 ++++++------ .../include/librapid/cuda/cudaStorage.hpp | 27 +++- librapid/include/librapid/math/complex.hpp | 49 ++++-- librapid/include/librapid/math/multiprec.hpp | 145 ++++++++---------- .../include/librapid/opencl/openclStorage.hpp | 25 ++- librapid/include/librapid/utils/time.hpp | 36 ++++- test/test-array.cpp | 86 ++++++----- test/test-complex.cpp | 8 +- test/test-sizetype.cpp | 6 +- 18 files changed, 367 insertions(+), 371 deletions(-) diff --git a/librapid/include/librapid/array/arrayContainer.hpp b/librapid/include/librapid/array/arrayContainer.hpp index a6c3bb21..2fff747e 100644 --- a/librapid/include/librapid/array/arrayContainer.hpp +++ b/librapid/include/librapid/array/arrayContainer.hpp @@ -845,61 +845,8 @@ namespace librapid { // Support FMT printing #ifdef FMT_API -template -struct fmt::formatter> { - using Type = librapid::array::ArrayContainer; - using Scalar = typename librapid::typetraits::TypeInfo::Scalar; - using Formatter = fmt::formatter; - Formatter m_formatter; - char m_bracket = 's'; - char m_separator = ' '; - - template - FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { - // Custom format options: - // - "~r" for round brackets - // - "~s" for square brackets - // - "~c" for curly brackets - // - "~a" for angle brackets - // - "~p" for pipe brackets - // - "-," for comma separator - // - "-;" for semicolon separator - // - "-:" for colon separator - // - "-|" for pipe separator - // - "-_" for underscore separator - - auto it = ctx.begin(), end = ctx.end(); - if (it != end && *it == '~') { - ++it; - if (it != end && (*it == 'r' || *it == 's' || *it == 'c' || *it == 'a' || *it == 'p')) { - m_bracket = *it++; - } - } - - if (it != end && *it == '-') { - ++it; - if (it != end) { m_separator = *it++; } - } - - ctx.advance_to(it); - - return m_formatter.parse(ctx); - } - - template - FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) { - val.str(m_formatter, m_bracket, m_separator, ctx); - return ctx.out(); - } -}; - -template -auto operator<<(std::ostream &os, - const librapid::array::ArrayContainer &object) - -> std::ostream & { - os << fmt::format("{}", object); - return os; -} +ARRAY_TYPE_FMT_IML(typename ShapeType_ COMMA typename StorageType_, + librapid::array::ArrayContainer) LIBRAPID_SIMPLE_IO_NORANGE(typename ShapeType_ COMMA typename StorageType_, librapid::array::ArrayContainer) diff --git a/librapid/include/librapid/array/arrayTypeDef.hpp b/librapid/include/librapid/array/arrayTypeDef.hpp index fd00bada..587e5a27 100644 --- a/librapid/include/librapid/array/arrayTypeDef.hpp +++ b/librapid/include/librapid/array/arrayTypeDef.hpp @@ -93,17 +93,17 @@ namespace librapid { \ template \ FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { \ - /* Custom format options: */ \ - /* - "~r" for round brackets */ \ - /* - "~s" for square brackets */ \ - /* - "~c" for curly brackets */ \ - /* - "~a" for angle brackets */ \ - /* - "~p" for pipe brackets */ \ - /* - "-," for comma separator */ \ - /* - "-;" for semicolon separator */ \ - /* - "-:" for colon separator */ \ - /* - "-|" for pipe separator */ \ - /* - "-_" for underscore separator */ \ + /* Custom format options: */ \ + /* - "~r" for round brackets */ \ + /* - "~s" for square brackets */ \ + /* - "~c" for curly brackets */ \ + /* - "~a" for angle brackets */ \ + /* - "~p" for pipe brackets */ \ + /* - "-," for comma separator */ \ + /* - "-;" for semicolon separator */ \ + /* - "-:" for colon separator */ \ + /* - "-|" for pipe separator */ \ + /* - "-_" for underscore separator */ \ \ auto it = ctx.begin(), end = ctx.end(); \ if (it != end && *it == '~') { \ diff --git a/librapid/include/librapid/array/arrayView.hpp b/librapid/include/librapid/array/arrayView.hpp index b95eb3a5..76895c63 100644 --- a/librapid/include/librapid/array/arrayView.hpp +++ b/librapid/include/librapid/array/arrayView.hpp @@ -34,11 +34,11 @@ namespace librapid { /// Copy an ArrayView object /// \param array The array to copy - explicit ArrayView(ArrayViewType &array); + ArrayView(ArrayViewType &array); /// Copy an ArrayView object (not const) /// \param array The array to copy - explicit ArrayView(ArrayViewType &&array) = delete; + ArrayView(ArrayViewType &&array); /// Copy an ArrayView object (const) /// \param other The array to copy @@ -148,6 +148,10 @@ namespace librapid { ArrayView::ArrayView(ArrayViewType &array) : m_ref(array), m_shape(array.shape()), m_stride(array.shape()) {} + template + ArrayView::ArrayView(ArrayViewType &&array) : + m_ref(array), m_shape(array.shape()), m_stride(array.shape()) {} + template ArrayView &ArrayView::operator=(const Scalar &scalar) { LIBRAPID_ASSERT(m_shape.ndim() == 0, "Cannot assign to a non-scalar ArrayView."); @@ -322,51 +326,7 @@ namespace librapid { // Support FMT printing #ifdef FMT_API -template -struct fmt::formatter> { - using Type = librapid::array::ArrayView; - using Scalar = typename librapid::typetraits::TypeInfo::Scalar; - using Formatter = fmt::formatter; - Formatter m_formatter; - char m_bracket = 's'; - char m_separator = ' '; - - template - FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { - // Same formatting options as for the ArrayContainer type - - auto it = ctx.begin(), end = ctx.end(); - if (it != end && *it == '~') { - ++it; - if (it != end && (*it == 'r' || *it == 's' || *it == 'c' || *it == 'a' || *it == 'p')) { - m_bracket = *it++; - } - } - - if (it != end && *it == '-') { - ++it; - if (it != end) { m_separator = *it++; } - } - - ctx.advance_to(it); - - return m_formatter.parse(ctx); - } - - template - FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) { - val.str(m_formatter, m_bracket, m_separator, ctx); - return ctx.out(); - } -}; - -template -auto operator<<(std::ostream &os, const librapid::array::ArrayView &object) - -> std::ostream & { - os << fmt::format("{}", object); - return os; -} - +ARRAY_TYPE_FMT_IML(typename T, librapid::array::ArrayView) LIBRAPID_SIMPLE_IO_NORANGE(typename T, librapid::array::ArrayView) #endif // FMT_API diff --git a/librapid/include/librapid/array/arrayViewString.hpp b/librapid/include/librapid/array/arrayViewString.hpp index 488ffbca..1da10ff4 100644 --- a/librapid/include/librapid/array/arrayViewString.hpp +++ b/librapid/include/librapid/array/arrayViewString.hpp @@ -59,7 +59,11 @@ namespace librapid { if (i > 0) fmt::format_to(ctx.out(), "{}", std::string(indent + 1, ' ')); arrayViewToString(view[i], formatter, bracket, separator, indent + 1, ctx); if (i != view.shape()[0] - 1) { - fmt::format_to(ctx.out(), "{}\n", separator); + if (separator == ' ') { + fmt::format_to(ctx.out(), "\n"); + } else { + fmt::format_to(ctx.out(), "{}\n", separator); + } if (view.ndim() > 2) { fmt::format_to(ctx.out(), "\n"); } } } diff --git a/librapid/include/librapid/array/function.hpp b/librapid/include/librapid/array/function.hpp index 41463e50..0b54344a 100644 --- a/librapid/include/librapid/array/function.hpp +++ b/librapid/include/librapid/array/function.hpp @@ -174,10 +174,9 @@ namespace librapid { LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator begin() const; LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Iterator end() const; - /// Return a string representation of the Function - /// \param format The format to use. - /// \return A string representation of the Function - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; + template + void str(const fmt::formatter &format, char bracket, char separator, + Ctx &ctx) const; private: /// Implementation detail -- evaluates the function at the given index, @@ -264,17 +263,18 @@ namespace librapid { } template - std::string Function::str(const std::string &format) const { - return eval().str(format); + template + void Function::str(const fmt::formatter &format, + char bracket, char separator, Ctx &ctx) const { + array::ArrayView(*this).str(format, bracket, separator, ctx); } } // namespace detail } // namespace librapid // Support FMT printing #ifdef FMT_API -LIBRAPID_SIMPLE_IO_IMPL(typename desc COMMA typename Functor COMMA typename... Args, - librapid::detail::Function) - +ARRAY_TYPE_FMT_IML(typename desc COMMA typename Functor COMMA typename... Args, + librapid::detail::Function) LIBRAPID_SIMPLE_IO_NORANGE(typename desc COMMA typename Functor COMMA typename... Args, librapid::detail::Function) #endif // FMT_API diff --git a/librapid/include/librapid/array/linalg/arrayMultiply.hpp b/librapid/include/librapid/array/linalg/arrayMultiply.hpp index 80aeb4a4..334a10b0 100644 --- a/librapid/include/librapid/array/linalg/arrayMultiply.hpp +++ b/librapid/include/librapid/array/linalg/arrayMultiply.hpp @@ -184,10 +184,9 @@ namespace librapid { template void applyTo(array::ArrayContainer &out) const; - /// \brief String representation of the array multiplication - /// \param format Format string for each element - /// \return String representation of the array multiplication - LIBRAPID_NODISCARD std::string str(const std::string &format) const; + template + void str(const fmt::formatter &format, char bracket, char separator, + Ctx &ctx) const; private: bool m_transA; // Transpose state of A @@ -497,10 +496,10 @@ namespace librapid { template - std::string - ArrayMultiply::str( - const std::string &format) const { - return eval().str(format); + template + void ArrayMultiply::str( + const fmt::formatter &format, char bracket, char separator, Ctx &ctx) const { + eval().str(format, bracket, separator, ctx); } } // namespace linalg @@ -693,11 +692,10 @@ namespace librapid { } // namespace typetraits } // namespace librapid -LIBRAPID_SIMPLE_IO_IMPL( - typename ShapeTypeA COMMA typename StorageTypeA COMMA typename ShapeTypeB COMMA - typename StorageTypeB COMMA typename Alpha COMMA typename Beta, - librapid::linalg::ArrayMultiply< - ShapeTypeA COMMA StorageTypeA COMMA ShapeTypeB COMMA StorageTypeB COMMA Alpha COMMA Beta>) +ARRAY_TYPE_FMT_IML(typename ShapeTypeA COMMA typename StorageTypeA COMMA typename ShapeTypeB COMMA + typename StorageTypeB COMMA typename Alpha COMMA typename Beta, + librapid::linalg::ArrayMultiply) LIBRAPID_SIMPLE_IO_NORANGE( typename ShapeTypeA COMMA typename StorageTypeA COMMA typename ShapeTypeB COMMA diff --git a/librapid/include/librapid/array/linalg/linalg.hpp b/librapid/include/librapid/array/linalg/linalg.hpp index b57f2fd6..11b76e21 100644 --- a/librapid/include/librapid/array/linalg/linalg.hpp +++ b/librapid/include/librapid/array/linalg/linalg.hpp @@ -23,10 +23,11 @@ namespace librapid::typetraits { #include "transpose.hpp" +#include "level3/gemm.hpp" // Included before gemv, since gemm is used in some gemv implementations + #include "level2/gemv.hpp" #include "level3/geam.hpp" -#include "level3/gemm.hpp" #include "arrayMultiply.hpp" diff --git a/librapid/include/librapid/array/linalg/transpose.hpp b/librapid/include/librapid/array/linalg/transpose.hpp index 4d2dd7c4..c2fc9643 100644 --- a/librapid/include/librapid/array/linalg/transpose.hpp +++ b/librapid/include/librapid/array/linalg/transpose.hpp @@ -422,11 +422,11 @@ namespace librapid { } // namespace detail namespace array { - template + template class Transpose { public: - using ArrayType = T; - using BaseType = typename std::decay_t; + using ArrayType = TransposeType; + using BaseType = typename std::decay_t; using Scalar = typename typetraits::TypeInfo::Scalar; using Reference = BaseType &; using ConstReference = const BaseType &; @@ -446,7 +446,7 @@ namespace librapid { /// Create a Transpose object from an array/operation /// \param array The array to copy /// \param axes The transposition axes - Transpose(const T &array, const ShapeType &axes, Scalar alpha = Scalar(1.0)); + Transpose(const TransposeType &array, const ShapeType &axes, Scalar alpha = Scalar(1.0)); /// Copy a Transpose object Transpose(const Transpose &other) = default; @@ -457,7 +457,7 @@ namespace librapid { /// Assign another Transpose object to this one /// \param other The Transpose to assign /// \return *this; - Transpose &operator=(const Transpose &other) = default; + auto operator=(const Transpose &other) -> Transpose & = default; /// Access sub-array of this Transpose object /// \param index Array index @@ -505,7 +505,9 @@ namespace librapid { /// the given format string /// \param format Format string /// \return Stringified object - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; + template + LIBRAPID_ALWAYS_INLINE void str(const fmt::formatter &format, char bracket, + char separator, Ctx &ctx) const; private: ArrayType m_array; @@ -537,6 +539,12 @@ namespace librapid { return m_outputShape.ndim(); } + template + auto Transpose::scalar(int64_t index) const -> auto { + // TODO: This is a heinously inefficient way of doing this. Fix it. + return eval().scalar(index); + } + template auto Transpose::axes() const -> const ShapeType & { return m_axes; @@ -617,9 +625,11 @@ namespace librapid { return res; } - template - std::string Transpose::str(const std::string &format) const { - return eval().str(format); + template + template + void Transpose::str(const fmt::formatter &format, char bracket, char separator, + Ctx &ctx) const { + eval().str(format, bracket, separator, ctx); } }; // namespace array @@ -716,7 +726,7 @@ namespace librapid { // Support FMT printing #ifdef FMT_API -LIBRAPID_SIMPLE_IO_IMPL(typename T, librapid::array::Transpose) +ARRAY_TYPE_FMT_IML(typename T, librapid::array::Transpose) LIBRAPID_SIMPLE_IO_NORANGE(typename T, librapid::array::Transpose) #endif // FMT_API diff --git a/librapid/include/librapid/array/sizetype.hpp b/librapid/include/librapid/array/sizetype.hpp index 5f7f1ec7..e3a22b3b 100644 --- a/librapid/include/librapid/array/sizetype.hpp +++ b/librapid/include/librapid/array/sizetype.hpp @@ -135,9 +135,9 @@ namespace librapid { /// \return Number of elements LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T size() const; - /// Convert a Shape object into a string representation - /// \return A string representation of the Shape object - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const; + template + void str(const fmt::formatter &format, + Ctx &ctx) const; protected: T m_dims; @@ -290,14 +290,16 @@ namespace librapid { return res; } - template - std::string Shape::str(const std::string &format) const { - std::string result("("); + template + template + void Shape::str(const fmt::formatter &format, + Ctx &ctx) const { + fmt::format_to(ctx.out(), "Shape("); for (size_t i = 0; i < m_dims; ++i) { - result += fmt::format(format, m_data[i]); - if (i < m_dims - 1) result += std::string(", "); + format.format(m_data[i], ctx); + if (i != m_dims - 1) fmt::format_to(ctx.out(), ", "); } - return std::operator+(result, std::string(")")); + fmt::format_to(ctx.out(), ")"); } /// Returns true if all inputs have the same shape @@ -350,7 +352,25 @@ namespace librapid { // Support FMT printing #ifdef FMT_API -LIBRAPID_SIMPLE_IO_IMPL(typename T COMMA size_t N, librapid::Shape) +template +struct fmt::formatter> { +private: + using Base = fmt::formatter; + Base m_base; + +public: + template + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { + return m_base.parse(ctx); + } + + template + FMT_CONSTEXPR auto format(const librapid::Shape &val, FormatContext &ctx) const + -> decltype(ctx.out()) { + val.str(m_base, ctx); + return ctx.out(); + } +}; #endif // FMT_API #endif // LIBRAPID_ARRAY_SIZETYPE_HPP \ No newline at end of file diff --git a/librapid/include/librapid/autodiff/dual.hpp b/librapid/include/librapid/autodiff/dual.hpp index 1ef4a8cf..15287c3d 100644 --- a/librapid/include/librapid/autodiff/dual.hpp +++ b/librapid/include/librapid/autodiff/dual.hpp @@ -14,9 +14,6 @@ namespace librapid { template class Dual { public: - T value; - T derivative; - #if defined(LIBRAPID_IN_JITIFY) using Scalar = T; using Packet = T; @@ -27,9 +24,12 @@ namespace librapid { static constexpr uint64_t packetWidth = typetraits::TypeInfo::packetWidth; #endif + Scalar value; + Scalar derivative; + Dual() = default; - explicit Dual(T value) : value(value), derivative(T()) {} - Dual(T value, T derivative) : value(value), derivative(derivative) {} + explicit Dual(Scalar value) : value(value), derivative(Scalar()) {} + Dual(Scalar value, Scalar derivative) : value(value), derivative(derivative) {} template explicit Dual(const Dual &other) : value(other.value), derivative(other.derivative) {} @@ -39,107 +39,76 @@ namespace librapid { value(std::move(other.value)), derivative(std::move(other.derivative)) {} template - Dual &operator=(const Dual &other) { + auto operator=(const Dual &other) -> Dual & { value = other.value; derivative = other.derivative; return *this; } template - Dual &operator=(Dual &&other) { + auto operator=(Dual &&other) -> Dual & { value = std::move(other.value); derivative = std::move(other.derivative); return *this; } - static constexpr size_t size() { return typetraits::TypeInfo::packetWidth; } - - // template - // LIBRAPID_ALWAYS_INLINE void store(P *ptr) const { - // // Load the data into batches. - // auto casted = reinterpret_cast(ptr); - // - // // Compute interleaved values. - // std::array interleaved; - // for (std::size_t i = 0; i < packetWidth; ++i) { - // interleaved[2 * i] = value.get(i); - // interleaved[2 * i + 1] = derivative.get(i); - // } - // - // // Store the interleaved values back to memory. - // std::copy(interleaved.begin(), interleaved.end(), casted); - // } - - // template - // LIBRAPID_ALWAYS_INLINE void load(const P *ptr) { - // // auto casted = reinterpret_cast(ptr); - // // Vc::deinterleave(&value, &derivative, casted, Vc::Aligned); - // - // // Load the data into batches. - // auto casted = reinterpret_cast(ptr); - // - // // Compute interleaved values. - // std::array interleaved; - // std::copy(casted, casted + 2 * packetWidth, interleaved.begin()); - // - // // Store the interleaved values back to memory. - // for (std::size_t i = 0; i < packetWidth; ++i) { - // value.set(i, interleaved[2 * i]); - // derivative.set(i, interleaved[2 * i + 1]); - // } - // } - - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+=(const Dual &other) { + static constexpr auto size() -> size_t { return typetraits::TypeInfo::packetWidth; } + + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator+=(const Dual &other) -> Dual & { value += other.value; derivative += other.derivative; return *this; } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator-=(const Dual &other) { + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator-=(const Dual &other) -> Dual & { value -= other.value; derivative -= other.derivative; return *this; } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator*=(const Dual &other) { + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator*=(const Dual &other) -> Dual & { value *= other.value; derivative = derivative * other.value + value * other.derivative; return *this; } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator/=(const Dual &other) { + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator/=(const Dual &other) -> Dual & { value /= other.value; derivative = (derivative * other.value - value * other.derivative) / (other.value * other.value); return *this; } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator+=(const T &other) { + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator+=(const T &other) -> Dual & { value += other; return *this; } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator-=(const T &other) { + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator-=(const T &other) -> Dual & { value -= other; return *this; } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator*=(const T &other) { + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator*=(const T &other) -> Dual & { value *= other; derivative *= other; return *this; } - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Dual operator/=(const T &other) { + LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto operator/=(const T &other) -> Dual & { value /= other; derivative /= other; return *this; } #if !defined(LIBRAPID_IN_JITIFY) - std::string str(const std::string &format = "{}") const { - return fmt::format( - "Dual({}, {})", fmt::format(format, value), fmt::format(format, derivative)); + template + void str(const fmt::formatter &format, Ctx &ctx) const { + fmt::format_to(ctx.out(), "Dual("); + format.format(value, ctx); + fmt::format_to(ctx.out(), ", "); + format.format(derivative, ctx); + fmt::format_to(ctx.out(), ")"); } #endif // !defined(LIBRAPID_IN_JITIFY) }; @@ -471,7 +440,28 @@ namespace jitify::reflection::detail { #endif // LIBRAPID_HAS_CUDA #ifdef FMT_API -LIBRAPID_SIMPLE_IO_IMPL(typename T, librapid::Dual) + +template +struct fmt::formatter, Char> { +private: + using Type = librapid::Dual; + using Scalar = typename Type::Scalar; + using Base = fmt::formatter; + Base m_base; + +public: + template + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { + return m_base.parse(ctx); + } + + template + FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) { + val.str(m_base, ctx); + return ctx.out(); + } +}; + #endif // FMT_API #endif // LIBRAPID_AUTODIFF_DUAL \ No newline at end of file diff --git a/librapid/include/librapid/cuda/cudaStorage.hpp b/librapid/include/librapid/cuda/cudaStorage.hpp index 70146bb5..72c2713a 100644 --- a/librapid/include/librapid/cuda/cudaStorage.hpp +++ b/librapid/include/librapid/cuda/cudaStorage.hpp @@ -120,8 +120,9 @@ namespace librapid { return static_cast(get()); } - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const { - return fmt::format(format, get()); + template + void str(const fmt::formatter &format, Ctx &ctx) const { + format.format(get(), ctx); } private: @@ -564,7 +565,27 @@ namespace librapid { } // namespace librapid # if defined(FMT_API) -LIBRAPID_SIMPLE_IO_IMPL(typename T, librapid::detail::CudaRef) +// LIBRAPID_SIMPLE_IO_IMPL(typename T, librapid::detail::CudaRef) + +template +struct fmt::formatter, Char> { +private: + using Base = fmt::formatter; + Base m_base; + +public: + template + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { + return m_base.parse(ctx); + } + + template + FMT_CONSTEXPR auto format(const librapid::detail::CudaRef &val, FormatContext &ctx) const + -> decltype(ctx.out()) { + val.str(m_base, ctx); + return ctx.out(); + } +}; # endif // FM_API #else // Trait implementations diff --git a/librapid/include/librapid/math/complex.hpp b/librapid/include/librapid/math/complex.hpp index 49268b1b..e3d0d680 100644 --- a/librapid/include/librapid/math/complex.hpp +++ b/librapid/include/librapid/math/complex.hpp @@ -682,20 +682,20 @@ namespace librapid { return Complex(static_cast(m_val[RE]), static_cast(m_val[IM])); } - /// \brief Complex number to string - /// - /// Create a std::string representation of a complex number, formatting each component with - /// the format string - /// - /// \param format Format string - /// \return std::string - LIBRAPID_NODISCARD auto str(const std::string &format = "{}") const -> std::string { - if (!::librapid::signBit(m_val[IM])) - return "(" + fmt::format(format, m_val[RE]) + "+" + fmt::format(format, m_val[IM]) + - "j)"; - else - return "(" + fmt::format(format, m_val[RE]) + "-" + - fmt::format(format, -m_val[IM]) + "j)"; + template + void str(const fmt::formatter &format, Ctx &ctx) const { + // Complex numbers are printed as (a +- bi) + + fmt::format_to(ctx.out(), "("); + format.format(m_val[RE], ctx); + if (m_val[IM] < 0) { + fmt::format_to(ctx.out(), "-"); + format.format(-m_val[IM], ctx); + } else { + fmt::format_to(ctx.out(), "+"); + format.format(m_val[IM], ctx); + } + fmt::format_to(ctx.out(), "i)"); } protected: @@ -2079,7 +2079,26 @@ namespace librapid { // Support FMT printing #ifdef FMT_API -LIBRAPID_SIMPLE_IO_IMPL(typename Scalar, librapid::Complex) +template +struct fmt::formatter, Char> { +private: + using Type = librapid::Complex; + using Scalar = typename Type::Scalar; + using Base = fmt::formatter; + Base m_base; + +public: + template + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { + return m_base.parse(ctx); + } + + template + FMT_CONSTEXPR auto format(const Type &val, FormatContext &ctx) const -> decltype(ctx.out()) { + val.str(m_base, ctx); + return ctx.out(); + } +}; #endif // FMT_API #ifdef USE_X86_X64_INTRINSICS diff --git a/librapid/include/librapid/math/multiprec.hpp b/librapid/include/librapid/math/multiprec.hpp index e51bf95e..fe6b14bd 100644 --- a/librapid/include/librapid/math/multiprec.hpp +++ b/librapid/include/librapid/math/multiprec.hpp @@ -654,136 +654,117 @@ namespace librapid { // Provide {fmt} printing capabilities # ifdef FMT_API -template<> -struct fmt::formatter { - detail::dynamic_format_specs specs_; +template +struct fmt::formatter { + detail::dynamic_format_specs specs_; template constexpr auto parse(ParseContext &ctx) { - auto type = ::fmt::detail::type_constant::value; + auto type = ::fmt::detail::type_constant::value; auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); return end; } template - inline auto format(const mpz_class &num, FormatContext &ctx) { - try { - std::stringstream ss; - ss << std::fixed; - ss.precision(specs_.precision < 0 ? 10 : specs_.precision); - ss << num; - return fmt::format_to(ctx.out(), ss.str()); - } catch (std::exception &e) { - return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); - } + inline auto format(const librapid::mpz &num, FormatContext &ctx) const { + std::stringstream ss; + ss << std::fixed; + ss.precision(specs_.precision < 0 ? 10 : specs_.precision); + ss << num; + return fmt::format_to(ctx.out(), "{}", ss.str()); } }; -template<> -struct fmt::formatter { - detail::dynamic_format_specs specs_; +template +struct fmt::formatter { + detail::dynamic_format_specs specs_; template constexpr auto parse(ParseContext &ctx) { - auto type = ::fmt::detail::type_constant::value; + auto type = ::fmt::detail::type_constant::value; auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); return end; } template - inline auto format(const mpf_class &num, FormatContext &ctx) { - try { - if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); - - std::stringstream ss; - ss << std::fixed; - ss.precision(specs_.precision); - ss << num; - return fmt::format_to(ctx.out(), ss.str()); - } catch (std::exception &e) { - return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); - } + inline auto format(const librapid::mpf &num, FormatContext &ctx) const { + if (specs_.precision < 1) return fmt::format_to(ctx.out(), "{}", librapid::str(num)); + + std::stringstream ss; + ss << std::fixed; + ss.precision(specs_.precision); + ss << num; + return fmt::format_to(ctx.out(), "{}", ss.str()); } }; -template -struct fmt::formatter<__gmp_expr> { - detail::dynamic_format_specs specs_; +template +struct fmt::formatter<__gmp_expr, Char> { + detail::dynamic_format_specs specs_; template constexpr auto parse(ParseContext &ctx) { - auto type = ::fmt::detail::type_constant::value; + auto type = ::fmt::detail::type_constant::value; auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); return end; } template - inline auto format(const __gmp_expr &num, FormatContext &ctx) { - try { - if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); - - std::stringstream ss; - ss << std::fixed; - ss.precision(specs_.precision); - ss << num; - return fmt::format_to(ctx.out(), ss.str()); - } catch (std::exception &e) { - return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); - } + inline auto format(const __gmp_expr &num, FormatContext &ctx) const { + if (specs_.precision < 1) + return fmt::format_to(ctx.out(), "{}", librapid::str(librapid::mpf(num))); + + std::stringstream ss; + ss << std::fixed; + ss.precision(specs_.precision); + ss << librapid::mpf(num); + return fmt::format_to(ctx.out(), "{}", ss.str()); } }; -template<> -struct fmt::formatter { - detail::dynamic_format_specs specs_; +template +struct fmt::formatter { + detail::dynamic_format_specs specs_; template - constexpr auto parse(ParseContext &ctx) { - auto type = ::fmt::detail::type_constant::value; - auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); + FMT_CONSTEXPR auto parse(ParseContext &ctx) { + auto type = detail::type_constant::value; + auto end = detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); return end; } template - inline auto format(const mpq_class &num, FormatContext &ctx) { - try { - if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); - - std::stringstream ss; - ss << std::fixed; - ss.precision(specs_.precision); - ss << num; - return fmt::format_to(ctx.out(), ss.str()); - } catch (std::exception &e) { - return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); - } + auto format(const librapid::mpq &num, FormatContext &ctx) const { + if (specs_.precision < 1) return fmt::format_to(ctx.out(), "{}", librapid::str(num)); + + std::stringstream ss; + ss << std::fixed; + ss.precision(specs_.precision); + ss << librapid::mpf(num); + return fmt::format_to(ctx.out(), "{}", ss.str()); } }; -template<> -struct fmt::formatter { - detail::dynamic_format_specs specs_; +template +struct fmt::formatter { + detail::dynamic_format_specs specs_; template - constexpr auto parse(ParseContext &ctx) { - auto type = ::fmt::detail::type_constant::value; - auto end = ::fmt::detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); + FMT_CONSTEXPR auto parse(ParseContext &ctx) { + auto type = detail::type_constant::value; + auto end = detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); return end; } template - inline auto format(const librapid::mpfr &num, FormatContext &ctx) { - try { - if (specs_.precision < 1) return fmt::format_to(ctx.out(), librapid::str(num)); - - std::stringstream ss; - ss << std::fixed; - ss.precision(specs_.precision); - ss << num; - return fmt::format_to(ctx.out(), ss.str()); - } catch (std::exception &e) { - return fmt::format_to(ctx.out(), fmt::format("Format Error: {}", e.what())); - } + auto format(const librapid::mpfr &num, FormatContext &ctx) const { + if (specs_.precision < 1) return fmt::format_to(ctx.out(), "{}", librapid::str(num)); + + std::stringstream ss; + ss << std::fixed; + ss.precision(specs_.precision); + ss << num; + return fmt::format_to(ctx.out(), "{}", ss.str()); } }; # endif // FMT_API diff --git a/librapid/include/librapid/opencl/openclStorage.hpp b/librapid/include/librapid/opencl/openclStorage.hpp index 7305766e..b334c8f4 100644 --- a/librapid/include/librapid/opencl/openclStorage.hpp +++ b/librapid/include/librapid/opencl/openclStorage.hpp @@ -101,8 +101,9 @@ namespace librapid { return static_cast(get()); } - LIBRAPID_NODISCARD std::string str(const std::string &format = "{}") const { - return fmt::format(format, get()); + template + void str(const fmt::formatter &format, Ctx &ctx) { + format.format(get(), ctx); } private: @@ -425,6 +426,26 @@ namespace librapid { } } // namespace librapid +template +struct fmt::formatter, Char> { +private: + using Base = fmt::formatter; + Base m_base; + +public: + template + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { + return m_base.parse(ctx); + } + + template + FMT_CONSTEXPR auto format(const librapid::detail::OpenCLRef &val, FormatContext &ctx) const + -> decltype(ctx.out()) { + val.str(m_base, ctx); + return ctx.out(); + } +}; + #endif // LIBRAPID_HAS_OPENCL #endif // LIBRAPID_ARRAY_OPENCL_STORAGE_HPP \ No newline at end of file diff --git a/librapid/include/librapid/utils/time.hpp b/librapid/include/librapid/utils/time.hpp index 240cb5a7..cf9294e2 100644 --- a/librapid/include/librapid/utils/time.hpp +++ b/librapid/include/librapid/utils/time.hpp @@ -139,14 +139,19 @@ namespace librapid { } /// Print the current elapsed time of the timer - LIBRAPID_NODISCARD std::string str(const std::string &format = "{:.3f}") const { + template + void str(const fmt::formatter &formatter, Ctx &ctx) const { double tmpEnd = m_end; if (tmpEnd < 0) tmpEnd = now(); - return fmt::format( - "{}Elapsed: {} | Average: {}", - (m_name.empty() ? "" : m_name + ": "), - formatTime(tmpEnd - m_start, format), - formatTime((tmpEnd - m_start) / (double)m_iters, format)); + // return fmt::format( + // "{}Elapsed: {} | Average: {}", + // (m_name.empty() ? "" : m_name + ": "), + // formatTime(tmpEnd - m_start, format), + // formatTime((tmpEnd - m_start) / (double)m_iters, format)); + fmt::format_to(ctx.out(), "{}Elapsed: ", m_name.empty() ? "" : m_name + ": "); + formatter.format(tmpEnd - m_start, ctx); + fmt::format_to(ctx.out(), " | Average: "); + formatter.format((tmpEnd - m_start) / (double)m_iters, ctx); } private: @@ -159,6 +164,23 @@ namespace librapid { }; } // namespace librapid -LIBRAPID_SIMPLE_IO_IMPL_NO_TEMPLATE(librapid::Timer); +template +struct fmt::formatter { +public: + using Base = fmt::formatter; + Base m_base; + + template + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { + return m_base.parse(ctx); + } + + template + FMT_CONSTEXPR auto format(const librapid::Timer &val, FormatContext &ctx) const + -> decltype(ctx.out()) { + val.str(m_base, ctx); + return ctx.out(); + } +}; #endif // LIBRAPID_UTILS_TIME_HPP \ No newline at end of file diff --git a/test/test-array.cpp b/test/test-array.cpp index cb64e078..c5d2a42d 100644 --- a/test/test-array.cpp +++ b/test/test-array.cpp @@ -76,27 +76,29 @@ using CUDA = lrc::backend::CUDA; * other dimensions */ \ auto testI = \ lrc::Array::fromData(InitList({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); \ - REQUIRE(testI.str() == fmt::format("[[[{} {}]\n [{} {}]]\n\n [[{} {}]\n [{} {}]]]", \ - SCALAR(1), \ - SCALAR(2), \ - SCALAR(3), \ - SCALAR(4), \ - SCALAR(5), \ - SCALAR(6), \ - SCALAR(7), \ - SCALAR(8))); \ + REQUIRE(fmt::format("{}", testI) == \ + fmt::format("[[[{} {}]\n [{} {}]]\n\n [[{} {}]\n [{} {}]]]", \ + SCALAR(1), \ + SCALAR(2), \ + SCALAR(3), \ + SCALAR(4), \ + SCALAR(5), \ + SCALAR(6), \ + SCALAR(7), \ + SCALAR(8))); \ \ auto testJ = \ lrc::Array::fromData(Vec({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); \ - REQUIRE(testJ.str() == fmt::format("[[[{} {}]\n [{} {}]]\n\n [[{} {}]\n [{} {}]]]", \ - SCALAR(1), \ - SCALAR(2), \ - SCALAR(3), \ - SCALAR(4), \ - SCALAR(5), \ - SCALAR(6), \ - SCALAR(7), \ - SCALAR(8))); \ + REQUIRE(fmt::format("{}", testJ) == \ + fmt::format("[[[{} {}]\n [{} {}]]\n\n [[{} {}]\n [{} {}]]]", \ + SCALAR(1), \ + SCALAR(2), \ + SCALAR(3), \ + SCALAR(4), \ + SCALAR(5), \ + SCALAR(6), \ + SCALAR(7), \ + SCALAR(8))); \ } #define TEST_INDEXING(SCALAR, BACKEND) \ @@ -108,14 +110,14 @@ using CUDA = lrc::backend::CUDA; std::string index2 = fmt::format("[{} {} {}]", SCALAR(7), SCALAR(8), SCALAR(9)); \ std::string index3 = fmt::format("[{} {} {}]", SCALAR(10), SCALAR(11), SCALAR(12)); \ std::string index4 = fmt::format("[{} {} {}]", SCALAR(13), SCALAR(14), SCALAR(15)); \ - REQUIRE(testA[0].str() == index0); \ - REQUIRE(testA[1].str() == index1); \ - REQUIRE(testA[2].str() == index2); \ - REQUIRE(testA[3].str() == index3); \ - REQUIRE(testA[4].str() == index4); \ - REQUIRE(testA[0][0].str() == fmt::format("{}", SCALAR(1))); \ - REQUIRE(testA[1][1].str() == fmt::format("{}", SCALAR(5))); \ - REQUIRE(testA[2][2].str() == fmt::format("{}", SCALAR(9))); \ + REQUIRE(fmt::format("{}", testA[0]) == index0); \ + REQUIRE(fmt::format("{}", testA[1]) == index1); \ + REQUIRE(fmt::format("{}", testA[2]) == index2); \ + REQUIRE(fmt::format("{}", testA[3]) == index3); \ + REQUIRE(fmt::format("{}", testA[4]) == index4); \ + REQUIRE(fmt::format("{}", testA[0][0]) == fmt::format("{}", SCALAR(1))); \ + REQUIRE(fmt::format("{}", testA[1][1]) == fmt::format("{}", SCALAR(5))); \ + REQUIRE(fmt::format("{}", testA[2][2]) == fmt::format("{}", SCALAR(9))); \ \ testA[1][2] = 123; \ \ @@ -143,29 +145,29 @@ using CUDA = lrc::backend::CUDA; lrc::Array testA(lrc::Array::ShapeType({2, 3})); \ testA << 1, 2, 3, 4, 5, 6; \ \ - REQUIRE(testA.str() == fmt::format("[[{} {} {}]\n [{} {} {}]]", \ - SCALAR(1), \ - SCALAR(2), \ - SCALAR(3), \ - SCALAR(4), \ - SCALAR(5), \ - SCALAR(6))); \ + REQUIRE(fmt::format("{}", testA) == fmt::format("[[{} {} {}]\n [{} {} {}]]", \ + SCALAR(1), \ + SCALAR(2), \ + SCALAR(3), \ + SCALAR(4), \ + SCALAR(5), \ + SCALAR(6))); \ \ lrc::Array testB(lrc::Array::ShapeType({2, 3})); \ testB << 10, 2, 3, 4, 5, 6; \ \ - REQUIRE(testB.str() == fmt::format("[[{} {} {}]\n [ {} {} {}]]", \ - SCALAR(10), \ - SCALAR(2), \ - SCALAR(3), \ - SCALAR(4), \ - SCALAR(5), \ - SCALAR(6))); \ + REQUIRE(fmt::format("{}", testB) == fmt::format("[[{} {} {}]\n [{} {} {}]]", \ + SCALAR(10), \ + SCALAR(2), \ + SCALAR(3), \ + SCALAR(4), \ + SCALAR(5), \ + SCALAR(6))); \ \ lrc::Array testC(lrc::Array::ShapeType({2, 2, 2})); \ testC << 100, 2, 3, 4, 5, 6, 7, 8; \ - REQUIRE(testC.str() == \ - fmt::format("[[[{} {}]\n [ {} {}]]\n\n [[ {} {}]\n [ {} {}]]]", \ + REQUIRE(fmt::format("{}", testC) == \ + fmt::format("[[[{} {}]\n [{} {}]]\n\n [[{} {}]\n [{} {}]]]", \ SCALAR(100), \ SCALAR(2), \ SCALAR(3), \ diff --git a/test/test-complex.cpp b/test/test-complex.cpp index d69eb7b5..d3f67e7e 100644 --- a/test/test-complex.cpp +++ b/test/test-complex.cpp @@ -99,10 +99,10 @@ static double tolerance = 1e-5; REQUIRE(lrc::Complex(z1) == lrc::Complex(1, 2)); \ REQUIRE(lrc::Complex(z2) == lrc::Complex(3, 4)); \ \ - REQUIRE(z1.str() == fmt::format("({}+{}j)", z1.real(), z1.imag())); \ - REQUIRE(z2.str() == fmt::format("({}+{}j)", z2.real(), z2.imag())); \ - REQUIRE((-z1).str() == fmt::format("(-{}-{}j)", z1.real(), z1.imag())); \ - REQUIRE((-z2).str() == fmt::format("(-{}-{}j)", z2.real(), z2.imag())); \ + REQUIRE(fmt::format("{}", z1) == fmt::format("({}+{}i)", z1.real(), z1.imag())); \ + REQUIRE(fmt::format("{}", z2) == fmt::format("({}+{}i)", z2.real(), z2.imag())); \ + REQUIRE(fmt::format("{}", (-z1)) == fmt::format("(-{}-{}i)", z1.real(), z1.imag())); \ + REQUIRE(fmt::format("{}", (-z2)) == fmt::format("(-{}-{}i)", z2.real(), z2.imag())); \ } \ \ SECTION("Out-of-Place Arithmetic") { \ diff --git a/test/test-sizetype.cpp b/test/test-sizetype.cpp index eb6730c1..2d4f5334 100644 --- a/test/test-sizetype.cpp +++ b/test/test-sizetype.cpp @@ -7,13 +7,13 @@ namespace lrc = librapid; TEST_CASE("Test Storage", "[storage]") { lrc::Shape shape1({1, 2, 3, 4}); - REQUIRE(shape1.str() == "(1, 2, 3, 4)"); + REQUIRE(fmt::format("{}", shape1) == "Shape(1, 2, 3, 4)"); lrc::Shape zero = lrc::Shape::zeros(3); - REQUIRE(zero.str() == "(0, 0, 0)"); + REQUIRE(fmt::format("{}", zero) == "Shape(0, 0, 0)"); lrc::Shape ones = lrc::Shape::ones(3); - REQUIRE(ones.str() == "(1, 1, 1)"); + REQUIRE(fmt::format("{}", ones) == "Shape(1, 1, 1)"); REQUIRE(shape1.ndim() == 4); REQUIRE(shape1[0] == 1); From 6729a3930257831707942ac7a6fd194ea5fd3f9a Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Sat, 12 Aug 2023 22:19:13 -0700 Subject: [PATCH 21/29] Require C++23 in tests --- test/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0f259734..b5148928 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,5 +1,7 @@ include("warnings") +set(CMAKE_CXX_STANDARD 23) + function(make_test name) add_executable(test-${name} test-${name}.cpp) target_link_libraries(test-${name} PRIVATE librapid) From 9e72bebd9d5288f351bec3a6e958278c7ae34cc7 Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Sat, 12 Aug 2023 22:27:22 -0700 Subject: [PATCH 22/29] C++23? --- test/CMakeLists.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b5148928..cf381cd6 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,9 +1,13 @@ include("warnings") -set(CMAKE_CXX_STANDARD 23) +# set(CMAKE_CXX_STANDARD 23) function(make_test name) add_executable(test-${name} test-${name}.cpp) + + # Set C++ standard to 23 + target_compile_features(test-${name} PRIVATE cxx_std_23) + target_link_libraries(test-${name} PRIVATE librapid) disable_all_warnings(test-${name}) # Disable warnings for tests since they test unrealistic scenarios From 956a34957820f09a3fd1c90166fe53ff072a1104 Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Sat, 12 Aug 2023 22:36:37 -0700 Subject: [PATCH 23/29] C++23 again? --- test/CMakeLists.txt | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index cf381cd6..f7b8745e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,16 +1,13 @@ include("warnings") -# set(CMAKE_CXX_STANDARD 23) - function(make_test name) add_executable(test-${name} test-${name}.cpp) - - # Set C++ standard to 23 - target_compile_features(test-${name} PRIVATE cxx_std_23) - target_link_libraries(test-${name} PRIVATE librapid) disable_all_warnings(test-${name}) # Disable warnings for tests since they test unrealistic scenarios + target_compile_features(test-${name} PRIVATE cxx_std_23) + target_compile_options(librapid PUBLIC cxx_std_23) + message(STATUS "[ LIBRAPID ] Adding test ${name}") add_test(NAME ${name} COMMAND test-${name} -s --skip-benchmarks) From 03e045f9eb5ccb18c3462893e95479774b31e55b Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Sat, 12 Aug 2023 23:00:30 -0700 Subject: [PATCH 24/29] Why does this not set the C++ version? --- test/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f7b8745e..db659b97 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -5,8 +5,10 @@ function(make_test name) target_link_libraries(test-${name} PRIVATE librapid) disable_all_warnings(test-${name}) # Disable warnings for tests since they test unrealistic scenarios - target_compile_features(test-${name} PRIVATE cxx_std_23) + target_compile_features(test-${name} PUBLIC cxx_std_23) + set_target_properties(test-${name} PROPERTIES CXX_STANDARD 23) target_compile_options(librapid PUBLIC cxx_std_23) + set_target_properties(librapid PROPERTIES CXX_STANDARD 23) message(STATUS "[ LIBRAPID ] Adding test ${name}") add_test(NAME ${name} From 3a020abb0e892dbac74c80d0f52e995cb1dbe7c8 Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Sat, 12 Aug 2023 23:07:36 -0700 Subject: [PATCH 25/29] Might have fixed it --- CMakeLists.txt | 5 +++++ test/CMakeLists.txt | 5 ----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1907147a..005722d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,6 +3,11 @@ project(librapid) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") cmake_policy(SET CMP0077 NEW) +if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) + message(STATUS "[ LIBRAPID ] LibRapid is a top-level project. Using C++23") + set(CMAKE_CXX_STANDARD 23) +endif () + # LibRapid requires C++20 or later if (CMAKE_CXX_STANDARD LESS 20) message(FATAL_ERROR "LibRapid requires C++20 or later") diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index db659b97..0f259734 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -5,11 +5,6 @@ function(make_test name) target_link_libraries(test-${name} PRIVATE librapid) disable_all_warnings(test-${name}) # Disable warnings for tests since they test unrealistic scenarios - target_compile_features(test-${name} PUBLIC cxx_std_23) - set_target_properties(test-${name} PROPERTIES CXX_STANDARD 23) - target_compile_options(librapid PUBLIC cxx_std_23) - set_target_properties(librapid PROPERTIES CXX_STANDARD 23) - message(STATUS "[ LIBRAPID ] Adding test ${name}") add_test(NAME ${name} COMMAND test-${name} -s --skip-benchmarks) From 9c4c1e8825f81b5cebd3688b91fa20ca373d44c2 Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Sat, 12 Aug 2023 23:23:41 -0700 Subject: [PATCH 26/29] Template name change --- librapid/include/librapid/autodiff/dual.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/librapid/include/librapid/autodiff/dual.hpp b/librapid/include/librapid/autodiff/dual.hpp index 15287c3d..15b50484 100644 --- a/librapid/include/librapid/autodiff/dual.hpp +++ b/librapid/include/librapid/autodiff/dual.hpp @@ -102,8 +102,8 @@ namespace librapid { } #if !defined(LIBRAPID_IN_JITIFY) - template - void str(const fmt::formatter &format, Ctx &ctx) const { + template + void str(const fmt::formatter &format, Ctx &ctx) const { fmt::format_to(ctx.out(), "Dual("); format.format(value, ctx); fmt::format_to(ctx.out(), ", "); From adc4751638457be31a258afff48175b11cd78e4c Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Sat, 12 Aug 2023 23:44:57 -0700 Subject: [PATCH 27/29] Another template name change --- librapid/include/librapid/array/sizetype.hpp | 8 +++---- librapid/include/librapid/utils/time.hpp | 25 ++++++++++++++------ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/librapid/include/librapid/array/sizetype.hpp b/librapid/include/librapid/array/sizetype.hpp index e3a22b3b..2c7bed88 100644 --- a/librapid/include/librapid/array/sizetype.hpp +++ b/librapid/include/librapid/array/sizetype.hpp @@ -135,9 +135,8 @@ namespace librapid { /// \return Number of elements LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE T size() const; - template - void str(const fmt::formatter &format, - Ctx &ctx) const; + template + void str(const fmt::formatter &format, Ctx &ctx) const; protected: T m_dims; @@ -292,8 +291,7 @@ namespace librapid { template template - void Shape::str(const fmt::formatter &format, - Ctx &ctx) const { + void Shape::str(const fmt::formatter &format, Ctx &ctx) const { fmt::format_to(ctx.out(), "Shape("); for (size_t i = 0; i < m_dims; ++i) { format.format(m_data[i], ctx); diff --git a/librapid/include/librapid/utils/time.hpp b/librapid/include/librapid/utils/time.hpp index cf9294e2..70ef99df 100644 --- a/librapid/include/librapid/utils/time.hpp +++ b/librapid/include/librapid/utils/time.hpp @@ -45,8 +45,15 @@ namespace librapid { while (now() - start < time - sleepOffset) {} } + namespace detail { + struct FormattedTime { + double time; + std::string unit; + }; + } // namespace detail + template - std::string formatTime(double time, const std::string &format = "{:.3f}") { + detail::FormattedTime formatTime(double time, const std::string &format = "{:.3f}") { double ns = time * scale; int numUnits = 8; @@ -68,10 +75,10 @@ namespace librapid { static double divisor[] = {1000, 1000, 1000, 60, 60, 24, 365, 1e300}; for (int i = 0; i < numUnits; ++i) { // if (ns < divisor[i]) return std::operator+(fmt::format(format, ns), prefix[i]); - if (ns < divisor[i]) return fmt::vformat(format, fmt::make_format_args(ns)) + prefix[i]; + if (ns < divisor[i]) return {ns, prefix[i]}; ns /= divisor[i]; } - return fmt::format("{}ns", time * ns); + return {ns, prefix[numUnits - 1]}; } /// A timer class that can be used to measure a multitude of things. @@ -148,10 +155,14 @@ namespace librapid { // (m_name.empty() ? "" : m_name + ": "), // formatTime(tmpEnd - m_start, format), // formatTime((tmpEnd - m_start) / (double)m_iters, format)); - fmt::format_to(ctx.out(), "{}Elapsed: ", m_name.empty() ? "" : m_name + ": "); - formatter.format(tmpEnd - m_start, ctx); - fmt::format_to(ctx.out(), " | Average: "); - formatter.format((tmpEnd - m_start) / (double)m_iters, ctx); + + auto [elapsed, elapsedUnit] = formatTime(tmpEnd - m_start); + auto [average, averageUnit] = formatTime((tmpEnd - m_start) / (double)m_iters); + fmt::format_to(ctx.out(), "{}Elapsed: ", m_name); + formatter.format(elapsed, ctx); + fmt::format_to(ctx.out(), "{} | Average: ", elapsedUnit); + formatter.format(average, ctx); + fmt::format_to(ctx.out(), "{}", averageUnit); } private: From 8d0922f4a87a5b1766c1b28cf8ff7bcf1364fe27 Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Sat, 12 Aug 2023 23:51:23 -0700 Subject: [PATCH 28/29] oops --- librapid/include/librapid/array/sizetype.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/librapid/include/librapid/array/sizetype.hpp b/librapid/include/librapid/array/sizetype.hpp index 2c7bed88..e78e3a63 100644 --- a/librapid/include/librapid/array/sizetype.hpp +++ b/librapid/include/librapid/array/sizetype.hpp @@ -289,9 +289,9 @@ namespace librapid { return res; } - template - template - void Shape::str(const fmt::formatter &format, Ctx &ctx) const { + template + template + void Shape::str(const fmt::formatter &format, Ctx &ctx) const { fmt::format_to(ctx.out(), "Shape("); for (size_t i = 0; i < m_dims; ++i) { format.format(m_data[i], ctx); From 2358af8774ba144af7759b2d7f69936acdf843f2 Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Sun, 13 Aug 2023 00:05:23 -0700 Subject: [PATCH 29/29] Implement fmt formatting for half precision types --- librapid/include/librapid/math/half.hpp | 29 +++++++++++++++++++------ 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/librapid/include/librapid/math/half.hpp b/librapid/include/librapid/math/half.hpp index c5f836f3..57d4e577 100644 --- a/librapid/include/librapid/math/half.hpp +++ b/librapid/include/librapid/math/half.hpp @@ -516,8 +516,8 @@ namespace librapid { LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::float16_t data() const noexcept; LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE detail::float16_t &data() noexcept; - LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE std::string - str(const std::string &format = "{}") const; + template + void str(const fmt::formatter &formatter, Ctx &ctx) const; // static half infinity; // static half max; @@ -634,10 +634,9 @@ namespace librapid { return m_value; } - std::string half::str(const std::string &format) const { - // return fmt::vformat(format, fmt::make_wformat_args(detail::halfToFloat(m_value.m_bits))); - - return std::vformat(format, std::make_format_args(detail::halfToFloat(m_value.m_bits))); + template + void half::str(const fmt::formatter &formatter, Ctx &ctx) const { + formatter.format(static_cast(*this), ctx); } LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE half operator+(const half &lhs, @@ -759,6 +758,22 @@ namespace librapid { } // namespace typetraits } // namespace librapid -LIBRAPID_SIMPLE_IO_IMPL_NO_TEMPLATE(librapid::half); +template +struct fmt::formatter { +public: + using Base = fmt::formatter; + Base m_base; + + template + FMT_CONSTEXPR auto parse(ParseContext &ctx) -> const char * { + return m_base.parse(ctx); + } + + template + FMT_CONSTEXPR auto format(const librapid::half &h, FormatContext &ctx) -> decltype(ctx.out()) { + h.str(m_base, ctx); + return ctx.out(); + } +}; #endif // LIBRAPID_MATH_HALF_HPP \ No newline at end of file