Skip to content

Commit

Permalink
Teach the SIMD metadata group match to defer masking (#4595)
Browse files Browse the repository at this point in the history
When using a byte-encoding for matched group metadata we need to mask
down to a single bit in each matching byte to make the iteration of a
range of match indices work. In most cases, this mask can be folded into
the overall match computation, but for Arm Neon, there is avoidable
overhead from this. Instead, we can defer the mask until starting to
iterate. Doing more than one iteration is relative rare so this doesn't
accumulate much waste and makes common paths a bit faster.

For the M1 this makes the SIMD match path about 2-4% faster. This isn't
enough to catch the portable match code path on the M1 though.

For some Neoverse cores the difference here is more significant (>10%
improvement) and it makes the SIMD and scalar code paths have comparable
latency. Still not clear which is better as the latency is comparable
and beyond latency the factors are very hard to analyze -- port pressure
on different parts of the CPU, etc.

Leaving the selected code path as portable since that's so much better
on the M1, and I'm hoping to avoid different code paths for different
Arm CPUs for a while.

---------

Co-authored-by: Danila Kutenin <[email protected]>
  • Loading branch information
chandlerc and danilak-G authored Dec 20, 2024
1 parent e85125c commit 3ae968a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 58 deletions.
169 changes: 115 additions & 54 deletions common/raw_hashtable_metadata_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cstddef>
#include <cstring>
#include <iterator>
#include <type_traits>

#include "common/check.h"
#include "common/ostream.h"
Expand Down Expand Up @@ -55,11 +56,12 @@ constexpr ssize_t MaxGroupSize = 16;
//
// Some bits of the underlying value may be known-zero, which can optimize
// various operations. These can be represented as a `ZeroMask`.
template <typename BitsInputT, bool ByteEncoding, BitsInputT ZeroMask = 0>
template <typename BitsInputT, bool ByteEncodingInput, BitsInputT ZeroMask = 0>
class BitIndex
: public Printable<BitIndex<BitsInputT, ByteEncoding, ZeroMask>> {
: public Printable<BitIndex<BitsInputT, ByteEncodingInput, ZeroMask>> {
public:
using BitsT = BitsInputT;
static constexpr bool ByteEncoding = ByteEncodingInput;

BitIndex() = default;
explicit BitIndex(BitsT bits) : bits_(bits) {}
Expand Down Expand Up @@ -175,10 +177,13 @@ class BitIndex
// with byte-encoded bit indices, exactly the high bit and no other bit of each
// matching byte must be set. This is a stricter constraint than what `BitIndex`
// alone would impose on any one of the matches.
template <typename BitIndexT>
class BitIndexRange : public Printable<BitIndexRange<BitIndexT>> {
template <typename BitIndexT, BitIndexT::BitsT ByteEncodingMask = 0>
class BitIndexRange
: public Printable<BitIndexRange<BitIndexT, ByteEncodingMask>> {
public:
using BitsT = BitIndexT::BitsT;
static_assert(BitIndexT::ByteEncoding || ByteEncodingMask == 0,
"Non-byte encoding must not have a byte encoding mask.");

class Iterator
: public llvm::iterator_facade_base<Iterator, std::forward_iterator_tag,
Expand Down Expand Up @@ -208,6 +213,13 @@ class BitIndexRange : public Printable<BitIndexRange<BitIndexT>> {
auto operator++() -> Iterator& {
CARBON_DCHECK(bits_ != 0, "Must not increment past the end!");
__builtin_assume(bits_ != 0);

if constexpr (ByteEncodingMask != 0) {
// Apply an increment mask to the bits first. This is used with the byte
// encoding when the mask isn't needed until we begin incrementing.
bits_ &= ByteEncodingMask;
}

// Clears the least significant set bit, effectively stepping to the next
// match.
bits_ &= (bits_ - 1);
Expand All @@ -229,7 +241,28 @@ class BitIndexRange : public Printable<BitIndexRange<BitIndexT>> {
auto end() const -> Iterator { return Iterator(); }

friend auto operator==(BitIndexRange lhs, BitIndexRange rhs) -> bool {
return lhs.bits_ == rhs.bits_;
if constexpr (ByteEncodingMask == 0) {
// If there is no encoding mask, we can just compare the bits directly.
return lhs.bits_ == rhs.bits_;
} else {
// Otherwise, compare the initial bit indices and the masked bits.
return BitIndexT(lhs.bits_) == BitIndexT(rhs.bits_) &&
(lhs.bits_ & ByteEncodingMask) == (rhs.bits_ & ByteEncodingMask);
}
}

// Define heterogeneous equality between a masked (the current type) and
// unmasked range. Requires a non-zero mask to avoid a redundant definition
// with the homogeneous equality.
friend auto operator==(BitIndexRange lhs, BitIndexRange<BitIndexT, 0> rhs)
-> bool
requires(ByteEncodingMask != 0)
{
// For mixed masked / unmasked comparison, we make sure the initial indices
// are the same and that the masked side (LHS) is the same after masking as
// the unmasked side (RHS).
return BitIndexT(lhs.bits_) == BitIndexT(rhs.bits_) &&
(lhs.bits_ & ByteEncodingMask) == rhs.bits_;
}

auto Print(llvm::raw_ostream& out) const -> void {
Expand All @@ -240,6 +273,10 @@ class BitIndexRange : public Printable<BitIndexRange<BitIndexT>> {
explicit operator BitIndexT() const { return BitIndexT(bits_); }

private:
template <typename FriendBitIndexT,
FriendBitIndexT::BitsT FriendByteEncodingMask>
friend class BitIndexRange;

BitsT bits_ = 0;
};

Expand Down Expand Up @@ -304,6 +341,16 @@ class MetadataGroup : public Printable<MetadataGroup> {

static constexpr uint8_t PresentMask = 0b1000'0000;

// Whether to use a SIMD implementation. Even when we *support* a SIMD
// implementation, we do not always have to use it in the event that it is
// less efficient than the portable version.
static constexpr bool UseSIMD =
#if CARBON_X86_SIMD_SUPPORT
true;
#else
false;
#endif

// Some architectures make it much more efficient to build the match indices
// in a byte-encoded form rather than a bit-encoded form. This encoding
// changes verification and other aspects of our algorithms.
Expand All @@ -327,11 +374,31 @@ class MetadataGroup : public Printable<MetadataGroup> {
// the larger metadata array.
static constexpr bool FastByteClear = Size == 8;

// Most and least significant bits set.
static constexpr uint64_t MSBs = 0x8080'8080'8080'8080ULL;
static constexpr uint64_t LSBs = 0x0101'0101'0101'0101ULL;

using MatchIndex =
BitIndex<std::conditional_t<ByteEncoding, uint64_t, uint32_t>,
ByteEncoding,
/*ZeroMask=*/ByteEncoding ? 0 : (~0U << Size)>;
using MatchRange = BitIndexRange<MatchIndex>;

// Only one kind of portable matched range is needed.
using PortableMatchRange = BitIndexRange<MatchIndex>;

// We use specialized match range types for SIMD implementations to allow
// deferring the masking operation where useful. When that optimization
// doesn't apply, these will be the same type.
using SIMDMatchRange =
BitIndexRange<MatchIndex, /*ByteEncodingMask=*/ByteEncoding ? MSBs : 0>;
using SIMDMatchPresentRange = BitIndexRange<MatchIndex>;

// The public API range types can be either the portable or SIMD variations,
// selected here.
using MatchRange =
std::conditional_t<UseSIMD, SIMDMatchRange, PortableMatchRange>;
using MatchPresentRange =
std::conditional_t<UseSIMD, SIMDMatchPresentRange, PortableMatchRange>;

union {
uint8_t metadata_bytes[Size];
Expand Down Expand Up @@ -390,7 +457,7 @@ class MetadataGroup : public Printable<MetadataGroup> {

// Find all of the present bytes of metadata in this group. A range over all
// of the byte indices which are present is returned.
auto MatchPresent() const -> MatchRange;
auto MatchPresent() const -> MatchPresentRange;

// Find the first byte of the metadata group that is empty and return that
// index. There is no order or position required for which of the bytes of
Expand All @@ -412,16 +479,6 @@ class MetadataGroup : public Printable<MetadataGroup> {
friend class BenchmarkPortableMetadataGroup;
friend class BenchmarkSIMDMetadataGroup;

// Whether to use a SIMD implementation. Even when we *support* a SIMD
// implementation, we do not always have to use it in the event that it is
// less efficient than the portable version.
static constexpr bool UseSIMD =
#if CARBON_X86_SIMD_SUPPORT
true;
#else
false;
#endif

// All SIMD variants that we have an implementation for should be enabled for
// debugging. This lets us maintain a SIMD implementation even if it is not
// used due to performance reasons, and easily re-enable it if the performance
Expand All @@ -433,12 +490,19 @@ class MetadataGroup : public Printable<MetadataGroup> {
false;
#endif

// Most and least significant bits set.
static constexpr uint64_t MSBs = 0x8080'8080'8080'8080ULL;
static constexpr uint64_t LSBs = 0x0101'0101'0101'0101ULL;

using MatchBitsT = MatchIndex::BitsT;

// A helper function to allow deducing the return type from the selected arm
// of a `constexpr` ternary.
template <bool Condition, typename LeftT, typename RightT>
static auto ConstexprTernary(LeftT lhs, RightT rhs) {
if constexpr (Condition) {
return lhs;
} else {
return rhs;
}
}

static auto CompareEqual(MetadataGroup lhs, MetadataGroup rhs) -> bool;

// Functions for validating the returned matches agree with what is predicted
Expand All @@ -451,9 +515,9 @@ class MetadataGroup : public Printable<MetadataGroup> {
auto VerifyIndexBits(
MatchBitsT index_bits,
llvm::function_ref<auto(uint8_t byte)->bool> byte_match) const -> bool;
// `VerifyRangeBits` is for functions that return `MatchRange`, and so it
// validates all the bytes of `range_bits`.
auto VerifyRangeBits(
// `VerifyPortableRangeBits` is for functions that return `MatchRange`, and so
// it validates all the bytes of `range_bits`.
auto VerifyPortableRangeBits(
MatchBitsT range_bits,
llvm::function_ref<auto(uint8_t byte)->bool> byte_match) const -> bool;

Expand All @@ -472,8 +536,8 @@ class MetadataGroup : public Printable<MetadataGroup> {

auto PortableClearDeleted() -> void;

auto PortableMatch(uint8_t tag) const -> MatchRange;
auto PortableMatchPresent() const -> MatchRange;
auto PortableMatch(uint8_t tag) const -> PortableMatchRange;
auto PortableMatchPresent() const -> PortableMatchRange;

auto PortableMatchEmpty() const -> MatchIndex;
auto PortableMatchDeleted() const -> MatchIndex;
Expand All @@ -494,8 +558,8 @@ class MetadataGroup : public Printable<MetadataGroup> {

auto SIMDClearDeleted() -> void;

auto SIMDMatch(uint8_t tag) const -> MatchRange;
auto SIMDMatchPresent() const -> MatchRange;
auto SIMDMatch(uint8_t tag) const -> SIMDMatchRange;
auto SIMDMatchPresent() const -> SIMDMatchPresentRange;

auto SIMDMatchEmpty() const -> MatchIndex;
auto SIMDMatchDeleted() const -> MatchIndex;
Expand All @@ -505,7 +569,7 @@ class MetadataGroup : public Printable<MetadataGroup> {
#if CARBON_X86_SIMD_SUPPORT
// A common routine for x86 SIMD matching that can be used for matching
// present, empty, and deleted bytes with equal efficiency.
auto X86SIMDMatch(uint8_t match_byte) const -> MatchRange;
auto X86SIMDMatch(uint8_t match_byte) const -> SIMDMatchRange;
#endif
};

Expand Down Expand Up @@ -570,8 +634,8 @@ inline auto MetadataGroup::Match(uint8_t tag) const -> MatchRange {
// a present byte.
CARBON_DCHECK((tag & PresentMask) == 0, "{0:x}", tag);

MatchRange portable_result;
MatchRange simd_result;
PortableMatchRange portable_result;
SIMDMatchRange simd_result;
if constexpr (!UseSIMD || DebugSIMD) {
portable_result = PortableMatch(tag);
}
Expand All @@ -581,12 +645,13 @@ inline auto MetadataGroup::Match(uint8_t tag) const -> MatchRange {
"SIMD result '{0}' doesn't match portable result '{1}'",
simd_result, portable_result);
}
return UseSIMD ? simd_result : portable_result;
// Return whichever result we're using.
return ConstexprTernary<UseSIMD>(simd_result, portable_result);
}

inline auto MetadataGroup::MatchPresent() const -> MatchRange {
MatchRange portable_result;
MatchRange simd_result;
inline auto MetadataGroup::MatchPresent() const -> MatchPresentRange {
PortableMatchRange portable_result;
SIMDMatchPresentRange simd_result;
if constexpr (!UseSIMD || DebugSIMD) {
portable_result = PortableMatchPresent();
}
Expand All @@ -596,7 +661,8 @@ inline auto MetadataGroup::MatchPresent() const -> MatchRange {
"SIMD result '{0}' doesn't match portable result '{1}'",
simd_result, portable_result);
}
return UseSIMD ? simd_result : portable_result;
// Return whichever result we're using.
return ConstexprTernary<UseSIMD>(simd_result, portable_result);
}

inline auto MetadataGroup::MatchEmpty() const -> MatchIndex {
Expand Down Expand Up @@ -679,7 +745,7 @@ inline auto MetadataGroup::VerifyIndexBits(
return true;
}

inline auto MetadataGroup::VerifyRangeBits(
inline auto MetadataGroup::VerifyPortableRangeBits(
MatchBitsT range_bits,
llvm::function_ref<auto(uint8_t byte)->bool> byte_match) const -> bool {
for (ssize_t byte_index : llvm::seq<ssize_t>(0, Size)) {
Expand Down Expand Up @@ -814,7 +880,7 @@ inline auto MetadataGroup::PortableMatch(uint8_t tag) const -> MatchRange {

// At this point, `match_bits` has the high bit set for bytes where the
// original group byte equals `tag` plus the high bit.
CARBON_DCHECK(VerifyRangeBits(
CARBON_DCHECK(VerifyPortableRangeBits(
match_bits, [&](uint8_t byte) { return byte == (tag | PresentMask); }));
return MatchRange(match_bits);
}
Expand All @@ -841,7 +907,7 @@ inline auto MetadataGroup::PortableMatchPresent() const -> MatchRange {
// represents a present slot.
uint64_t match_bits = metadata_ints[0] & MSBs;

CARBON_DCHECK(VerifyRangeBits(
CARBON_DCHECK(VerifyPortableRangeBits(
match_bits, [&](uint8_t byte) { return (byte & PresentMask) != 0; }));
return MatchRange(match_bits);
}
Expand Down Expand Up @@ -966,20 +1032,17 @@ inline auto MetadataGroup::SIMDClearDeleted() -> void {
#endif
}

inline auto MetadataGroup::SIMDMatch(uint8_t tag) const -> MatchRange {
MatchRange result;
inline auto MetadataGroup::SIMDMatch(uint8_t tag) const -> SIMDMatchRange {
SIMDMatchRange result;
#if CARBON_NEON_SIMD_SUPPORT
// Broadcast byte we want to match to every byte in the vector.
auto match_byte_vec = vdup_n_u8(tag | PresentMask);
// Result bytes have all bits set for the bytes that match, so we have to
// clear everything but MSBs next.
auto match_byte_cmp_vec = vceq_u8(metadata_vec, match_byte_vec);
uint64_t match_bits = vreinterpret_u64_u8(match_byte_cmp_vec)[0];
// The matched range is likely to be tested for zero by the caller, and that
// test can often be folded into masking the bits with `MSBs` when we do that
// mask in the scalar domain rather than the SIMD domain. So we do the mask
// here rather than above prior to extracting the match bits.
result = MatchRange(match_bits & MSBs);
// Note that the range will lazily mask to the MSBs as part of incrementing.
result = SIMDMatchRange(match_bits);
#elif CARBON_X86_SIMD_SUPPORT
result = X86SIMDMatch(tag | PresentMask);
#else
Expand All @@ -989,20 +1052,18 @@ inline auto MetadataGroup::SIMDMatch(uint8_t tag) const -> MatchRange {
return result;
}

inline auto MetadataGroup::SIMDMatchPresent() const -> MatchRange {
MatchRange result;
inline auto MetadataGroup::SIMDMatchPresent() const -> SIMDMatchPresentRange {
SIMDMatchPresentRange result;
#if CARBON_NEON_SIMD_SUPPORT
// Just extract the metadata directly.
uint64_t match_bits = vreinterpret_u64_u8(metadata_vec)[0];
// The matched range is likely to be tested for zero by the caller, and that
// test can often be folded into masking the bits with `MSBs` when we do that
// mask in the scalar domain rather than the SIMD domain. So we do the mask
// here rather than above prior to extracting the match bits.
result = MatchRange(match_bits & MSBs);
// Even though the Neon SIMD range will do its own masking, we have to mask
// here so that `empty` is correct.
result = SIMDMatchPresentRange(match_bits & MSBs);
#elif CARBON_X86_SIMD_SUPPORT
// We arranged the byte vector so that present bytes have the high bit set,
// which this instruction extracts.
result = MatchRange(_mm_movemask_epi8(metadata_vec));
result = SIMDMatchPresentRange(_mm_movemask_epi8(metadata_vec));
#else
static_assert(!UseSIMD && !DebugSIMD, "Unimplemented SIMD operation");
#endif
Expand Down
12 changes: 8 additions & 4 deletions common/raw_hashtable_metadata_group_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ class BenchmarkPortableMetadataGroup : public MetadataGroup {

auto ClearDeleted() -> void { PortableClearDeleted(); }

auto Match(uint8_t present_byte) const -> MatchRange {
auto Match(uint8_t present_byte) const -> PortableMatchRange {
return PortableMatch(present_byte);
}
auto MatchPresent() const -> MatchRange { return PortableMatchPresent(); }
auto MatchPresent() const -> PortableMatchRange {
return PortableMatchPresent();
}

auto MatchEmpty() const -> MatchIndex { return PortableMatchEmpty(); }
auto MatchDeleted() const -> MatchIndex { return PortableMatchDeleted(); }
Expand All @@ -53,10 +55,12 @@ class BenchmarkSIMDMetadataGroup : public MetadataGroup {

auto ClearDeleted() -> void { SIMDClearDeleted(); }

auto Match(uint8_t present_byte) const -> MatchRange {
auto Match(uint8_t present_byte) const -> SIMDMatchRange {
return SIMDMatch(present_byte);
}
auto MatchPresent() const -> MatchRange { return SIMDMatchPresent(); }
auto MatchPresent() const -> SIMDMatchPresentRange {
return SIMDMatchPresent();
}

auto MatchEmpty() const -> MatchIndex { return SIMDMatchEmpty(); }
auto MatchDeleted() const -> MatchIndex { return SIMDMatchDeleted(); }
Expand Down

0 comments on commit 3ae968a

Please sign in to comment.