Skip to content

Commit bcd8d0b

Browse files
committed
Make tests compile. Will get them passing in a later commit
1 parent aae9f32 commit bcd8d0b

19 files changed

+2686
-1755
lines changed

cmake/ArchDetect2.cmake

+29-2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ function(check_simd_capability FLAG_GNU FLAG_MSVC NAME TEST_SOURCE VAR)
5353
endif ()
5454
endfunction()
5555

56+
check_simd_capability("-mmmx" "" "MMX" "
57+
#include <mmintrin.h>
58+
int main() {
59+
__m64 a = _mm_set_pi32(-1, 2);
60+
__m64 result = _mm_abs_pi32(a);
61+
return 0;
62+
}" SIMD_MMX)
63+
5664
# Check SSE2 (not a valid flag for MSVC)
5765
check_simd_capability("-msse2" "" "SSE2" "
5866
#include <emmintrin.h>
@@ -99,6 +107,16 @@ int main() {
99107
return 0;
100108
}" SIMD_SSE4_2)
101109

110+
check_simd_capability("-mfma" "/arch:FMA" "FMA3" "
111+
#include <immintrin.h>
112+
int main() {
113+
__m256 a = _mm256_set_ps(-1.0f, 2.0f, -3.0f, 4.0f, -1.0f, 2.0f, -3.0f, 4.0f);
114+
__m256 b = _mm256_set_ps(1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f);
115+
__m256 c = _mm256_set_ps(1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f);
116+
__m256 result = _mm256_fmadd_ps(a, b, c);
117+
return 0;
118+
}" SIMD_FMA3)
119+
102120
# Check AVX
103121
check_simd_capability("-mavx" "/arch:AVX" "AVX" "
104122
#include <immintrin.h>
@@ -171,6 +189,15 @@ int main() {
171189
return 0;
172190
}" SIMD_AVX512PF)
173191

192+
check_simd_capability("-march=armv6" "" "ARMv6" "
193+
#include <arm_neon.h>
194+
int main() {
195+
int32x2_t a = vdup_n_s32(1);
196+
int32x2_t b = vdup_n_s32(2);
197+
int32x2_t result = vadd_s32(a, b);
198+
return 0;
199+
}" SIMD_ARMv6)
200+
174201
# ARM
175202
check_simd_capability("-march=armv7-a" "" "ARMv7" "
176203
#include <arm_neon.h>
@@ -238,6 +265,6 @@ int main() {
238265

239266
if (LIBRAPID_ARCH_FOUND)
240267
message(STATUS "[ LIBRAPID ] Architecture Flags: ${LIBRAPID_ARCH_FLAGS}")
241-
else()
268+
else ()
242269
message(STATUS "[ LIBRAPID ] Architecture Flags Not Found")
243-
endif()
270+
endif ()

examples/example-cuda.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ auto main() -> int {
3131
fmt::print("Vector: \n{}\n", vector);
3232
fmt::print("Matrix dot Vector^T:\n{}\n", lrc::dot(cudaArray, lrc::transpose(vector)));
3333
#else
34-
fmt::print("OpenCL not enabled in this build of librapid\n");
34+
fmt::print("CUDA not enabled in this build of librapid\n");
3535
fmt::print("Check the documentation for more information on enabling OpenCL\n");
3636
fmt::print("https://librapid.readthedocs.io/en/latest/cmakeIntegration.html#librapid-use-cuda\n");
3737
#endif // LIBRAPID_HAS_CUDA

librapid/include/librapid/array/assignOps.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ namespace librapid {
274274
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto
275275
openCLTupleEvaluatorImpl(const detail::Function<descriptor, Functor, Args...> &function) {
276276
array::ArrayContainer<
277-
decltype(function.shape()),
277+
typename std::decay_t<decltype(function.shape())>,
278278
OpenCLStorage<typename detail::Function<descriptor, Functor, Args...>::Scalar>>
279279
result(function.shape());
280280
assign(result, function);
@@ -367,7 +367,7 @@ namespace librapid {
367367
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto
368368
cudaTupleEvaluatorImpl(const detail::Function<descriptor, Functor, Args...> &function) {
369369
array::ArrayContainer<
370-
decltype(function.shape()),
370+
typename std::decay_t<decltype(function.shape())>,
371371
CudaStorage<typename detail::Function<descriptor, Functor, Args...>::Scalar>>
372372
result(function.shape());
373373
assign(result, function);

librapid/include/librapid/array/function.hpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ namespace librapid {
155155

156156
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto size() const -> size_t;
157157

158+
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto ndim() const -> size_t;
159+
158160
/// Return the shape of the Function's result
159161
/// \return The shape of the Function's result
160162
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto shape() const -> const ShapeType &;
@@ -229,6 +231,11 @@ namespace librapid {
229231
return m_size;
230232
}
231233

234+
template<typename desc, typename Functor, typename... Args>
235+
LIBRAPID_ALWAYS_INLINE auto Function<desc, Functor, Args...>::ndim() const -> size_t {
236+
return m_shape.ndim();
237+
}
238+
232239
template<typename desc, typename Functor, typename... Args>
233240
LIBRAPID_ALWAYS_INLINE auto &Function<desc, Functor, Args...>::args() const {
234241
return m_args;
@@ -237,7 +244,7 @@ namespace librapid {
237244
template<typename desc, typename Functor, typename... Args>
238245
LIBRAPID_ALWAYS_INLINE auto
239246
Function<desc, Functor, Args...>::operator[](int64_t index) const {
240-
return array::GeneralArrayView(*this)[index];
247+
return createGeneralArrayView(*this)[index];
241248
}
242249

243250
template<typename desc, typename Functor, typename... Args>

librapid/include/librapid/array/generalArrayView.hpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ namespace librapid {
77
struct TypeInfo<array::GeneralArrayView<T, S>> {
88
static constexpr detail::LibRapidType type = detail::LibRapidType::GeneralArrayView;
99
using Scalar = typename TypeInfo<std::decay_t<T>>::Scalar;
10-
using Backend = typename TypeInfo<std::decay_t<T>>::Backend;
10+
using Backend = typename TypeInfo<std::decay_t<T>>::Backend;
11+
using ArrayViewType = std::decay_t<T>;
12+
using ShapeType = typename TypeInfo<ArrayViewType>::ShapeType;
13+
using StorageType = typename TypeInfo<ArrayViewType>::StorageType;
1114
static constexpr bool allowVectorisation = false;
1215
};
1316

@@ -21,7 +24,7 @@ namespace librapid {
2124
}
2225

2326
template<typename ShapeType, typename T>
24-
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto createGeneralArrayView(T &&array) {
27+
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE auto createGeneralArrayViewShapeModifier(T &&array) {
2528
return array::GeneralArrayView<T, ShapeType>(std::forward<T>(array));
2629
}
2730

@@ -219,7 +222,7 @@ namespace librapid {
219222
"Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}",
220223
index,
221224
m_shape[0]);
222-
auto view = createGeneralArrayView<Shape>(m_ref);
225+
auto view = createGeneralArrayViewShapeModifier<Shape>(m_ref);
223226
const auto stride = Stride(m_shape);
224227
view.setShape(m_shape.subshape(1, ndim()));
225228
if (ndim() == 1)
@@ -238,7 +241,7 @@ namespace librapid {
238241
"Index {} out of bounds in ArrayContainer::operator[] with leading dimension={}",
239242
index,
240243
m_shape[0]);
241-
auto view = createGeneralArrayView<Shape>(m_ref);
244+
auto view = createGeneralArrayViewShapeModifier<Shape>(m_ref);
242245
const auto stride = Stride(m_shape);
243246
view.setShape(m_shape.subshape(1, ndim()));
244247
if (ndim() == 1)

librapid/include/librapid/array/linalg/transpose.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ namespace librapid {
77
struct TypeInfo<array::Transpose<T>> {
88
static constexpr detail::LibRapidType type = detail::LibRapidType::ArrayFunction;
99
using Scalar = typename TypeInfo<std::decay_t<T>>::Scalar;
10-
using Backend = typename TypeInfo<std::decay_t<T>>::Backend;
10+
using Backend = typename TypeInfo<std::decay_t<T>>::Backend;
11+
using ShapeType = typename TypeInfo<std::decay_t<T>>::ShapeType;
12+
using StorageType = typename TypeInfo<std::decay_t<T>>::StorageType;
1113
static constexpr bool allowVectorisation = false;
1214
};
1315

librapid/include/librapid/array/shape.hpp

+12-4
Original file line numberDiff line numberDiff line change
@@ -807,17 +807,25 @@ namespace librapid {
807807
using Type = VectorShape;
808808
};
809809

810-
template<typename First, typename Second, typename... Rest>
811-
struct ShapeTypeHelper {
812-
using FirstResult = typename ShapeTypeHelperImpl<First, Second>::Type;
813-
using Type = typename ShapeTypeHelper<FirstResult, Rest...>::Type;
810+
template<typename... Args>
811+
struct ShapeTypeHelper;
812+
813+
template<typename First>
814+
struct ShapeTypeHelper<First> {
815+
using Type = First;
814816
};
815817

816818
template<typename First, typename Second>
817819
struct ShapeTypeHelper<First, Second> {
818820
using Type = typename ShapeTypeHelperImpl<First, Second>::Type;
819821
};
820822

823+
template<typename First, typename Second, typename... Rest>
824+
struct ShapeTypeHelper<First, Second, Rest...> {
825+
using FirstResult = typename ShapeTypeHelperImpl<First, Second>::Type;
826+
using Type = typename ShapeTypeHelper<FirstResult, Rest...>::Type;
827+
};
828+
821829
template<typename T>
822830
struct SubscriptShapeType {
823831
using Type = Shape;

librapid/include/librapid/array/storage.hpp

+17-12
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ namespace librapid {
148148
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer begin() noexcept;
149149
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE Pointer end() noexcept;
150150

151-
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator begin() const noexcept;
152-
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator end() const noexcept;
151+
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstPointer begin() const noexcept;
152+
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstPointer end() const noexcept;
153153

154154
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator cbegin() const noexcept;
155155
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ConstIterator cend() const noexcept;
@@ -391,17 +391,22 @@ namespace librapid {
391391
return ptr_;
392392
}
393393

394-
template<typename T>
395-
void fastCopy(T *__restrict dst, const T *__restrict src, size_t size) {
394+
template<typename T, typename V>
395+
void fastCopy(T *__restrict dst, const V *__restrict src, size_t size) {
396396
LIBRAPID_ASSUME(size > 0);
397397
LIBRAPID_ASSUME(dst != nullptr);
398398
LIBRAPID_ASSUME(src != nullptr);
399399

400-
if (typetraits::TriviallyDefaultConstructible<T>::value) {
401-
// Use a slightly faster memcpy if the type is trivially default constructible
402-
std::uninitialized_copy(src, src + size, dst);
400+
if constexpr (std::is_same_v<T, V>) {
401+
if constexpr (typetraits::TriviallyDefaultConstructible<T>::value) {
402+
// Use a slightly faster memcpy if the type is trivially default constructible
403+
std::uninitialized_copy(src, src + size, dst);
404+
} else {
405+
// Otherwise, use the standard copy algorithm
406+
std::copy(src, src + size, dst);
407+
}
403408
} else {
404-
// Otherwise, use the standard copy algorithm
409+
// Cannot use memcpy if the types are different
405410
std::copy(src, src + size, dst);
406411
}
407412
}
@@ -441,14 +446,14 @@ namespace librapid {
441446
template<typename V>
442447
Storage<T>::Storage(const std::initializer_list<V> &list) :
443448
m_begin(nullptr), m_size(0), m_ownsData(true) {
444-
initData(list.begin(), list.end());
449+
initData(static_cast<const V *>(list.begin()), static_cast<const V *>(list.end()));
445450
}
446451

447452
template<typename T>
448453
template<typename V>
449454
Storage<T>::Storage(const std::vector<V> &vector) :
450455
m_begin(nullptr), m_size(0), m_ownsData(true) {
451-
initData(vector.begin(), vector.end());
456+
initData(static_cast<const V *>(vector.data()), vector.size());
452457
}
453458

454459
template<typename T>
@@ -620,12 +625,12 @@ namespace librapid {
620625
}
621626

622627
template<typename T>
623-
auto Storage<T>::begin() const noexcept -> ConstIterator {
628+
auto Storage<T>::begin() const noexcept -> ConstPointer {
624629
return m_begin;
625630
}
626631

627632
template<typename T>
628-
auto Storage<T>::end() const noexcept -> ConstIterator {
633+
auto Storage<T>::end() const noexcept -> ConstPointer {
629634
return m_begin + m_size;
630635
}
631636

0 commit comments

Comments
 (0)