Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
Pencilcaseman committed Sep 8, 2023
1 parent 62d0cf1 commit b737125
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions librapid/include/librapid/array/linalg/arrayMultiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,14 @@ namespace librapid {
/// \return Class of the array multiplication
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE MatmulClass matmulClass() const;

LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType calculateShape() const;

/// \brief Determine the shape of the result
/// \return Shape of the result
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const;

LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE size_t size() const;

/// \brief Determine the number of dimensions of the result
/// \return Number of dimensions of the result
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const;
Expand Down Expand Up @@ -173,6 +177,9 @@ namespace librapid {
ScalarA m_alpha; // Scaling factor for A
TypeB m_b; // Second array
ScalarB m_beta; // Scaling factor for B

ShapeType m_shape;
size_t m_size;
};

template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
Expand All @@ -182,23 +189,26 @@ namespace librapid {
TypeB &&b, Beta beta) :
m_transA(transA),
m_transB(transB), m_a(std::forward<TypeA>(a)), m_alpha(static_cast<ScalarA>(alpha)),
m_b(std::forward<TypeB>(b)), m_beta(static_cast<ScalarB>(beta)) {}
m_b(std::forward<TypeB>(b)), m_beta(static_cast<ScalarB>(beta)),
m_shape(calculateShape()), m_size(m_shape.size()) {}

template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
typename StorageTypeB, typename Alpha, typename Beta>
ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
Beta>::ArrayMultiply(TypeA &&a, TypeB &&b) :
m_transA(false),
m_transB(false), m_a(std::forward<TypeA>(a)), m_alpha(1),
m_b(std::forward<TypeB>(b)), m_beta(0) {}
m_b(std::forward<TypeB>(b)), m_beta(0), m_shape(calculateShape()),
m_size(m_shape.size()) {}

template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
typename StorageTypeB, typename Alpha, typename Beta>
ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
Beta>::ArrayMultiply(bool transA, bool transB, TypeA &&a, TypeB &&b) :
m_transA(transA),
m_transB(transB), m_a(std::forward<TypeA>(a)), m_alpha(1),
m_b(std::forward<TypeB>(b)), m_beta(0) {}
m_b(std::forward<TypeB>(b)), m_beta(0), m_shape(calculateShape()),
m_size(m_shape.size()) {}

template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
typename StorageTypeB, typename Alpha, typename Beta>
Expand Down Expand Up @@ -268,8 +278,8 @@ namespace librapid {

template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
typename StorageTypeB, typename Alpha, typename Beta>
auto ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::shape()
const -> ShapeType {
auto ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
Beta>::calculateShape() const -> ShapeType {
const auto &shapeA = m_a.shape();
const auto &shapeB = m_b.shape();
MatmulClass matmulClass = this->matmulClass();
Expand All @@ -294,6 +304,21 @@ namespace librapid {
return {1};
}

template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
typename StorageTypeB, typename Alpha, typename Beta>
auto ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::shape()
const -> ShapeType {
return m_shape;
}

template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
typename StorageTypeB, typename Alpha, typename Beta>
auto
ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::size() const
-> size_t {
return m_size;
}

template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
typename StorageTypeB, typename Alpha, typename Beta>
auto
Expand Down

0 comments on commit b737125

Please sign in to comment.