Skip to content

Commit

Permalink
fix grok1 test.
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz committed Apr 28, 2024
1 parent e0a8133 commit 4366e25
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 16 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ jobs:
run: |
make main
make funcs-test
make quants-test
make llama2-tasks-test
make grok1-tasks-test
- name: funcs-test
run: ./funcs-test
- name: quants-test
run: ./quants-test
- name: llama2-tasks-test
run: ./llama2-tasks-test
- name: grok1-tasks-test
Expand Down
13 changes: 4 additions & 9 deletions src/grok1-tasks-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ void compare(float* a, float* b, int n) {
}
}

void nop(TASK_ARGS) {}

int main() {
TransformerSpec spec;
spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int);
Expand Down Expand Up @@ -66,22 +64,19 @@ int main() {
transformer.pos = 0;

float* x = transformer.x;
for (int i = 0; i < spec.dim; i++) x[i] = randomF32(&state) / 100.0;
for (int i = 0; i < spec.dim; i++) x[i] = (randomF32(&state) / 100.0) / 78.38367176906169f;

TransformerArch arch = buildGrok1Arch(&spec);
arch.inference.tasks[arch.inference.nTasks - 4].handler = &nop;
arch.inference.tasks[arch.inference.nTasks - 3].handler = &nop;
arch.inference.tasks[arch.inference.nTasks - 2].handler = &nop;
arch.inference.tasks[arch.inference.nTasks - 1].handler = &nop;

int nThreads = 1;
int nThreads = 4;
TransformerContext context;
context.transformer = &transformer;
context.currentBlockIndex = 0;
context.socket = NULL;
context.socketPool = &socketPool;

TaskLoop loop(nThreads, arch.inference.nTasks, TASK_N_TYPES, arch.inference.tasks, &context);
int skipLastNTasks = 4;
TaskLoop loop(nThreads, arch.inference.nTasks - skipLastNTasks, TASK_N_TYPES, arch.inference.tasks, &context);
long t0 = timeMs();
loop.run();
long t1 = timeMs();
Expand Down
10 changes: 3 additions & 7 deletions src/llama2-tasks-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,6 @@ float expectedOutput[4096] = {
1.00493455, 1.00216055, 1.02500832, 1.01412213, 0.997673035, 1.01922369, 1.01705575, 1.01369667,
};

void nop(TASK_ARGS) {}

int main() {
TransformerSpec spec;
spec.headerSize = sizeof(TransformerFileOldHeader) + sizeof(int);
Expand Down Expand Up @@ -571,18 +569,16 @@ int main() {
for (int i = 0; i < spec.dim; i++) x[i] = randomF32(&state) / 120.0;

TransformerArch arch = buildLlama2Arch(&spec);
arch.inference.tasks[arch.inference.nTasks - 3].handler = &nop;
arch.inference.tasks[arch.inference.nTasks - 2].handler = &nop;
arch.inference.tasks[arch.inference.nTasks - 1].handler = &nop;

int nThreads = 1;
int nThreads = 4;
TransformerContext context;
context.transformer = &transformer;
context.currentBlockIndex = 0;
context.socket = NULL;
context.socketPool = &socketPool;

TaskLoop loop(nThreads, arch.inference.nTasks, TASK_N_TYPES, arch.inference.tasks, &context);
int skipLastNTasks = 3;
TaskLoop loop(nThreads, arch.inference.nTasks - skipLastNTasks, TASK_N_TYPES, arch.inference.tasks, &context);
long t0 = timeMs();
loop.run();
long t1 = timeMs();
Expand Down

0 comments on commit 4366e25

Please sign in to comment.