Skip to content

Commit

Permalink
Improve NEON transpose performance
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed May 27, 2024
1 parent d970f35 commit 47fa169
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 28 deletions.
2 changes: 2 additions & 0 deletions librapid/include/librapid/array/arrayContainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,8 @@ namespace librapid {
}
} // namespace array

// template<typename ShapeType_, typename StorageType_>

namespace detail {
template<typename T>
struct IsArrayType {
Expand Down
61 changes: 56 additions & 5 deletions librapid/include/librapid/array/linalg/transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,59 @@ namespace librapid {
_mm_storeu_pd(out + 1 * cols, _mm_mul_pd(tmp1Unpck, alphaVec));
}

# endif // LIBRAPID_MSVC
#endif // LIBRAPID_NATIVE_ARCH
# elif defined(LIBRAPID_NEON)
# define LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE 2
# define LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE 4

template<typename Alpha>
LIBRAPID_ALWAYS_INLINE void transposeFloatKernel(float *__restrict out,
float *__restrict in, Alpha alpha,
int64_t cols) {
float32x4_t r0, r1, r2, r3;
float32x4_t t0, t1, t2, t3;

r0 = vld1q_f32(&in[0 * cols]);
r1 = vld1q_f32(&in[1 * cols]);
r2 = vld1q_f32(&in[2 * cols]);
r3 = vld1q_f32(&in[3 * cols]);

t0 = vzip1q_f32(r0, r1);
t1 = vzip2q_f32(r0, r1);
t2 = vzip1q_f32(r2, r3);
t3 = vzip2q_f32(r2, r3);

r0 = vcombine_f32(vget_low_f32(t0), vget_low_f32(t2));
r1 = vcombine_f32(vget_high_f32(t0), vget_high_f32(t2));
r2 = vcombine_f32(vget_low_f32(t1), vget_low_f32(t3));
r3 = vcombine_f32(vget_high_f32(t1), vget_high_f32(t3));

float32x4_t alphaVec = vdupq_n_f32(alpha);

vst1q_f32(&out[0 * cols], vmulq_f32(r0, alphaVec));
vst1q_f32(&out[1 * cols], vmulq_f32(r1, alphaVec));
vst1q_f32(&out[2 * cols], vmulq_f32(r2, alphaVec));
vst1q_f32(&out[3 * cols], vmulq_f32(r3, alphaVec));
}

template<typename Alpha>
LIBRAPID_ALWAYS_INLINE void transposeDoubleKernel(double *__restrict out,
double *__restrict in, Alpha alpha,
int64_t cols) {
float64x2_t r0, r1;

r0 = vld1q_f64(&in[0 * cols]);
r1 = vld1q_f64(&in[1 * cols]);

float64x2_t t0 = vzip1q_f64(r0, r1);
float64x2_t t1 = vzip2q_f64(r0, r1);

float64x2_t alphaVec = vdupq_n_f64(alpha);

vst1q_f64(&out[0 * cols], vmulq_f64(t0, alphaVec));
vst1q_f64(&out[1 * cols], vmulq_f64(t1, alphaVec));
}
# endif
#endif // LIBRAPID_NATIVE_ARCH

// Ensure the kernel size is always defined, even if the above code doesn't define it
#ifndef LIBRAPID_F32_TRANSPOSE_KERNEL_SIZE
Expand Down Expand Up @@ -310,7 +361,7 @@ namespace librapid {
}
}
}
#endif // LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE > 0
#endif // LIBRAPID_F64_TRANSPOSE_KERNEL_SIZE > 0
} // namespace cpu

#if defined(LIBRAPID_HAS_OPENCL)
Expand Down Expand Up @@ -431,8 +482,8 @@ namespace librapid {
rows));
}
} // namespace cuda
#endif // LIBRAPID_HAS_CUDA
} // namespace detail
#endif // LIBRAPID_HAS_CUDA
} // namespace detail

namespace array {
template<typename TransposeType>
Expand Down
23 changes: 12 additions & 11 deletions librapid/include/librapid/core/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,17 @@
#endif

// Instruction sets
#define ARCH_AVX512_2 10
#define ARCH_AVX512 9
#define ARCH_AVX2 8
#define ARCH_AVX 7
#define ARCH_SSE4_2 6
#define ARCH_SSE4_1 5
#define ARCH_SSSE3 4
#define ARCH_SSE3 3
#define ARCH_SSE2 2
#define ARCH_SSE 1
#define ARCH_AVX512_2 11
#define ARCH_AVX512 10
#define ARCH_AVX2 9
#define ARCH_AVX 8
#define ARCH_SSE4_2 7
#define ARCH_SSE4_1 6
#define ARCH_SSSE3 5
#define ARCH_SSE3 4
#define ARCH_SSE2 3
#define ARCH_SSE 2
#define ARCH_NEON 1
#define ARCH_NONE 0

// Instruction set detection
Expand Down Expand Up @@ -485,4 +486,4 @@ namespace librapid::backend {
// Code to be run *before* main()
#include "preMain.hpp"

#endif // LIBRAPID_CORE_CONFIG_HPP
#endif // LIBRAPID_CORE_CONFIG_HPP
23 changes: 20 additions & 3 deletions librapid/include/librapid/core/log.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,26 @@ namespace librapid::assert {
(int)signature.length(),
(int)strlen("ASSERTION FAILED"));

std::string formatted = fmt::format(
// std::string formatted = fmt::format(
// "[{0:-^{6}}]\n[File {1:>{7}}]\n[Function "
// "{2:>{8}}]\n[Line {3:>{9}}]\n[Condition "
// "{4:>{10}}]\n{5}\n",
// "ASSERTION FAILED",
// filename,
// signature,
// line,
// conditionString,
// formattedMessage,
// maxLen + 14,
// maxLen + 9,
// maxLen + 5,
// maxLen + 9,
// maxLen + 4);

// fmt::print(fmt::fg(fmt::color::red), formatted);
// fmt::vprint(fmt::fg(fmt::color::red), formatted);

fmt::print(fmt::fg(fmt::color::red),
"[{0:-^{6}}]\n[File {1:>{7}}]\n[Function "
"{2:>{8}}]\n[Line {3:>{9}}]\n[Condition "
"{4:>{10}}]\n{5}\n",
Expand All @@ -35,8 +54,6 @@ namespace librapid::assert {
maxLen + 5,
maxLen + 9,
maxLen + 4);

fmt::print(fmt::fg(fmt::color::red), formatted);
}

throw RaiseType(formattedMessage);
Expand Down
3 changes: 3 additions & 0 deletions librapid/src/global.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ namespace librapid {
// OpenBLAS threading
#if defined(LIBRAPID_BLAS_OPENBLAS)
openblas_set_num_threads((int)numThreads);

#if defined(_OPENMP)
omp_set_num_threads((int)numThreads);
#endif // _OPENMP
goto_set_num_threads((int)numThreads);

setOpenBLASThreadsEnv((int)numThreads);
Expand Down
9 changes: 0 additions & 9 deletions librapid/src/openclConfigure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,6 @@ __kernel void testAddition(__global const float *a, __global const float *b, __g
cl::Program program(context, sources);
err = program.build();

// if (err != CL_SUCCESS) {
// auto format = fmt::fg(fmt::color::red) | fmt::emphasis::bold;
// fmt::print(format,
// "Error compiling test program: {}\n",
// program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(device));
// fmt::print(format, "Error Code [{}]: {}\n", err, opencl::getOpenCLErrorString(err));
// return false;
// }

// Check the build status
cl_build_status buildStatus = program.getBuildInfo<CL_PROGRAM_BUILD_STATUS>(device);

Expand Down

0 comments on commit 47fa169

Please sign in to comment.