Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding wrappers for __riscv_vget* and __riscv_vset* for non-tuple types #2345

Open
lsrcz opened this issue Oct 7, 2024 · 9 comments
Open

Comments

@lsrcz
Copy link
Contributor

lsrcz commented Oct 7, 2024

RVV provides the __riscv_vset_v_*_* and __riscv_vget_v_*_* intrinsics for not only tuple types but also for vector groups since v0.11, for example:

vint16m4_t __riscv_vset_v_i16m1_i16m4(vint16m4_t dest, size_t index, vint16m1_t value);
// __riscv_vset_v_i16m1_i16m4(dest, 2, value) copies the register `value` to the third register in the dest vector group

They are usually translated to whole register move instructions (e.g., vmv1r) and is usually efficient on most microarchitectures, and could potentially be eliminated by the register allocator when the compilers are getting more advanced.

These operations are useful when implementing concat operators like ConcatUpperLower when LMUL is not fractional. For example, the current ConcatUpperLower is implemented as follows

template <class D, class V>
HWY_API V ConcatUpperLower(D d, const V hi, const V lo) {
  const size_t half = Lanes(d) / 2;
  const V hi_down = detail::SlideDown(hi, half);
  return detail::SlideUp(lo, hi_down, half);
}

For V=vuint8m2_t, each of the two slide operations will take 4 cycles on x280. If we implement it with vget and vset, we can do

vuint8m2_t ConcatUpperLower(const vuint8m2_t hi, const vuint8m2_t lo) {
  auto v0 = __riscv_vget_v_u8m2_u8m1(lo, 0);
  return __riscv_vset_v_u8m1_u8m2(hi, 0, v0);
}

This will be translated to a program that takes 2 cycles by clang (trunk version).

ConcatUpperLower:
  vmv1r.v v8, v10
  ret

However, I have no idea on how to deal with all the macros to add the operations to highway. Any idea or instructions on this?

@jan-wassenberg
Copy link
Member

Ooh, thanks for pointing that out! I agree this would be a great improvement.
If I understand correctly, we would always be using the m4_m2 or m2_m4 variants, so halving/doubling. This is very similar to HWY_RVV_TRUNC and HWY_RVV_EXT. We can basically copy those macros, with minor updates.

Think of HWY_RVV_TRUNC as a 'callback', called automatically with all the (BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP) combinations. Its job is to expand to a wrapper function whose name is the expansion of NAME, which calls the intrinsic whose base name is OP. Note that SEWD/SEWH are double/half of the SEW, and similar for LMULD/H.

HWY_RVV_FOREACH takes care of calling it and specifying NAME and OP. The last argument says what combination of LMULs we want. For ops that halve LMUL, we cannot use mf8 because it's already the smallest, hence _TRUNC. There is also _EXT. And _VIRT is to work around the fact that intrinsics don't provide mf4 for u32 elements, because they aren't made for the A profile that guarantees 128-bit, which we depend on. _VIRT is implemented around line 220 and adds one extra callback for half of the normal size, but using the smallest available intrinsic LMUL. Because this smallest LMUL type is the same as that of another overload, we require a HWY_RVV_D argument to make the overload different. This works because the LMUL scale (-3 for mf8) is part of the Simd<> template which is part of the function signature.

So to implement Get/Set we would mostly copy lmul_ext/trunc, but with an extra size_t template argument that indicates the index. This would be added to the C++ function that HWY_RVV_TRUNC/EXT, or rather new a HWY_RVV_GET_VEC, expands to.

Does that make sense?

@lsrcz
Copy link
Contributor Author

lsrcz commented Oct 16, 2024

Thanks! Let me have a look at it and see whether I can add the GET_VEC based on TRUNC and EXT.

@lsrcz
Copy link
Contributor Author

lsrcz commented Oct 16, 2024

Regarding the VIRT, RISC-V only provide vget for non-fractional LMULs, and the best implementation for something similar for fractional LMUL is to use vslidedown to extract the upper half, like what we did in the concat operators.

What I can get is something similar to the following, though I need to change the SlideDown operator to intrinsics. Here the _GET and _GET_VIRT data types are defined for non-fractional and fractional types.

// Halves LMUL. (Use LMUL arg for the source so we can use _TRUNC.)
#define HWY_RVV_GET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \
                    MLEN, NAME, OP)                                         \
  template <size_t kIndex>                                                  \
  HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) {  \
    return __riscv_v##OP##_v_##CHAR##SEW##LMUL##_##CHAR##SEW##LMULH(        \
        v, kIndex); /* no AVL */                                            \
  }
HWY_RVV_FOREACH(HWY_RVV_GET, Get, get, _GET)
#undef HWY_RVV_GET

#define HWY_RVV_GET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH,  \
                         SHIFT, MLEN, NAME, OP)                            \
  template <size_t kIndex>                                                 \
  HWY_API HWY_RVV_V(BASE, SEW, LMULH) NAME(HWY_RVV_V(BASE, SEW, LMUL) v) { \
    if constexpr (kIndex == 0) {                                           \
      return Trunc(v);                                                     \
    } else {                                                               \
      static_assert(kIndex == 1);                                          \
      return Trunc(                                                        \
          SlideDown(v, Lanes(DFromV<HWY_RVV_V(BASE, SEW, LMUL)>{}) / 2));  \
    }                                                                      \
  }
HWY_RVV_FOREACH(HWY_RVV_GET_VIRT, Get, _, _GET_VIRT)
#undef HWY_RVV_GET_VIRT

What's your idea on this? Thanks!

@jan-wassenberg
Copy link
Member

A suggestion: rather than go through RVV_V and then DFromV, we can construct the D directly via HWY_RVV_D(BASE, SEW, N, SHIFT). And given you want half the lanes, I think we can just say SHIFT-1. Is it then possible to use SlideDown, or why would it be necessary to convert that to intrinsics?

Also, I'm surprised you'd have to create a new _GET category. Because this is functionally the same as Trunc, just with an index, it should be possible to use the existing _TRUNC category, no?

@lsrcz
Copy link
Contributor Author

lsrcz commented Oct 23, 2024

A suggestion: rather than go through RVV_V and then DFromV, we can construct the D directly via HWY_RVV_D(BASE, SEW, N, SHIFT). And given you want half the lanes, I think we can just say SHIFT-1

That makes sense and thanks for the suggestion.

Is it then possible to use SlideDown, or why would it be necessary to convert that to intrinsics?

Yes, using SlideDown directly would work well here, and there’s no need to convert to intrinsics for this case.

Also, I'm surprised you'd have to create a new _GET category. Because this is functionally the same as Trunc, just with an index, it should be possible to use the existing _TRUNC category, no?

The reason for creating a new _GET category was that vget is only provided for non-fractional types, while vlmul_trunc supports both non-fractional and fractional types. But you are right, we can still do the compile-time dispatch with if constexpr with the _TRUNC category.

I will go work on a pull request with these suggestions. Thanks!

@lsrcz
Copy link
Contributor Author

lsrcz commented Oct 23, 2024

Hi, I still have questions regarding the HWY_RVV_D operator. Here, for the N parameter, should I use HWY_LANES(HWY_RVV_T(BASE, SEW))?

Also, I realized that doing if constexpr won't work as it won't ignore undeclared identifiers for the skipped branch, so I feel that the _GET and _GET_VIRT macros are still needed.

@jan-wassenberg
Copy link
Member

Hi, I still have questions regarding the HWY_RVV_D operator. Here, for the N parameter

This is typically taken from a template argument. The idea is to allow users to specify a (power of two) cap on the max number of lanes, similar to avl. For concreteness, we can add a size_t N argument after kIndex.

hm, bummer about if constexpr. That can actually survive some compile errors, the trick is that it works like a template. The disabled branch is still lexed, but it is not instantiated (like a template). If we are able to make the code dependent on one of the template parameters, it could even work.

@lsrcz
Copy link
Contributor Author

lsrcz commented Oct 25, 2024

For these Get and Set operators, the intrinsics does not have an avl argument and does not work on partial vectors, so I used the lane number for full vectors here. If we'd like to use a partial vector, we probably have to use the slide operators to work with the higher parts even when working with non-fractional LMULs.

It might make sense to let Get and Set always deal with full, non-fractional vectors using the vget and vset intrinsics, and we do not provide virtual overloads with slide and mv in the macros.

Then we can

  1. implement LowerHalf with Trunc, and
  2. implement UpperHalf using Get when working with full vectors with non-fractional LMUL, and
  3. implement UpperHalf with SlideDown for other cases, and
  4. provide SetLowerHalf and SetUpperHalf in the detail namespace, which dispatches to slide/vmv/set based on whether we are working with full vectors and whether we have non-fractional LMUL.

Then, in the client code, for example, in the implementation of Concat* operators, we will use these wrappers exclusively and should not use the primitive Get and Set operators. This should make the code a bit cleaner, as we won't need to worry about the number of lanes or the if constexpr in the macros, and all other operators implemented with LowerHalf and UpperHalf would enjoy this optimization.

About the if constexpr, thanks for the explanation. Here, the problem was that, to my understanding, the HWY_FOREACH macros are explicitly generating all overloaded functions instead of using a template parameter for the vector types, thus the if constexpr won't be able to depend on a template parameter. But for sure, if we adopt the proposal shown above, we won't need to worry about the dispatching in the macros.

@jan-wassenberg
Copy link
Member

I agree that the if constexpr only works if we have a template argument. That could be arranged by adding a D argument, but it sounds like this is anyway not necessary.

Your plan (full vectors only for Get/Set) sounds good to me 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants