Skip to content

Commit

Permalink
feat: avx2 (#2)
Browse files Browse the repository at this point in the history
* avx2.
  • Loading branch information
b4rtaz authored Jan 23, 2024
1 parent 7eb77ca commit f2137af
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ __pycache__
quants-test
transformer-tasks-test
main
run.sh
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
CXX = g++
CXXFLAGS = -std=c++11 -Werror -O3
CXXFLAGS = -std=c++11 -Werror -O3 -march=native -mtune=native

utils: src/utils.cpp
$(CXX) $(CXXFLAGS) -c src/utils.cpp -o utils.o
Expand Down
20 changes: 13 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@ This project was initiated based on the [llama2.c](https://github.com/karpathy/l
* This project is a proof of concept, it's not optimized for production usage.
* You can run Distributed Llama only on 1, 2, 4... 2^n devices.
* The project supports only the inference mode, the chat mode is not supported.
* Optimized for:
* ✅ ARM CPUs
* ❌ x86_64 CPUs (Q40xF32 mode works but is slow)
* Optimized for (weights format × buffer format):
* ARM CPUs
* ✅ F32 × F32
* ❌ F16 × F16
* ✅ Q40 × Q80
* x86_64 AVX2 CPUs
* ❌ F32 × F32
* ❌ F16 × F16
* ⚠️ Q40 × Q80 (partial optimization)

**Supported models**
* Llama 2 7B
Expand Down Expand Up @@ -134,7 +140,7 @@ sudo nice -n -20 ./main worker --port 9998
```
10. Run root node on the root device:
```sh
sudo nice -n -20 ./main inference --model ../dllama_llama-2-13b_q40.bin --tokenizer ../tokenizer.bin --weights-float-type q40 --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 --workers 10.0.0.2:9998
sudo nice -n -20 ./main inference --model ../dllama_llama-2-7b_q40.bin --tokenizer ../tokenizer.bin --weights-float-type q40 --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 --workers 10.0.0.2:9998
```

To add more worker nodes, just add more addresses to the `--workers` argument.
Expand All @@ -145,9 +151,9 @@ To add more worker nodes, just add more addresses to the `--workers` argument.

[Share your results](https://github.com/b4rtaz/distributed-llama/discussions)!

## 💻 How to Run on Debian x86_64
## 💻 How to Run on MacOS or Linux

x86_64 CPUs are not optimized yet but still you can observe a significant speedup when you run Distributed Llama on multiple devices.
You need to have x86_64 AVX2 CPU or ARM CPU. Different devices may have different CPUs.

1. Install Git and G++:
```sh
Expand Down Expand Up @@ -177,7 +183,7 @@ sudo nice -n -20 ./main worker --port 9998
```
7. Run worker nodes on worker devices:
```sh
sudo nice -n -20 ./main inference --model ../dllama_llama-2-13b_q40.bin --tokenizer ../tokenizer.bin --weights-float-type q40 --buffer-float-type f32 --prompt "Hello world" --steps 16 --nthreads 4 --workers 192.168.0.1:9998
sudo nice -n -20 ./main inference --model ../dllama_llama-2-7b_q40.bin --tokenizer ../tokenizer.bin --weights-float-type q40 --buffer-float-type q80 --prompt "Hello world" --steps 16 --nthreads 4 --workers 192.168.0.1:9998
```

## 💡 License
Expand Down
75 changes: 75 additions & 0 deletions src/funcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,57 @@

#if defined(__ARM_NEON)
#include <arm_neon.h>
#elif defined(__AVX2__)
#include <immintrin.h>
#endif

#if defined(__AVX2__)
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)

static inline __m256i bytes_from_nibbles_32(const uint8_t* rsi) {
// Load 16 bytes from memory
__m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
__m128i tmph = _mm_srli_epi16(tmpl, 4);
const __m128i lowMask = _mm_set1_epi8(0xF);
tmpl = _mm_and_si128(lowMask, tmpl);
tmph = _mm_and_si128(lowMask, tmph);
return MM256_SET_M128I(tmph, tmpl);
}

static inline float hsum_float_8(const __m256 x) {
__m128 res = _mm256_extractf128_ps(x, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(x));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));
return _mm_cvtss_f32(res);
}

// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
const __m128i ones = _mm_set1_epi16(1);
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
return _mm256_cvtepi32_ps(summed_pairs);
}

// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
const __m128i xl = _mm256_castsi256_si128(x);
const __m128i xh = _mm256_extractf128_si256(x, 1);
const __m128i yl = _mm256_castsi256_si128(y);
const __m128i yh = _mm256_extractf128_si256(y, 1);
// Get absolute values of x vectors
const __m128i axl = _mm_sign_epi8(xl, xl);
const __m128i axh = _mm_sign_epi8(xh, xh);
// Sign the values of the y vectors
const __m128i syl = _mm_sign_epi8(yl, xl);
const __m128i syh = _mm_sign_epi8(yh, xh);
// Perform multiplication and create 16-bit values
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
const __m128i doth = _mm_maddubs_epi16(axh, syh);
return sum_i16_pairs_float(doth, dotl);
}
#endif

void softmax(float* x, const int size) {
Expand Down Expand Up @@ -253,6 +304,30 @@ void matmulQ40vQ80(MatmulThreadInfo* a) {
}
a->output[d] = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
}
#elif defined(__AVX2__)
for (int d = a->ds; d < a->de; d++) {
__m256 acc = _mm256_setzero_ps();

for (int j = 0; j < n; j++) {
/* Compute combined scale for the block */
const __m256 cd = _mm256_set1_ps( convertF16ToF32(w[d * n + j].d) * convertF16ToF32(input[j].d) );

__m256i bx = bytes_from_nibbles_32(w[d * n + j].qs);

// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
bx = _mm256_sub_epi8(bx, off);

__m256i by = _mm256_loadu_si256((const __m256i *)input[j].qs);

const __m256 q = mul_sum_i8_pairs_float(bx, by);

/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps( cd, q, acc );
}

a->output[d] = hsum_float_8(acc);
}
#else
printf("matmulQ40vQ80 - not implemented\n");
exit(EXIT_FAILURE);
Expand Down

0 comments on commit f2137af

Please sign in to comment.