@@ -92,7 +92,7 @@ namespace librapid {
92
92
using Scalar = typename StorageType::Scalar;
93
93
using Packet = typename typetraits::TypeInfo<Scalar>::Packet;
94
94
using Backend = typename typetraits::TypeInfo<ArrayContainer>::Backend;
95
- using Iterator = detail::ArrayIterator<GeneralArrayView<ArrayContainer>>;
95
+ using Iterator = detail::ArrayIterator<GeneralArrayView<ArrayContainer, ShapeType >>;
96
96
97
97
using DirectSubscriptType = typename detail::SubscriptType<StorageType>::Direct;
98
98
using DirectRefSubscriptType = typename detail::SubscriptType<StorageType>::Ref;
@@ -132,12 +132,16 @@ namespace librapid {
132
132
133
133
// / Constructs an array container from a shape
134
134
// / \param shape The shape of the array container
135
- LIBRAPID_ALWAYS_INLINE explicit ArrayContainer (const ShapeType &shape);
135
+ LIBRAPID_ALWAYS_INLINE explicit ArrayContainer (const Shape &shape);
136
+ LIBRAPID_ALWAYS_INLINE explicit ArrayContainer (const MatrixShape &shape);
137
+ LIBRAPID_ALWAYS_INLINE explicit ArrayContainer (const VectorShape &shape);
136
138
137
139
// / Create an array container from a shape and a scalar value. The scalar value
138
140
// / represents the value the memory is initialized with. \param shape The shape of the
139
141
// / array container \param value The value to initialize the memory with
140
- LIBRAPID_ALWAYS_INLINE ArrayContainer (const ShapeType &shape, const Scalar &value);
142
+ LIBRAPID_ALWAYS_INLINE ArrayContainer (const Shape &shape, const Scalar &value);
143
+ LIBRAPID_ALWAYS_INLINE ArrayContainer (const MatrixShape &shape, const Scalar &value);
144
+ LIBRAPID_ALWAYS_INLINE ArrayContainer (const VectorShape &shape, const Scalar &value);
141
145
142
146
// / Allows for a fixed-size array to be constructed with a fill value
143
147
// / \param value The value to fill the array with
@@ -369,7 +373,7 @@ namespace librapid {
369
373
370
374
template <typename ShapeType_, typename StorageType_>
371
375
LIBRAPID_ALWAYS_INLINE
372
- ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const ShapeType &shape) :
376
+ ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const Shape &shape) :
373
377
m_shape (shape),
374
378
m_size(shape.size()), m_storage(m_size) {
375
379
static_assert (!typetraits::IsFixedStorage<StorageType_>::value,
@@ -380,7 +384,67 @@ namespace librapid {
380
384
381
385
template <typename ShapeType_, typename StorageType_>
382
386
LIBRAPID_ALWAYS_INLINE
383
- ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const ShapeType &shape,
387
+ ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const MatrixShape &shape) :
388
+ m_shape (shape),
389
+ m_size(shape.size()), m_storage(m_size) {
390
+ static_assert (!typetraits::IsFixedStorage<StorageType_>::value,
391
+ " For a compile-time-defined shape, "
392
+ " the storage type must be "
393
+ " a FixedStorage object" );
394
+ }
395
+
396
+ template <typename ShapeType_, typename StorageType_>
397
+ LIBRAPID_ALWAYS_INLINE
398
+ ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const VectorShape &shape) :
399
+ m_shape (shape),
400
+ m_size(shape.size()), m_storage(m_size) {
401
+ static_assert (!typetraits::IsFixedStorage<StorageType_>::value,
402
+ " For a compile-time-defined shape, "
403
+ " the storage type must be "
404
+ " a FixedStorage object" );
405
+ }
406
+
407
+ template <typename ShapeType_, typename StorageType_>
408
+ LIBRAPID_ALWAYS_INLINE
409
+ ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const Shape &shape,
410
+ const Scalar &value) :
411
+ m_shape (shape),
412
+ m_size(shape.size()), m_storage(m_size, value) {
413
+ static_assert (typetraits::IsStorage<StorageType_>::value ||
414
+ typetraits::IsOpenCLStorage<StorageType_>::value ||
415
+ typetraits::IsCudaStorage<StorageType_>::value,
416
+ " For a runtime-defined shape, "
417
+ " the storage type must be "
418
+ " either a Storage or a "
419
+ " CudaStorage object" );
420
+ static_assert (!typetraits::IsFixedStorage<StorageType_>::value,
421
+ " For a compile-time-defined shape, "
422
+ " the storage type must be "
423
+ " a FixedStorage object" );
424
+ }
425
+
426
+ template <typename ShapeType_, typename StorageType_>
427
+ LIBRAPID_ALWAYS_INLINE
428
+ ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const MatrixShape &shape,
429
+ const Scalar &value) :
430
+ m_shape (shape),
431
+ m_size(shape.size()), m_storage(m_size, value) {
432
+ static_assert (typetraits::IsStorage<StorageType_>::value ||
433
+ typetraits::IsOpenCLStorage<StorageType_>::value ||
434
+ typetraits::IsCudaStorage<StorageType_>::value,
435
+ " For a runtime-defined shape, "
436
+ " the storage type must be "
437
+ " either a Storage or a "
438
+ " CudaStorage object" );
439
+ static_assert (!typetraits::IsFixedStorage<StorageType_>::value,
440
+ " For a compile-time-defined shape, "
441
+ " the storage type must be "
442
+ " a FixedStorage object" );
443
+ }
444
+
445
+ template <typename ShapeType_, typename StorageType_>
446
+ LIBRAPID_ALWAYS_INLINE
447
+ ArrayContainer<ShapeType_, StorageType_>::ArrayContainer(const VectorShape &shape,
384
448
const Scalar &value) :
385
449
m_shape (shape),
386
450
m_size(shape.size()), m_storage(m_size, value) {
@@ -525,44 +589,42 @@ namespace librapid {
525
589
index ,
526
590
m_shape[0 ]);
527
591
528
- if constexpr (typetraits::IsOpenCLStorage<StorageType_>::value) {
529
- #if defined(LIBRAPID_HAS_OPENCL)
530
- ArrayContainer res;
531
- res.m_shape = m_shape.subshape (1 , ndim ());
532
- auto subSize = res.shape ().size ();
533
- int64_t storageSize = sizeof (typename StorageType_::Scalar);
534
- cl_buffer_region region {index * subSize * storageSize, subSize * storageSize};
535
- res.m_storage =
536
- StorageType_ (m_storage.data ().createSubBuffer (
537
- StorageType_::bufferFlags, CL_BUFFER_CREATE_TYPE_REGION, ®ion),
538
- subSize,
539
- false );
540
- return res;
541
- #else
542
- LIBRAPID_ERROR (" OpenCL support not enabled" );
543
- #endif // LIBRAPID_HAS_OPENCL
544
- } else if constexpr (typetraits::IsCudaStorage<StorageType_>::value) {
545
- #if defined(LIBRAPID_HAS_CUDA)
546
- ArrayContainer res;
547
- res.m_shape = m_shape.subshape (1 , ndim ());
548
- auto subSize = res.shape ().size ();
549
- Scalar *begin = m_storage.begin ().get () + index * subSize;
550
- res.m_storage = StorageType_ (begin, subSize, false );
551
- return res;
552
- #else
553
- LIBRAPID_ERROR (" CUDA support not enabled" );
554
- #endif // LIBRAPID_HAS_CUDA
555
- } else if constexpr (typetraits::IsFixedStorage<StorageType_>::value) {
556
- return GeneralArrayView (*this )[index ];
557
- } else {
558
- ArrayContainer res;
559
- res.m_shape = m_shape.subshape (1 , ndim ());
560
- auto subSize = res.shape ().size ();
561
- Scalar *begin = m_storage.begin () + index * subSize;
562
- Scalar *end = begin + subSize;
563
- res.m_storage = StorageType_ (begin, end, false );
564
- return res;
565
- }
592
+ return createGeneralArrayView (*this )[index ];
593
+
594
+ // if constexpr (typetraits::IsOpenCLStorage<StorageType_>::value) {
595
+ // #if defined(LIBRAPID_HAS_OPENCL)
596
+ // ArrayContainer res;
597
+ // res.m_shape = m_shape.subshape(1, ndim());
598
+ // auto subSize = res.shape().size();
599
+ // int64_t storageSize = sizeof(typename StorageType_::Scalar);
600
+ // cl_buffer_region region {index * subSize * storageSize, subSize *
601
+ // storageSize}; res.m_storage =
602
+ // StorageType_(m_storage.data().createSubBuffer(
603
+ // StorageType_::bufferFlags,
604
+ // CL_BUFFER_CREATE_TYPE_REGION, ®ion), subSize,
605
+ // false); return res; #else LIBRAPID_ERROR("OpenCL support
606
+ // not enabled"); #endif // LIBRAPID_HAS_OPENCL } else if constexpr
607
+ // (typetraits::IsCudaStorage<StorageType_>::value) { #if defined(LIBRAPID_HAS_CUDA)
608
+ // ArrayContainer res;
609
+ // res.m_shape = m_shape.subshape(1, ndim());
610
+ // auto subSize = res.shape().size();
611
+ // Scalar *begin = m_storage.begin().get() + index * subSize;
612
+ // res.m_storage = StorageType_(begin, subSize, false);
613
+ // return res;
614
+ // #else
615
+ // LIBRAPID_ERROR("CUDA support not enabled");
616
+ // #endif // LIBRAPID_HAS_CUDA
617
+ // } else if constexpr (typetraits::IsFixedStorage<StorageType_>::value) {
618
+ // return GeneralArrayView(*this)[index];
619
+ // } else {
620
+ // ArrayContainer res;
621
+ // res.m_shape = m_shape.subshape(1, ndim());
622
+ // auto subSize = res.shape().size();
623
+ // Scalar *begin = m_storage.begin() + index * subSize;
624
+ // Scalar *end = begin + subSize;
625
+ // res.m_storage = StorageType_(begin, end, false);
626
+ // return res;
627
+ // }
566
628
}
567
629
568
630
template <typename ShapeType_, typename StorageType_>
@@ -574,44 +636,42 @@ namespace librapid {
574
636
index ,
575
637
m_shape[0 ]);
576
638
577
- if constexpr (typetraits::IsOpenCLStorage<StorageType_>::value) {
578
- #if defined(LIBRAPID_HAS_OPENCL)
579
- ArrayContainer res;
580
- res.m_shape = m_shape.subshape (1 , ndim ());
581
- auto subSize = res.shape ().size ();
582
- int64_t storageSize = sizeof (typename StorageType_::Scalar);
583
- cl_buffer_region region {index * subSize * storageSize, subSize * storageSize};
584
- res.m_storage =
585
- StorageType_ (m_storage.data ().createSubBuffer (
586
- StorageType_::bufferFlags, CL_BUFFER_CREATE_TYPE_REGION, ®ion),
587
- subSize,
588
- false );
589
- return res;
590
- #else
591
- LIBRAPID_ERROR (" OpenCL support not enabled" );
592
- #endif // LIBRAPID_HAS_OPENCL
593
- } else if constexpr (typetraits::IsCudaStorage<StorageType_>::value) {
594
- #if defined(LIBRAPID_HAS_CUDA)
595
- ArrayContainer res;
596
- res.m_shape = m_shape.subshape (1 , ndim ());
597
- auto subSize = res.shape ().size ();
598
- Scalar *begin = m_storage.begin ().get () + index * subSize;
599
- res.m_storage = StorageType_ (begin, subSize, false );
600
- return res;
601
- #else
602
- LIBRAPID_ERROR (" CUDA support not enabled" );
603
- #endif // LIBRAPID_HAS_CUDA
604
- } else if constexpr (typetraits::IsFixedStorage<StorageType_>::value) {
605
- return GeneralArrayView (*this )[index ];
606
- } else {
607
- ArrayContainer res;
608
- res.m_shape = m_shape.subshape (1 , ndim ());
609
- auto subSize = res.shape ().size ();
610
- Scalar *begin = m_storage.begin () + index * subSize;
611
- Scalar *end = begin + subSize;
612
- res.m_storage = StorageType_ (begin, end, false );
613
- return res;
614
- }
639
+ return createGeneralArrayView (*this )[index ];
640
+
641
+ // if constexpr (typetraits::IsOpenCLStorage<StorageType_>::value) {
642
+ // #if defined(LIBRAPID_HAS_OPENCL)
643
+ // ArrayContainer res;
644
+ // res.m_shape = m_shape.subshape(1, ndim());
645
+ // auto subSize = res.shape().size();
646
+ // int64_t storageSize = sizeof(typename StorageType_::Scalar);
647
+ // cl_buffer_region region {index * subSize * storageSize, subSize *
648
+ // storageSize}; res.m_storage =
649
+ // StorageType_(m_storage.data().createSubBuffer(
650
+ // StorageType_::bufferFlags,
651
+ // CL_BUFFER_CREATE_TYPE_REGION, ®ion), subSize,
652
+ // false); return res; #else LIBRAPID_ERROR("OpenCL support
653
+ // not enabled"); #endif // LIBRAPID_HAS_OPENCL } else if constexpr
654
+ // (typetraits::IsCudaStorage<StorageType_>::value) { #if defined(LIBRAPID_HAS_CUDA)
655
+ // ArrayContainer res;
656
+ // res.m_shape = m_shape.subshape(1, ndim());
657
+ // auto subSize = res.shape().size();
658
+ // Scalar *begin = m_storage.begin().get() + index * subSize;
659
+ // res.m_storage = StorageType_(begin, subSize, false);
660
+ // return res;
661
+ // #else
662
+ // LIBRAPID_ERROR("CUDA support not enabled");
663
+ // #endif // LIBRAPID_HAS_CUDA
664
+ // } else if constexpr (typetraits::IsFixedStorage<StorageType_>::value) {
665
+ // return GeneralArrayView(*this)[index];
666
+ // } else {
667
+ // ArrayContainer res;
668
+ // res.m_shape = m_shape.subshape(1, ndim());
669
+ // auto subSize = res.shape().size();
670
+ // Scalar *begin = m_storage.begin() + index * subSize;
671
+ // Scalar *end = begin + subSize;
672
+ // res.m_storage = StorageType_(begin, end, false);
673
+ // return res;
674
+ // }
615
675
}
616
676
617
677
template <typename ShapeType_, typename StorageType_>
@@ -854,8 +914,8 @@ namespace librapid {
854
914
static constexpr bool val = false ;
855
915
};
856
916
857
- template <typename T>
858
- struct IsArrayType <ArrayRef<T>> {
917
+ template <typename T, typename V >
918
+ struct IsArrayType <ArrayRef<T, V >> {
859
919
static constexpr bool val = true ;
860
920
};
861
921
@@ -864,8 +924,8 @@ namespace librapid {
864
924
static constexpr bool val = true ;
865
925
};
866
926
867
- template <typename T>
868
- struct IsArrayType <array::GeneralArrayView<T>> {
927
+ template <typename T, typename S >
928
+ struct IsArrayType <array::GeneralArrayView<T, S >> {
869
929
static constexpr bool val = true ;
870
930
};
871
931
0 commit comments