Skip to content

Commit b737125

Browse files
committed
c
1 parent 62d0cf1 commit b737125

File tree

1 file changed

+30
-5
lines changed

1 file changed

+30
-5
lines changed

librapid/include/librapid/array/linalg/arrayMultiply.hpp

+30-5
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,14 @@ namespace librapid {
106106
/// \return Class of the array multiplication
107107
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE MatmulClass matmulClass() const;
108108

109+
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType calculateShape() const;
110+
109111
/// \brief Determine the shape of the result
110112
/// \return Shape of the result
111113
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE ShapeType shape() const;
112114

115+
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE size_t size() const;
116+
113117
/// \brief Determine the number of dimensions of the result
114118
/// \return Number of dimensions of the result
115119
LIBRAPID_NODISCARD LIBRAPID_ALWAYS_INLINE int64_t ndim() const;
@@ -173,6 +177,9 @@ namespace librapid {
173177
ScalarA m_alpha; // Scaling factor for A
174178
TypeB m_b; // Second array
175179
ScalarB m_beta; // Scaling factor for B
180+
181+
ShapeType m_shape;
182+
size_t m_size;
176183
};
177184

178185
template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
@@ -182,23 +189,26 @@ namespace librapid {
182189
TypeB &&b, Beta beta) :
183190
m_transA(transA),
184191
m_transB(transB), m_a(std::forward<TypeA>(a)), m_alpha(static_cast<ScalarA>(alpha)),
185-
m_b(std::forward<TypeB>(b)), m_beta(static_cast<ScalarB>(beta)) {}
192+
m_b(std::forward<TypeB>(b)), m_beta(static_cast<ScalarB>(beta)),
193+
m_shape(calculateShape()), m_size(m_shape.size()) {}
186194

187195
template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
188196
typename StorageTypeB, typename Alpha, typename Beta>
189197
ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
190198
Beta>::ArrayMultiply(TypeA &&a, TypeB &&b) :
191199
m_transA(false),
192200
m_transB(false), m_a(std::forward<TypeA>(a)), m_alpha(1),
193-
m_b(std::forward<TypeB>(b)), m_beta(0) {}
201+
m_b(std::forward<TypeB>(b)), m_beta(0), m_shape(calculateShape()),
202+
m_size(m_shape.size()) {}
194203

195204
template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
196205
typename StorageTypeB, typename Alpha, typename Beta>
197206
ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
198207
Beta>::ArrayMultiply(bool transA, bool transB, TypeA &&a, TypeB &&b) :
199208
m_transA(transA),
200209
m_transB(transB), m_a(std::forward<TypeA>(a)), m_alpha(1),
201-
m_b(std::forward<TypeB>(b)), m_beta(0) {}
210+
m_b(std::forward<TypeB>(b)), m_beta(0), m_shape(calculateShape()),
211+
m_size(m_shape.size()) {}
202212

203213
template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
204214
typename StorageTypeB, typename Alpha, typename Beta>
@@ -268,8 +278,8 @@ namespace librapid {
268278

269279
template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
270280
typename StorageTypeB, typename Alpha, typename Beta>
271-
auto ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::shape()
272-
const -> ShapeType {
281+
auto ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha,
282+
Beta>::calculateShape() const -> ShapeType {
273283
const auto &shapeA = m_a.shape();
274284
const auto &shapeB = m_b.shape();
275285
MatmulClass matmulClass = this->matmulClass();
@@ -294,6 +304,21 @@ namespace librapid {
294304
return {1};
295305
}
296306

307+
template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
308+
typename StorageTypeB, typename Alpha, typename Beta>
309+
auto ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::shape()
310+
const -> ShapeType {
311+
return m_shape;
312+
}
313+
314+
template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
315+
typename StorageTypeB, typename Alpha, typename Beta>
316+
auto
317+
ArrayMultiply<ShapeTypeA, StorageTypeA, ShapeTypeB, StorageTypeB, Alpha, Beta>::size() const
318+
-> size_t {
319+
return m_size;
320+
}
321+
297322
template<typename ShapeTypeA, typename StorageTypeA, typename ShapeTypeB,
298323
typename StorageTypeB, typename Alpha, typename Beta>
299324
auto

0 commit comments

Comments
 (0)