Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve batched serial trsm implementation and testing #2432

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from

Conversation

yasahi-hpc
Copy link
Contributor

@yasahi-hpc yasahi-hpc commented Nov 18, 2024

This PR aims at improving the implementation and testing of serial Trsm.

  • Moving implementation details into Impl namespace.
  • Add all the specializations of the combinations of Left/Right, Upper/Lower, Non-Trans/Trans/ConjTrans, and Unit/Non-Unit
  • Disallow to use this function if X is a rank 1 View. Use Trsv for this case
  • Covering all the unit-tests for all the combinations of Left/Right, Upper/Lower, Non-Trans/Trans/ConjTrans, and Unit/Non-Unit

As a TO DO task, we need to add a ConjTrans implementation of blocked version which requires a little more investigation

Edited 25/Nov

  • Allow again to use this function if X is a rank 1 View

@cwpearson cwpearson added the AT2-CI-APPROVAL Approve CI to run at SNL label Nov 25, 2024
@cwpearson
Copy link
Contributor

cwpearson commented Nov 25, 2024

@yasahi-hpc I'm using your PR to test a change to our CI infrastructure (related to this AT2-CI-APPROVAL label I added). Sorry if there is some noise.

@yasahi-hpc
Copy link
Contributor Author

@yasahi-hpc I'm using your PR to test a change to our CI infrastructure (related to this AT2-CI-APPROVAL label I added). Sorry if there is some noise.

Hi @cwpearson
Please feel free to try
It would be rather appreciated if I could get some GPU runs

@cwpearson
Copy link
Contributor

@yasahi-hpc looks like a legit failure in github-AT2 / spr / PR_SPR_ONEAPI202310_OPENMP_LEFT_MKLBLAS_MKLLAPACK_REL.

@yasahi-hpc
Copy link
Contributor Author

@yasahi-hpc looks like a legit failure in github-AT2 / spr / PR_SPR_ONEAPI202310_OPENMP_LEFT_MKLBLAS_MKLLAPACK_REL.

Sorry for the mistake. It should be fine now. For some reason, I need to relax the tolerance for Intel CPU build.

Copy link
Contributor

@lucbv lucbv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some clean-ups required to simplify the code and future maintenance.

#pragma unroll
#endif
for (int j = 0; j < jend; ++j)
B2[i * bs0 + j * bs1] -= Kokkos::ArithTraits<ValueType>::conj(a21[i * as0]) * b1t[j * bs1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
B2[i * bs0 + j * bs1] -= Kokkos::ArithTraits<ValueType>::conj(a21[i * as0]) * b1t[j * bs1];
B2[i * bs0 + j * bs1] -= (do_conj ? Kokkos::ArithTraits<ValueType>::conj(a21[i * as0]) * b1t[j * bs1] : a21[i * as0] * b1t[j * bs1]);

With this you can remove a lot of the code duplication introduce for the conjugate type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After checking Kokkos::ArithTraits, since we define it to be a no-op for non-complex floating points, you can just call on both complex and non-complex numbers, it will do the right thing.

@@ -83,8 +112,8 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::i
template <>
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1,
const bool use_unit_diag, [[maybe_unused]] const bool do_conj, const int m, const int n, const ScalarType alpha,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const bool use_unit_diag, [[maybe_unused]] const bool do_conj, const int m, const int n, const ScalarType alpha,
const bool use_unit_diag, const bool /* do_conj */, const int m, const int n, const ScalarType alpha,

It looks more like it is never used... in that case just omit the variable name in the signature.

if (!use_unit_diag) {
const ValueType alpha11 = A[p * as0 + p * as1];
if (!use_unit_diag) {
const ValueType alpha11 = Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const ValueType alpha11 = Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]);
const ValueType alpha11 = (do_conj ? Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]) : A[p * as0 + p * as1]);

Simplify the code.

const ValueType alpha11 = A[p * as0 + p * as1];
if (do_conj) {
if (!use_unit_diag) {
const ValueType alpha11 = Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above


#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < jend; ++j) B0[i * bs0 + j * bs1] -= a01[i * as0] * b1t[j * bs1];
for (int j = 0; j < jend; ++j)
B0[i * bs0 + j * bs1] -= Kokkos::ArithTraits<ValueType>::conj(a01[i * as0]) * b1t[j * bs1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@@ -189,8 +240,8 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::i
template <>
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1,
const bool use_unit_diag, [[maybe_unused]] const bool do_conj, const int m, const int n, const ScalarType alpha,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@cwpearson cwpearson added AT2-CI-APPROVAL Approve CI to run at SNL and removed AT2-CI-APPROVAL Approve CI to run at SNL labels Dec 3, 2024
@yasahi-hpc yasahi-hpc requested a review from lucbv December 3, 2024 19:09
@cwpearson cwpearson added AT2-CI-APPROVAL Approve CI to run at SNL and removed AT2-CI-APPROVAL Approve CI to run at SNL labels Dec 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AT2-CI-APPROVAL Approve CI to run at SNL
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants