Skip to content

Commit

Permalink
Optimize backward brace match with SSE2/AVX2, issue #911.
Browse files Browse the repository at this point in the history
  • Loading branch information
zufuliu committed Nov 26, 2024
1 parent a632725 commit b462383
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 36 deletions.
102 changes: 96 additions & 6 deletions scintilla/src/Document.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -2962,10 +2962,10 @@ Sci::Position Document::BraceMatch(Sci::Position position, Sci::Position /*maxRe
int depth = 1;
if (chBrace <= asciiBackwardSafeChar && IsValidIndex(position + 64*direction, length)) {
#if NP2_USE_AVX2
const SplitView cbView = cb.AllView();
const __m256i mmBrace = mm256_set1_epi8(chBrace);
const __m256i mmSeek = mm256_set1_epi8(chSeek);
if (direction >= 0) {
const SplitView cbView = cb.AllView();
const __m256i mmBrace = mm256_set1_epi8(chBrace);
const __m256i mmSeek = mm256_set1_epi8(chSeek);
const Sci::Position maxPos = length - 2*sizeof(__m256i);
const Sci::Position segmentEndPos = std::min<Sci::Position>(maxPos, cbView.length1 - 1);
do {
Expand Down Expand Up @@ -3010,12 +3010,57 @@ Sci::Position Document::BraceMatch(Sci::Position position, Sci::Position /*maxRe
}
} while (position <= maxPos);
}
else {
constexpr Sci::Position minPos = 2*sizeof(__m256i) - 1;
const Sci::Position segmentEndPos = std::max<Sci::Position>(minPos, cbView.length1);
do {
const Sci::Position segmentLength = cbView.length1;
const bool scanFirst = IsValidIndex(position, segmentLength);
const Sci::Position endPos = scanFirst ? minPos : segmentEndPos;
const char * const segment = scanFirst ? cbView.segment1 : cbView.segment2;
const __m256i *ptr = reinterpret_cast<const __m256i *>(segment + position + 1);
Sci::Position index = position;
uint64_t mask = 0;
do {
const __m256i chunk1 = _mm256_loadu_si256(ptr - 1);
const __m256i chunk2 = _mm256_loadu_si256(ptr - 2);
mask = mm256_movemask_epi8(_mm256_or_si256(_mm256_cmpeq_epi8(chunk2, mmBrace), _mm256_cmpeq_epi8(chunk2, mmSeek)));
mask |= static_cast<uint64_t>(mm256_movemask_epi8(_mm256_or_si256(_mm256_cmpeq_epi8(chunk1, mmBrace), _mm256_cmpeq_epi8(chunk1, mmSeek)))) << sizeof(__m256i);
if (mask != 0) {
index = position;
position -= 2*sizeof(__m256i);
break;
}
ptr -= 2;
position -= 2*sizeof(__m256i);
} while (position >= endPos);
if (index >= segmentLength && position < segmentLength) {
position = segmentLength - 1;
const uint32_t offset = 63 ^ static_cast<uint32_t>(index - segmentLength);
mask = (mask >> offset) << offset;
}
while (mask) {
const uint64_t leading = np2::clz(mask);
index -= leading;
mask <<= leading;
if (index > GetEndStyled() || StyleIndexAt(index) == styBrace) {
const unsigned char chAtPos = segment[index];
depth += (chAtPos == chBrace) ? 1 : -1;
if (depth == 0) {
return index;
}
}
index--;
mask <<= 1;
}
} while (position >= minPos);
}
// end NP2_USE_AVX2
#elif NP2_USE_SSE2
const SplitView cbView = cb.AllView();
const __m128i mmBrace = _mm_set1_epi8(chBrace);
const __m128i mmSeek = _mm_set1_epi8(chSeek);
if (direction >= 0) {
const SplitView cbView = cb.AllView();
const __m128i mmBrace = _mm_set1_epi8(chBrace);
const __m128i mmSeek = _mm_set1_epi8(chSeek);
const Sci::Position maxPos = length - 2*sizeof(__m128i);
const Sci::Position segmentEndPos = std::min<Sci::Position>(maxPos, cbView.length1 - 1);
do {
Expand Down Expand Up @@ -3060,6 +3105,51 @@ Sci::Position Document::BraceMatch(Sci::Position position, Sci::Position /*maxRe
}
} while (position <= maxPos);
}
else {
constexpr Sci::Position minPos = 2*sizeof(__m128i) - 1;
const Sci::Position segmentEndPos = std::max<Sci::Position>(minPos, cbView.length1);
do {
const Sci::Position segmentLength = cbView.length1;
const bool scanFirst = IsValidIndex(position, segmentLength);
const Sci::Position endPos = scanFirst ? minPos : segmentEndPos;
const char * const segment = scanFirst ? cbView.segment1 : cbView.segment2;
const __m128i *ptr = reinterpret_cast<const __m128i *>(segment + position + 1);
Sci::Position index = position;
uint32_t mask = 0;
do {
const __m128i chunk1 = _mm_loadu_si128(ptr - 1);
const __m128i chunk2 = _mm_loadu_si128(ptr - 2);
mask = mm_movemask_epi8(_mm_or_si128(_mm_cmpeq_epi8(chunk2, mmBrace), _mm_cmpeq_epi8(chunk2, mmSeek)));
mask |= mm_movemask_epi8(_mm_or_si128(_mm_cmpeq_epi8(chunk1, mmBrace), _mm_cmpeq_epi8(chunk1, mmSeek))) << sizeof(__m128i);
if (mask != 0) {
index = position;
position -= 2*sizeof(__m128i);
break;
}
ptr -= 2;
position -= 2*sizeof(__m128i);
} while (position >= endPos);
if (index >= segmentLength && position < segmentLength) {
position = segmentLength - 1;
const uint32_t offset = 31 ^ static_cast<uint32_t>(index - segmentLength);
mask = (mask >> offset) << offset;
}
while (mask) {
const uint32_t leading = np2::clz(mask);
index -= leading;
mask <<= leading;
if (index > GetEndStyled() || StyleIndexAt(index) == styBrace) {
const unsigned char chAtPos = segment[index];
depth += (chAtPos == chBrace) ? 1 : -1;
if (depth == 0) {
return index;
}
}
index--;
mask <<= 1;
}
} while (position >= minPos);
}
// end NP2_USE_SSE2
#endif
}
Expand Down
120 changes: 90 additions & 30 deletions scintilla/test/BraceMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,11 @@
#include <cstring>
#include <cstdio>
#include "../include/VectorISA.h"
#include "TestUtils.h"

// cl /EHsc /std:c++20 /DNDEBUG /O2 /FAcs /GS- /GR- /Gv /W4 /arch:AVX2 BraceMatchTest.cpp
// clang-cl /EHsc /std:c++20 /DNDEBUG /O2 /FA /GS- /GR- /Gv /W4 -march=x86-64-v3 BraceMatchTest.cpp
// clang-cl /EHsc /std:c++20 /DNDEBUG /O2 /FA /GS- /GR- /Gv /W4 -march=x86-64-v3 -fsanitize=address BraceMatchTest.cpp
// g++ -S -std=gnu++20 -DNDEBUG -O3 -fno-rtti -Wall -Wextra -march=x86-64-v3 BraceMatchTest.cpp
template <typename T>
constexpr T min(T x, T y) noexcept {
return (x < y) ? x : y;
}
constexpr bool IsValidIndex(size_t index, size_t length) noexcept {
return index < length;
}
struct SplitView {
const char *segment1 = nullptr;
size_t length1 = 0;
const char *segment2 = nullptr;
size_t length = 0;

char CharAt(size_t position) const noexcept {
if (position < length1) {
return segment1[position];
}
if (position < length) {
return segment2[position];
}
return '\0';
}
};
constexpr char chBrace = '{';
constexpr char chSeek = '}';
constexpr uint32_t maxLength = 512;
Expand Down Expand Up @@ -133,6 +111,88 @@ void FindAllBraceForward(const SplitView &cbView, ptrdiff_t position, const ptrd

void FindAllBraceBackward(const SplitView &cbView, ptrdiff_t position, uint32_t (&result)[maxLength]) noexcept {
unsigned j = 0;
#if NP2_USE_AVX2
const __m256i mmBrace = _mm256_set1_epi8(chBrace);
const __m256i mmSeek = _mm256_set1_epi8(chSeek);
constexpr ptrdiff_t minPos = 2*sizeof(__m256i) - 1;
const ptrdiff_t segmentLength = cbView.length1;
const ptrdiff_t segmentEndPos = max(minPos, segmentLength);
while (position >= minPos) {
const bool scanFirst = IsValidIndex(position, segmentLength);
const ptrdiff_t endPos = scanFirst ? minPos : segmentEndPos;
const char * const segment = scanFirst ? cbView.segment1 : cbView.segment2;
const __m256i *ptr = reinterpret_cast<const __m256i *>(segment + position + 1);
ptrdiff_t index = position;
uint64_t mask = 0;
do {
const __m256i chunk1 = _mm256_loadu_si256(ptr - 1);
const __m256i chunk2 = _mm256_loadu_si256(ptr - 2);
mask = mm256_movemask_epi8(_mm256_or_si256(_mm256_cmpeq_epi8(chunk2, mmBrace), _mm256_cmpeq_epi8(chunk2, mmSeek)));
mask |= static_cast<uint64_t>(mm256_movemask_epi8(_mm256_or_si256(_mm256_cmpeq_epi8(chunk1, mmBrace), _mm256_cmpeq_epi8(chunk1, mmSeek)))) << sizeof(__m256i);
if (mask != 0) {
index = position;
position -= 2*sizeof(__m256i);
break;
}
ptr -= 2;
position -= 2*sizeof(__m256i);
} while (position >= endPos);
if (index >= segmentLength && position < segmentLength) {
position = segmentLength - 1;
const uint32_t offset = 63 ^ static_cast<uint32_t>(index - segmentLength);
mask = (mask >> offset) << offset;
}
while (mask) {
const uint64_t leading = np2::clz(mask);
index -= leading;
mask <<= leading;
result[j++] = static_cast<uint32_t>(index + 1);
index--;
mask <<= 1;
}
}

#elif NP2_USE_SSE2
const __m128i mmBrace = _mm_set1_epi8(chBrace);
const __m128i mmSeek = _mm_set1_epi8(chSeek);
constexpr ptrdiff_t minPos = 2*sizeof(__m128i) - 1;
const ptrdiff_t segmentLength = cbView.length1;
const ptrdiff_t segmentEndPos = max(minPos, segmentLength);
while (position >= minPos) {
const bool scanFirst = IsValidIndex(position, segmentLength);
const ptrdiff_t endPos = scanFirst ? minPos : segmentEndPos;
const char * const segment = scanFirst ? cbView.segment1 : cbView.segment2;
const __m128i *ptr = reinterpret_cast<const __m128i *>(segment + position + 1);
ptrdiff_t index = position;
uint32_t mask = 0;
do {
const __m128i chunk1 = _mm_loadu_si128(ptr - 1);
const __m128i chunk2 = _mm_loadu_si128(ptr - 2);
mask = mm_movemask_epi8(_mm_or_si128(_mm_cmpeq_epi8(chunk2, mmBrace), _mm_cmpeq_epi8(chunk2, mmSeek)));
mask |= mm_movemask_epi8(_mm_or_si128(_mm_cmpeq_epi8(chunk1, mmBrace), _mm_cmpeq_epi8(chunk1, mmSeek))) << sizeof(__m128i);
if (mask != 0) {
index = position;
position -= 2*sizeof(__m128i);
break;
}
ptr -= 2;
position -= 2*sizeof(__m128i);
} while (position >= endPos);
if (index >= segmentLength && position < segmentLength) {
position = segmentLength - 1;
const uint32_t offset = 31 ^ static_cast<uint32_t>(index - segmentLength);
mask = (mask >> offset) << offset;
}
while (mask) {
const uint32_t leading = np2::clz(mask);
index -= leading;
mask <<= leading;
result[j++] = static_cast<uint32_t>(index + 1);
index--;
mask <<= 1;
}
}
#endif

while (position >= 0) {
const char chAtPos = cbView.CharAt(position);
Expand Down Expand Up @@ -185,7 +245,8 @@ int __cdecl main(int argc, char *argv[]) {
argc = atoi(argv[1]);
}

srand(static_cast<unsigned int>(reinterpret_cast<uintptr_t>(argv)));
LCGRandom random(static_cast<unsigned int>(reinterpret_cast<uintptr_t>(argv)));
//PCG32Random random(reinterpret_cast<uintptr_t>(argv), reinterpret_cast<uintptr_t>(argv[0]));
constexpr uint32_t padding = 32;
char buffer[padding + maxLength + padding + 1]{};
memset(buffer, chBrace, padding);
Expand Down Expand Up @@ -214,17 +275,16 @@ int __cdecl main(int argc, char *argv[]) {

for (int j = 0; j < argc; j++) {
for (uint32_t i = 0; i < maxLength; i += 4) {
const uint32_t value = rand();
const uint32_t value = random.Next();
buffer[i + padding + 0] = "0{12[3(45)6]78}9"[value & 15];
buffer[i + padding + 1] = "0{12[3(45)6]78}9"[(value >> 4) & 15];
buffer[i + padding + 2] = "0{12[3(45)6]78}9"[(value >> 8) & 15];
buffer[i + padding + 3] = "0{12[3(45)6]78}9"[(value >> 12) & 15];
}

const uint32_t value = rand();
const uint32_t gapPosition = value & (maxLength/2 - 1);
const uint32_t gapLength = (value >> 16) & (maxLength/2 - 1);
uint32_t position = rand() & (maxLength - 1);
const uint32_t gapPosition = random.Next() & (maxLength/2 - 1);
const uint32_t gapLength = random.Next() & (maxLength/2 - 1);
uint32_t position = random.Next() & (maxLength - 1);
const bool hasGap = gapPosition != 0 && gapLength != 0;
const uint32_t length = maxLength - (hasGap ? gapLength : 0);
if (position >= length) {
Expand Down
73 changes: 73 additions & 0 deletions scintilla/test/TestUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include <cstdint>
#include <cstdlib>

template <typename T>
constexpr T min(T x, T y) noexcept {
return (x < y) ? x : y;
}

template <typename T>
constexpr T max(T x, T y) noexcept {
return (x > y) ? x : y;
}

constexpr bool IsValidIndex(size_t index, size_t length) noexcept {
return index < length;
}

struct SplitView {
const char *segment1 = nullptr;
size_t length1 = 0;
const char *segment2 = nullptr;
size_t length = 0;

char CharAt(size_t position) const noexcept {
if (position < length1) {
return segment1[position];
}
if (position < length) {
return segment2[position];
}
return '\0';
}
};

struct CRTRandom {
CRTRandom(uint32_t seed) noexcept {
srand(seed);
}
uint32_t Next() const noexcept {
return rand();
}
};

struct LCGRandom {
uint32_t state;
LCGRandom(uint32_t seed) noexcept: state{seed} {}
uint32_t Next() noexcept {
//state = state*214013 + 2531011; // msvc
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/rand.html
state = state*1103515245 + 12345;
return (state >> 16) & RAND_MAX;
}
};

// https://www.pcg-random.org/download.html
struct PCG32Random {
uint64_t state;
uint64_t inc;
PCG32Random(uint64_t seed, uint64_t seq) noexcept: state{seed}, inc{seq | 1} {}
uint32_t Next() noexcept {
const uint64_t oldstate = state;
// Advance internal state
state = oldstate * UINT64_C(6364136223846793005) + inc;
// Calculate output function (XSH RR), uses old state for max ILP
const uint32_t xorshifted = static_cast<uint32_t>(((oldstate >> 18) ^ oldstate) >> 27);
const int rot = oldstate >> 59;
return (xorshifted >> rot) | (xorshifted << ((-rot) & 31));
}
};

// https://prng.di.unimi.it/

// https://lemire.me/blog/2018/07/02/predicting-the-truncated-xorshift32-random-number-generator/

0 comments on commit b462383

Please sign in to comment.