Skip to content

Commit

Permalink
funcs-test. (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Apr 28, 2024
1 parent b095f78 commit 45c3e5b
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 40 deletions.
18 changes: 17 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ jobs:
runs-on: ${{matrix.os}}
strategy:
matrix:
os: [ubuntu-latest]
os:
- ubuntu-latest
platforms:
- linux/arm64
- linux/amd64
steps:
- name: Checkout Repo
uses: actions/checkout@v3
Expand All @@ -25,3 +29,15 @@ jobs:
id: build
run: |
make main
make funcs-test
make quants-test
make llama2-tasks-test
make grok1-tasks-test
- name: funcs-test
run: ./funcs-test
- name: quants-test
run: ./quants-test
- name: llama2-tasks-test
run: ./llama2-tasks-test
- name: grok1-tasks-test
run: ./grok1-tasks-test
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
*.bin
__pycache__

quants-test
llama2-tasks-test
grok1-tasks-test
*-test
main
run.sh
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ quants: src/quants.cpp
$(CXX) $(CXXFLAGS) -c src/quants.cpp -o quants.o
funcs: src/funcs.cpp
$(CXX) $(CXXFLAGS) -c src/funcs.cpp -o funcs.o
funcs-test: src/funcs-test.cpp funcs
$(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o
socket: src/socket.cpp
$(CXX) $(CXXFLAGS) -c src/socket.cpp -o socket.o
transformer: src/utils.cpp
Expand All @@ -24,6 +26,8 @@ tokenizer: src/tokenizer.cpp

main: src/main.cpp utils quants funcs socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer
$(CXX) $(CXXFLAGS) src/main.cpp -o main utils.o quants.o funcs.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o -lpthread
funcs-test: src/funcs-test.cpp funcs utils quants
$(CXX) $(CXXFLAGS) src/funcs-test.cpp -o funcs-test funcs.o utils.o quants.o
quants-test: src/quants.cpp utils quants
$(CXX) $(CXXFLAGS) src/quants-test.cpp -o quants-test utils.o quants.o -lpthread
llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs socket transformer tasks llama2-tasks tokenizer
Expand Down
4 changes: 2 additions & 2 deletions converter/convert-llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import math
import numpy as np
from writer import writeTensor, writeHeader
from writer import writeTensor, writeHeader, isFloatTypeSupported
from pathlib import Path

LAYER_CHUNK_SIZE = 48
Expand Down Expand Up @@ -106,7 +106,7 @@ def usage():
modelPath = sys.argv[1]
targetFloatType = sys.argv[2]

if (not modelPath or not targetFloatType in ['f16', 'f32', 'q40']):
if (not modelPath or not isFloatTypeSupported(targetFloatType)):
usage()

modelName = modelPath.split('/')[-1]
Expand Down
1 change: 1 addition & 0 deletions converter/convert-tokenizer-sentencepiece.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def export(self):
print(f"{bytes.decode('utf-8')} {score}")
f.write(struct.pack("fI", score, len(bytes)))
f.write(bytes)
print(f'Created {outputPath}')

if __name__ == "__main__":
if (len(sys.argv) < 2):
Expand Down
50 changes: 41 additions & 9 deletions converter/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import numpy as np

def isFloatTypeSupported(type):
return type in ['f16', 'f32', 'q40']
return type in ['f16', 'f32', 'q40', 'q80']

def writeQuantizedQ40Tensor(file, x):
t0 = time.time()
x = x.to(torch.float32).numpy().astype(np.float32)
blockSize = 32
blockHalfSize = blockSize // 2
Expand Down Expand Up @@ -35,28 +34,61 @@ def writeQuantizedQ40Tensor(file, x):
buffer = struct.pack(f'e{blockHalfSize}B', delta16, *block)
file.write(buffer)
nBytes += len(buffer)
t1 = time.time()
print(f'Quantized tensor to {nBytes} bytes in {t1 - t0:.2f} s')
return nBytes

def writeQuantizedQ80Tensor(file, x):
x = x.to(torch.float32).numpy().astype(np.float32)
blockSize = 32
assert(x.shape[0] % blockSize == 0)
groups = x.reshape(-1, blockSize)
gmax = np.max(groups, axis=1)
gmin = np.min(groups, axis=1)
gabsMax = np.where(-gmin > gmax, -gmin, gmax)
deltas = gabsMax / ((1 << 7) - 1)
deltas16 = deltas.astype(np.float16)
ids = np.where(deltas != 0, 1.0 / deltas, 0)
groups = groups * ids[:, np.newaxis]
groups8 = np.round(groups).astype(np.int8)

nBytes = 0
for groupIndex in range(0, len(groups)):
buffer = struct.pack(f'e{blockSize}b', deltas16[groupIndex], *groups8[groupIndex])
file.write(buffer)
nBytes += len(buffer)
return nBytes

def writeF32Tensor(file, d):
chunkSize = 10000
nBytes = 0
for i in range(0, len(d), chunkSize):
chunk = d[i:i+chunkSize].to(torch.float32).numpy().astype(np.float32)
b = struct.pack(f'{len(chunk)}f', *chunk)
nBytes += len(b)
file.write(b)
return nBytes

def writeF16Tensor(file, d):
d = d.to(torch.float16).numpy().astype(np.float16)
b = struct.pack(f'{len(d)}e', *d)
file.write(b)
return len(b)

def writeTensor(file, tensor, floatType):
d = tensor.detach().cpu().view(-1)
t0 = time.time()
nBytes = 0
if (floatType == 'f16'):
d = d.to(torch.float16).numpy().astype(np.float16)
b = struct.pack(f'{len(d)}e', *d)
file.write(b)
nBytes = writeF16Tensor(file, d)
elif (floatType == 'f32'):
writeF32Tensor(file, d)
nBytes = writeF32Tensor(file, d)
elif (floatType == 'q40'):
writeQuantizedQ40Tensor(file, d)
nBytes = writeQuantizedQ40Tensor(file, d)
elif (floatType == 'q80'):
nBytes = writeQuantizedQ80Tensor(file, d)
else:
raise Exception('Unknown float type')
t1 = time.time()
print(f'Saved {floatType} tensor in {t1 - t0:.2f}s, {nBytes} bytes')

def writeHeader(file, params):
headerKeys = {
Expand Down
67 changes: 67 additions & 0 deletions src/funcs-test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "funcs.hpp"
#include "utils.hpp"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

void testRms() {
float x[] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f};
float r = rms(x, 8);
if (fabs(r - 1.980256) > 0.001) {
printf("❌ rms() = %f\n", r);
exit(EXIT_FAILURE);
}
printf("✅ rms\n");
}

void testMatmulQ80() {
const int n = 512;
const int d = 256;
unsigned long long state = 88888888L;
float x[n];
float w[n * d];
float y[d];
float yQ0[d];
float yQ1[d];
int i;
for (i = 0; i < n; i++) x[i] = randomF32(&state) / 127.0f;
for (i = 0; i < n * d; i++) w[i] = randomF32(&state) / 127.0f;

char* xQ = new char[getBatchBytes(Q80, n, 1)];
char* wQ = new char[getBatchBytes(Q80, n, d)];
quantizeQ80Row(x, (BlockQ80*)xQ, n, 1, 0);
quantizeQ80Row(w, (BlockQ80*)wQ, n * d, 1, 0);

matmul(F32, F32, y, x, w, n, d, 1, 0);
matmul(Q80, F32, yQ0, x, wQ, n, d, 1, 0);
matmul(Q80, Q80, yQ1, xQ, wQ, n, d, 1, 0);

for (i = 0; i < d; i++) {
float diff = fabs(y[i] - yQ0[i]);
if (diff > 0.001) {
printf("❌ matmulQ80() ix=%d %f != %f diff=%f\n", i, y[i], yQ0[i], diff);
exit(EXIT_FAILURE);
}
}
printf("✅ matmulQ80\n");

for (i = 0; i < d; i++) {
float diff = fabs(y[i] - yQ1[i]);
if (diff > 0.001) {
printf("❌ matmulQ80vQ80() ix=%d %f != %f diff=%f\n", i, y[i], yQ1[i], diff);
exit(EXIT_FAILURE);
}
}
printf("✅ matmulQ80vQ80\n");

delete[] xQ;
delete[] wQ;
}

int main() {
initQuants();

testRms();
testMatmulQ80();
return EXIT_SUCCESS;
}
55 changes: 51 additions & 4 deletions src/funcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,25 @@ void matmulQ40(MatmulThreadInfo* a) {
#endif
}

void matmulQ80(MatmulThreadInfo* a) {
float* input = (float*)a->input;
BlockQ80* weights = (BlockQ80*)a->weights;
assert(a->n % QK80 == 0);
int nb = a->n / QK80;

for (int d = a->ds; d < a->de; d++) {
float sum = 0.0;
for (int i = 0; i < nb; i++) {
float s = 0.0;
for (int j = 0; j < QK80; j++) {
s += input[i * QK80 + j] * (float)weights[d * nb + i].qs[j];
}
sum += s * convertF16ToF32(weights[d * nb + i].d);
}
a->output[d] = sum;
}
}

void matmulQ40vQ80(MatmulThreadInfo* a) {
const BlockQ40* w = (BlockQ40*)a->weights;
const BlockQ80* input = (BlockQ80*)a->input;
Expand Down Expand Up @@ -334,6 +353,25 @@ void matmulQ40vQ80(MatmulThreadInfo* a) {
#endif
}

void matmulQ80vQ80(MatmulThreadInfo* a) {
BlockQ80* input = (BlockQ80*)a->input;
BlockQ80* weights = (BlockQ80*)a->weights;
assert(a->n % QK80 == 0);
int nb = a->n / QK80;

for (int d = a->ds; d < a->de; d++) {
float sum = 0.0;
for (int i = 0; i < nb; i++) {
int s = 0;
for (int j = 0; j < QK80; j++) {
s += input[i].qs[j] * (int)weights[d * nb + i].qs[j];
}
sum += s * (convertF16ToF32(input[i].d) * convertF16ToF32(weights[d * nb + i].d));
}
a->output[d] = sum;
}
}

// weights input output
// ___________ ___ ___
// | | | | | |
Expand Down Expand Up @@ -363,10 +401,19 @@ void matmul(FloatType weightsFloatType, FloatType inputFloatType, float* output,
matmulQ40(&s);
return;
}
}
if (inputFloatType == Q80 && weightsFloatType == Q40) {
matmulQ40vQ80(&s);
return;
if (weightsFloatType == Q80) {
matmulQ80(&s);
return;
}
} else if (inputFloatType == Q80) {
if (weightsFloatType == Q40) {
matmulQ40vQ80(&s);
return;
}
if (weightsFloatType == Q80) {
matmulQ80vQ80(&s);
return;
}
}

printf("Unsupported float types: %d/%d\n", weightsFloatType, inputFloatType);
Expand Down
9 changes: 6 additions & 3 deletions src/grok1-tasks-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ float expectedOutput_5012_5016[] = { 0.0126675405, 0.0169415697, 0.0183475353, 0

void compare(float* a, float* b, int n) {
for (int i = 0; i < n; i++) {
if (fabs(a[i] - b[i]) > 0.00001) { // Optimization may cause some differences
if (std::isnan(a[i]) || fabs(a[i] - b[i]) > 0.00001) { // Optimization may cause some differences
printf("%.9g != %.9g\n", a[i], b[i]); i++;
printf("%.9g != %.9g\n", a[i], b[i]); i++;
printf("%.9g != %.9g\n", a[i], b[i]); i++;
Expand Down Expand Up @@ -44,6 +44,8 @@ int main() {
spec.weightsFloatType = F32;
spec.bufferFloatType = F32;
spec.nSlices = 1;
spec.hiddenAct = GELU;
spec.ropeTheta = 10000.0f;

size_t beforeBlockBytes = spec.dim * spec.vocabSize * sizeof(float);
size_t blockBytes = 956596224;
Expand All @@ -62,7 +64,7 @@ int main() {
transformer.pos = 0;

float* x = transformer.x;
for (int i = 0; i < spec.dim; i++) x[i] = randomF32(&state) / 100.0;
for (int i = 0; i < spec.dim; i++) x[i] = (randomF32(&state) / 100.0) / 78.38367176906169f;

TransformerArch arch = buildGrok1Arch(&spec);

Expand All @@ -73,7 +75,8 @@ int main() {
context.socket = NULL;
context.socketPool = &socketPool;

TaskLoop loop(nThreads, arch.inference.nTasks, TASK_N_TYPES, arch.inference.tasks, &context);
int skipLastNTasks = 4;
TaskLoop loop(nThreads, arch.inference.nTasks - skipLastNTasks, TASK_N_TYPES, arch.inference.tasks, &context);
long t0 = timeMs();
loop.run();
long t1 = timeMs();
Expand Down
12 changes: 4 additions & 8 deletions src/llama2-tasks-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,6 @@ float expectedOutput[4096] = {
1.00493455, 1.00216055, 1.02500832, 1.01412213, 0.997673035, 1.01922369, 1.01705575, 1.01369667,
};

void nop(TASK_ARGS) {}

int main() {
TransformerSpec spec;
spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int);
Expand Down Expand Up @@ -571,18 +569,16 @@ int main() {
for (int i = 0; i < spec.dim; i++) x[i] = randomF32(&state) / 120.0;

TransformerArch arch = buildLlama2Arch(&spec);
arch.inference.tasks[arch.inference.nTasks - 3].handler = &nop;
arch.inference.tasks[arch.inference.nTasks - 2].handler = &nop;
arch.inference.tasks[arch.inference.nTasks - 1].handler = &nop;

int nThreads = 1;
int nThreads = 4;
TransformerContext context;
context.transformer = &transformer;
context.currentBlockIndex = 0;
context.socket = NULL;
context.socketPool = &socketPool;

TaskLoop loop(nThreads, arch.inference.nTasks, TASK_N_TYPES, arch.inference.tasks, &context);
int skipLastNTasks = 3;
TaskLoop loop(nThreads, arch.inference.nTasks - skipLastNTasks, TASK_N_TYPES, arch.inference.tasks, &context);
long t0 = timeMs();
loop.run();
long t1 = timeMs();
Expand All @@ -591,7 +587,7 @@ int main() {

int ix = -1;
for (int i = 0; i < spec.dim; i++) {
if (fabs(x[i] - expectedOutput[i]) > 0.00001) { // Optimization may cause some differences
if (std::isnan(x[i]) || fabs(x[i] - expectedOutput[i]) > 0.00001) { // Optimization may cause some differences
ix = i;
break;
}
Expand Down
Loading

0 comments on commit 45c3e5b

Please sign in to comment.