Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize select_bit #52

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/marisa/grimoire/intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,9 @@
#endif // MARISA_WORD_SIZE == 64
#endif // _MSC_VER

#if defined(__aarch64__)
#define MARISA_AARCH64
#include <arm_neon.h>
#endif

#endif // MARISA_GRIMOIRE_INTRIN_H_
135 changes: 101 additions & 34 deletions lib/marisa/grimoire/vector/bit-vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,33 @@ const UInt64 MASK_0F = 0x0F0F0F0F0F0F0F0FULL;
const UInt64 MASK_33 = 0x3333333333333333ULL;
const UInt64 MASK_55 = 0x5555555555555555ULL;
#endif // !defined(MARISA_X64) || !defined(MARISA_USE_SSSE3)
#if !defined(MARISA_X64) || !defined(MARISA_USE_POPCNT)
const UInt64 MASK_80 = 0x8080808080808080ULL;
#endif // !defined(MARISA_X64) || !defined(MARISA_USE_POPCNT)

// Pre-computed lookup table trick from Gog, Simon and Matthias Petri.
// "Optimized succinct data structures for massive data." Software:
// Practice and Experience 44 (2014): 1287 - 1314.
// PREFIX_SUM_OVERFLOW[i] = (0x7F - i) * MASK_01.
const UInt64 PREFIX_SUM_OVERFLOW[64] = {
0x7F * MASK_01, 0x7E * MASK_01, 0x7D * MASK_01, 0x7C * MASK_01,
0x7B * MASK_01, 0x7A * MASK_01, 0x79 * MASK_01, 0x78 * MASK_01,
0x77 * MASK_01, 0x76 * MASK_01, 0x75 * MASK_01, 0x74 * MASK_01,
0x73 * MASK_01, 0x72 * MASK_01, 0x71 * MASK_01, 0x70 * MASK_01,

0x6F * MASK_01, 0x6E * MASK_01, 0x6D * MASK_01, 0x6C * MASK_01,
0x6B * MASK_01, 0x6A * MASK_01, 0x69 * MASK_01, 0x68 * MASK_01,
0x67 * MASK_01, 0x66 * MASK_01, 0x65 * MASK_01, 0x64 * MASK_01,
0x63 * MASK_01, 0x62 * MASK_01, 0x61 * MASK_01, 0x60 * MASK_01,

0x5F * MASK_01, 0x5E * MASK_01, 0x5D * MASK_01, 0x5C * MASK_01,
0x5B * MASK_01, 0x5A * MASK_01, 0x59 * MASK_01, 0x58 * MASK_01,
0x57 * MASK_01, 0x56 * MASK_01, 0x55 * MASK_01, 0x54 * MASK_01,
0x53 * MASK_01, 0x52 * MASK_01, 0x51 * MASK_01, 0x50 * MASK_01,

0x4F * MASK_01, 0x4E * MASK_01, 0x4D * MASK_01, 0x4C * MASK_01,
0x4B * MASK_01, 0x4A * MASK_01, 0x49 * MASK_01, 0x48 * MASK_01,
0x47 * MASK_01, 0x46 * MASK_01, 0x45 * MASK_01, 0x44 * MASK_01,
0x43 * MASK_01, 0x42 * MASK_01, 0x41 * MASK_01, 0x40 * MASK_01
};

std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) {
UInt64 counts;
Expand All @@ -196,11 +220,16 @@ std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) {

counts = static_cast<UInt64>(_mm_cvtsi128_si64(
_mm_add_epi8(lower_counts, upper_counts)));
#else // defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
#elif defined(MARISA_AARCH64)
// Byte-wise popcount using CNT (plus a lot of conversion noise).
// This actually only requires NEON, not AArch64, but we are already
// in a 64-bit `#ifdef`.
counts = vget_lane_u64(vreinterpret_u64_u8(vcnt_u8(vcreate_u8(unit))), 0);
#else // defined(MARISA_AARCH64)
counts = unit - ((unit >> 1) & MASK_55);
counts = (counts & MASK_33) + ((counts >> 2) & MASK_33);
counts = (counts + (counts >> 4)) & MASK_0F;
#endif // defined(MARISA_X64) && defined(MARISA_USE_SSSE3)
#endif // defined(MARISA_AARCH64)
counts *= MASK_01;
}

Expand All @@ -213,12 +242,17 @@ std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) {
skip = (UInt8)PopCount::count(static_cast<UInt64>(_mm_cvtsi128_si64(x)));
}
#else // defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
const UInt64 x = (counts | MASK_80) - ((i + 1) * MASK_01);
const UInt64 x = (counts + PREFIX_SUM_OVERFLOW[i]) & MASK_80;
// We masked with `MASK_80`, so the first bit set is the high bit in the
// byte, therefore `num_trailing_zeros == 8 * byte_nr + 7` and the byte
// number is the number of trailing zeros divided by 8. We just shift off
// the low 7 bits, so `CTZ` gives us the `skip` value we want for the
// number of bits of `counts` to shift.
#ifdef _MSC_VER
unsigned long skip;
::_BitScanForward64(&skip, (x & MASK_80) >> 7);
::_BitScanForward64(&skip, x >> 7);
#else // _MSC_VER
const int skip = ::__builtin_ctzll((x & MASK_80) >> 7);
const int skip = ::__builtin_ctzll(x >> 7);
#endif // _MSC_VER
#endif // defined(MARISA_X64) && defined(MARISA_USE_POPCNT)

Expand All @@ -230,7 +264,8 @@ std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) {
}
#else // MARISA_WORD_SIZE == 64
#ifdef MARISA_USE_SSE2
const UInt8 POPCNT_TABLE[256] = {
// Popcount of the byte times eight.
const UInt8 POPCNT_X8_TABLE[256] = {
0, 8, 8, 16, 8, 16, 16, 24, 8, 16, 16, 24, 16, 24, 24, 32,
8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40,
Expand Down Expand Up @@ -315,7 +350,10 @@ std::size_t select_bit(std::size_t i, std::size_t bit_id,
{
__m128i x = _mm_set1_epi8((UInt8)(i + 1));
x = _mm_cmpgt_epi8(x, accumulated_counts);
skip = POPCNT_TABLE[_mm_movemask_epi8(x)];
// Since we use `_mm_movemask_epi8`, to move the top bit of every byte,
// popcount times eight gives the original popcount of `x` before the
// movemask. (`_mm_cmpgt_epi8` sets all bits in a byte to 0 or 1.)
skip = POPCNT_X8_TABLE[_mm_movemask_epi8(x)];
}

UInt8 byte;
Expand All @@ -340,33 +378,62 @@ std::size_t select_bit(std::size_t i, std::size_t bit_id,
return bit_id + SELECT_TABLE[i][byte];
}
#else // MARISA_USE_SSE2
const UInt8 POPCNT_TABLE[256] = {
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
};

std::size_t select_bit(std::size_t i, std::size_t bit_id,
UInt32 unit_lo, UInt32 unit_hi) {
UInt32 unit = unit_lo;
PopCount count(unit);
if (i >= count.lo32()) {
bit_id += 32;
i -= count.lo32();
unit = unit_hi;
count = PopCount(unit);
}

if (i < count.lo16()) {
if (i >= count.lo8()) {
bit_id += 8;
unit >>= 8;
i -= count.lo8();
}
} else if (i < count.lo24()) {
bit_id += 16;
unit >>= 16;
i -= count.lo16();
} else {
bit_id += 24;
unit >>= 24;
i -= count.lo24();
}
return bit_id + SELECT_TABLE[i][unit & 0xFF];
UInt32 next_byte = unit_lo & 0xFF;
UInt32 byte_popcount = POPCNT_TABLE[next_byte];
// Assuming the desired bit is in a random byte, branches are not
// taken 7/8 of the time, so this is branch-predictor friendly,
// unlike binary search.
if (i < byte_popcount) return bit_id + SELECT_TABLE[i][next_byte];
i -= byte_popcount;
next_byte = (unit_lo >> 8) & 0xFF;
byte_popcount = POPCNT_TABLE[next_byte];
if (i < byte_popcount) return bit_id + 8 + SELECT_TABLE[i][next_byte];
i -= byte_popcount;
next_byte = (unit_lo >> 16) & 0xFF;
byte_popcount = POPCNT_TABLE[next_byte];
if (i < byte_popcount) return bit_id + 16 + SELECT_TABLE[i][next_byte];
i -= byte_popcount;
next_byte = unit_lo >> 24;
byte_popcount = POPCNT_TABLE[next_byte];
if (i < byte_popcount) return bit_id + 24 + SELECT_TABLE[i][next_byte];
i -= byte_popcount;

next_byte = unit_hi & 0xFF;
byte_popcount = POPCNT_TABLE[next_byte];
if (i < byte_popcount) return bit_id + 32 + SELECT_TABLE[i][next_byte];
i -= byte_popcount;
next_byte = (unit_hi >> 8) & 0xFF;
byte_popcount = POPCNT_TABLE[next_byte];
if (i < byte_popcount) return bit_id + 40 + SELECT_TABLE[i][next_byte];
i -= byte_popcount;
next_byte = (unit_hi >> 16) & 0xFF;
byte_popcount = POPCNT_TABLE[next_byte];
if (i < byte_popcount) return bit_id + 48 + SELECT_TABLE[i][next_byte];
i -= byte_popcount;
next_byte = unit_hi >> 24;
// Assume `i < POPCNT_TABLE[next_byte]`.
return bit_id + 56 + SELECT_TABLE[i][next_byte];
}
#endif // MARISA_USE_SSE2

Expand Down
7 changes: 5 additions & 2 deletions lib/marisa/grimoire/vector/pop-count.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@ class PopCount {
#else // _MSC_VER
return static_cast<std::size_t>(_mm_popcnt_u64(x));
#endif // _MSC_VER
#else // defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
#elif defined(MARISA_AARCH64)
// Byte-wise popcount followed by horizontal add.
return vaddv_u8(vcnt_u8(vcreate_u8(x)));
#else // defined(MARISA_AARCH64)
return PopCount(x).lo64();
#endif // defined(MARISA_X64) && defined(MARISA_USE_POPCNT)
#endif // defined(MARISA_AARCH64)
}

private:
Expand Down