Skip to content

Commit

Permalink
Remove ref-counted arrays. Slowed things down
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Aug 16, 2023
1 parent 5d79348 commit ebda6ce
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 226 deletions.
3 changes: 2 additions & 1 deletion librapid/include/librapid/array/arrayContainer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ namespace librapid {
auto subSize = res.shape().size();
Scalar *begin = m_storage.begin() + index * subSize;
Scalar *end = begin + subSize;
res.m_storage.set(StorageType_(begin, end, false));
// res.m_storage.set(StorageType_(begin, end, false));
res.m_storage = StorageType_(begin, end); // TODO: Replace with optimised array view
return res;
}
}
Expand Down
7 changes: 5 additions & 2 deletions librapid/include/librapid/array/arrayFromData.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ namespace librapid {
// return res; \
// }

// TODO: This recalculates the size of the shape at every iteration. Create a new function
// to allow this to be done on request

#define HIGHER_DIMENSIONAL_FROM_DATA(TYPE) \
template<typename Scalar, typename Backend> \
auto array::ArrayContainer<Scalar, Backend>::fromData(const TYPE &data) -> ArrayContainer { \
Expand All @@ -96,8 +99,8 @@ namespace librapid {
LIBRAPID_ASSERT(tmp[i].shape().operator==(zeroShape), \
"Arrays must have consistent shapes"); \
auto newShape = ShapeType::zeros(zeroShape.ndim() + 1); \
newShape[0] = data.size(); \
for (size_t i = 0; i < zeroShape.ndim(); ++i) { newShape[i + 1] = zeroShape[i]; } \
newShape.setAt(0, data.size()); \
for (size_t i = 0; i < zeroShape.ndim(); ++i) { newShape.setAt(i + 1, zeroShape[i]); } \
auto res = Array<Scalar, Backend>(newShape); \
for (int64_t i = 0; i < data.size(); ++i) res[i] = tmp[i]; \
return res; \
Expand Down
6 changes: 3 additions & 3 deletions librapid/include/librapid/array/arrayView.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,11 @@ namespace librapid {
auto ArrayView<T>::scalar(int64_t index) const -> auto {
if (ndim() == 0) return m_ref.scalar(m_offset);

ShapeType tmp = ShapeType::zeros(ndim());
tmp[ndim() - 1] = index % m_shape[ndim() - 1];
ShapeType tmp = ShapeType::zeros(ndim());
tmp.setAt(ndim() - 1, index % m_shape[ndim() - 1]);
for (int64_t i = ndim() - 2; i >= 0; --i) {
index /= m_shape[i + 1];
tmp[i] = index % m_shape[i];
tmp.setAt(i, index % m_shape[i]);
}
int64_t offset = 0;
for (int64_t i = 0; i < ndim(); ++i) { offset += tmp[i] * m_stride[i]; }
Expand Down
9 changes: 4 additions & 5 deletions librapid/include/librapid/array/function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ namespace librapid {
/// Constructs a Function from a functor and arguments.
/// \param functor The functor to use.
/// \param args The arguments to use.
LIBRAPID_ALWAYS_INLINE explicit Function(const Functor &functor, const Args &...args);
LIBRAPID_ALWAYS_INLINE explicit Function(Functor &&functor, Args &&...args);

/// Constructs a Function from another function.
/// \param other The Function to copy.
Expand Down Expand Up @@ -202,8 +202,8 @@ namespace librapid {
};

template<typename desc, typename Functor, typename... Args>
Function<desc, Functor, Args...>::Function(const Functor &functor, const Args &...args) :
m_functor(functor), m_args(args...) {}
Function<desc, Functor, Args...>::Function(Functor &&functor, Args &&...args) :
m_functor(std::forward<Functor>(functor)), m_args(std::forward<Args>(args)...) {}

template<typename desc, typename Functor, typename... Args>
auto Function<desc, Functor, Args...>::shape() const {
Expand All @@ -228,8 +228,7 @@ namespace librapid {
}

template<typename desc, typename Functor, typename... Args>
typename Function<desc, Functor, Args...>::Packet
Function<desc, Functor, Args...>::packet(size_t index) const {
auto Function<desc, Functor, Args...>::packet(size_t index) const -> Packet {
return packetImpl(std::make_index_sequence<sizeof...(Args)>(), index);
}

Expand Down
21 changes: 11 additions & 10 deletions librapid/include/librapid/array/linalg/transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ namespace librapid {
/// Create a Transpose object from an array/operation
/// \param array The array to copy
/// \param axes The transposition axes
Transpose(const TransposeType &array, const ShapeType &axes, Scalar alpha = Scalar(1.0));
Transpose(ArrayType &&array, const ShapeType &axes,
Scalar alpha = Scalar(1.0));

/// Copy a Transpose object
Transpose(const Transpose &other) = default;
Expand Down Expand Up @@ -517,15 +518,15 @@ namespace librapid {
Scalar m_alpha;
};

template<typename T>
Transpose<T>::Transpose(const T &array, const ShapeType &axes, Scalar alpha) :
template<typename TransposeType>
Transpose<TransposeType>::Transpose(ArrayType &&array, const ShapeType &axes, Scalar alpha) :
m_array(array), m_inputShape(array.shape()), m_axes(axes), m_alpha(alpha) {
LIBRAPID_ASSERT(m_inputShape.ndim() == m_axes.ndim(),
"Shape and axes must have the same number of dimensions");

m_outputShape = m_inputShape;
for (size_t i = 0; i < m_inputShape.ndim(); i++) {
m_outputShape[i] = m_inputShape[m_axes[i]];
m_outputShape.setAt(i, m_inputShape[m_axes[i]]);
}
}

Expand Down Expand Up @@ -619,32 +620,32 @@ namespace librapid {

template<typename T>
auto Transpose<T>::eval() const {
using NonConstArrayType = std::remove_const_t<ArrayType>;
using NonConstArrayType = std::remove_const_t<BaseType>;
NonConstArrayType res(m_outputShape);
applyTo(res);
return res;
}

template<typename TransposeType>
template<typename T, typename Char, typename Ctx>
void Transpose<TransposeType>::str(const fmt::formatter<T, Char> &format, char bracket, char separator,
Ctx &ctx) const {
void Transpose<TransposeType>::str(const fmt::formatter<T, Char> &format, char bracket,
char separator, Ctx &ctx) const {
eval().str(format, bracket, separator, ctx);
}
}; // namespace array

template<typename T, typename ShapeType = Shape<size_t, 32>,
typename std::enable_if_t<
typetraits::TypeInfo<T>::type == detail::LibRapidType::ArrayContainer, int> = 0>
typetraits::TypeInfo<std::decay_t<T>>::type == detail::LibRapidType::ArrayContainer, int> = 0>
auto transpose(T &&array, const ShapeType &axes = ShapeType()) {
// If axes is empty, transpose the array in reverse order
ShapeType newAxes = axes;
if (axes.ndim() == 0) {
newAxes = ShapeType::zeros(array.ndim());
for (size_t i = 0; i < array.ndim(); i++) { newAxes[i] = array.ndim() - i - 1; }
for (size_t i = 0; i < array.ndim(); i++) { newAxes.setAt(i, array.ndim() - i - 1); }
}

return array::Transpose(array, newAxes);
return array::Transpose<T>(std::forward<T>(array), newAxes, 1);
}

template<typename T, typename ShapeType = Shape<size_t, 32>,
Expand Down
Loading

0 comments on commit ebda6ce

Please sign in to comment.