Skip to content

Commit 03fdf8e

Browse files
committed
Fix matrix transpose (use unaligned loads)
1 parent ab4acb7 commit 03fdf8e

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

librapid/include/librapid/array/linalg/transpose.hpp

+21-20
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,15 @@ namespace librapid {
7373

7474
__m256 alphaVec = _mm256_set1_ps(alpha);
7575

76-
_mm256_store_ps(&out[0 * cols], _mm256_mul_ps(r0, alphaVec));
77-
_mm256_store_ps(&out[1 * cols], _mm256_mul_ps(r1, alphaVec));
78-
_mm256_store_ps(&out[2 * cols], _mm256_mul_ps(r2, alphaVec));
79-
_mm256_store_ps(&out[3 * cols], _mm256_mul_ps(r3, alphaVec));
80-
_mm256_store_ps(&out[4 * cols], _mm256_mul_ps(r4, alphaVec));
81-
_mm256_store_ps(&out[5 * cols], _mm256_mul_ps(r5, alphaVec));
82-
_mm256_store_ps(&out[6 * cols], _mm256_mul_ps(r6, alphaVec));
83-
_mm256_store_ps(&out[7 * cols], _mm256_mul_ps(r7, alphaVec));
76+
// Must store unaligned, since the indices are not guaranteed to be aligned
77+
_mm256_storeu_ps(&out[0 * cols], _mm256_mul_ps(r0, alphaVec));
78+
_mm256_storeu_ps(&out[1 * cols], _mm256_mul_ps(r1, alphaVec));
79+
_mm256_storeu_ps(&out[2 * cols], _mm256_mul_ps(r2, alphaVec));
80+
_mm256_storeu_ps(&out[3 * cols], _mm256_mul_ps(r3, alphaVec));
81+
_mm256_storeu_ps(&out[4 * cols], _mm256_mul_ps(r4, alphaVec));
82+
_mm256_storeu_ps(&out[5 * cols], _mm256_mul_ps(r5, alphaVec));
83+
_mm256_storeu_ps(&out[6 * cols], _mm256_mul_ps(r6, alphaVec));
84+
_mm256_storeu_ps(&out[7 * cols], _mm256_mul_ps(r7, alphaVec));
8485
}
8586

8687
template<typename Alpha>
@@ -134,17 +135,17 @@ namespace librapid {
134135
int64_t cols) {
135136
__m128 tmp3, tmp2, tmp1, tmp0;
136137

137-
tmp0 = _mm_shuffle_ps(_mm_load_ps(in + 0 * cols), _mm_load_ps(in + 1 * cols), 0x44);
138-
tmp2 = _mm_shuffle_ps(_mm_load_ps(in + 0 * cols), _mm_load_ps(in + 1 * cols), 0xEE);
139-
tmp1 = _mm_shuffle_ps(_mm_load_ps(in + 2 * cols), _mm_load_ps(in + 3 * cols), 0x44);
140-
tmp3 = _mm_shuffle_ps(_mm_load_ps(in + 2 * cols), _mm_load_ps(in + 3 * cols), 0xEE);
138+
tmp0 = _mm_shuffle_ps(_mm_loadu_ps(in + 0 * cols), _mm_loadu_ps(in + 1 * cols), 0x44);
139+
tmp2 = _mm_shuffle_ps(_mm_loadu_ps(in + 0 * cols), _mm_loadu_ps(in + 1 * cols), 0xEE);
140+
tmp1 = _mm_shuffle_ps(_mm_loadu_ps(in + 2 * cols), _mm_loadu_ps(in + 3 * cols), 0x44);
141+
tmp3 = _mm_shuffle_ps(_mm_loadu_ps(in + 2 * cols), _mm_loadu_ps(in + 3 * cols), 0xEE);
141142

142143
__m128 alphaVec = _mm_set1_ps(alpha);
143144

144-
_mm_store_ps(out + 0 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp0, tmp1, 0x88), alphaVec));
145-
_mm_store_ps(out + 1 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp0, tmp1, 0xDD), alphaVec));
146-
_mm_store_ps(out + 2 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp2, tmp3, 0x88), alphaVec));
147-
_mm_store_ps(out + 3 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp2, tmp3, 0xDD), alphaVec));
145+
_mm_storeu_ps(out + 0 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp0, tmp1, 0x88), alphaVec));
146+
_mm_storeu_ps(out + 1 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp0, tmp1, 0xDD), alphaVec));
147+
_mm_storeu_ps(out + 2 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp2, tmp3, 0x88), alphaVec));
148+
_mm_storeu_ps(out + 3 * cols, _mm_mul_ps(_mm_shuffle_ps(tmp2, tmp3, 0xDD), alphaVec));
148149
}
149150

150151
template<typename Alpha>
@@ -154,17 +155,17 @@ namespace librapid {
154155
__m128d tmp0, tmp1;
155156

156157
// Load the values from input matrix
157-
tmp0 = _mm_load_pd(in + 0 * cols);
158-
tmp1 = _mm_load_pd(in + 1 * cols);
158+
tmp0 = _mm_loadu_pd(in + 0 * cols);
159+
tmp1 = _mm_loadu_pd(in + 1 * cols);
159160

160161
// Transpose the 2x2 matrix
161162
__m128d tmp0Unpck = _mm_unpacklo_pd(tmp0, tmp1);
162163
__m128d tmp1Unpck = _mm_unpackhi_pd(tmp0, tmp1);
163164

164165
// Store the transposed values in the output matrix
165166
__m128d alphaVec = _mm_set1_pd(alpha);
166-
_mm_store_pd(out + 0 * cols, _mm_mul_pd(tmp0Unpck, alphaVec));
167-
_mm_store_pd(out + 1 * cols, _mm_mul_pd(tmp1Unpck, alphaVec));
167+
_mm_storeu_pd(out + 0 * cols, _mm_mul_pd(tmp0Unpck, alphaVec));
168+
_mm_storeu_pd(out + 1 * cols, _mm_mul_pd(tmp1Unpck, alphaVec));
168169
}
169170

170171
# endif // LIBRAPID_MSVC

0 commit comments

Comments
 (0)