Skip to content

Commit 73124c1

Browse files
committed
using newer version of case, which uses custom section whenever possible
1 parent 0015607 commit 73124c1

File tree

3 files changed

+76
-31
lines changed

3 files changed

+76
-31
lines changed

lib/cuda/cwc_convnet.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,11 @@ static void _cwc_convnet_convolutional_backward_propagate(ccv_convnet_layer_t* l
555555
configuration->bias, layer->net.convolutional.count);
556556
if (b)
557557
{
558-
dim3 threads_per_block(batch, 1);
558+
dim3 threads_per_block(batch, layer->input.matrix.channels / 8);
559559
dim3 num_blocks(layer->input.matrix.rows, layer->input.matrix.cols);
560-
shared_memory_size = sizeof(float) * (batch * 2 + layer->input.matrix.channels * 48);
560+
shared_memory_size = sizeof(float) * (batch * 2 + layer->input.matrix.channels * 16);
561561
_cwc_kern_convolutional_backward_propagate
562-
<1, 3, 48>
562+
<1, 8, 16>
563563
<<<num_blocks, threads_per_block, shared_memory_size, stream>>>
564564
(layer->net.convolutional.strides, layer->net.convolutional.border, batch,
565565
b, layer->input.matrix.rows, layer->input.matrix.cols, layer->input.matrix.channels,

test/case.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,32 @@
4343

4444
typedef void (*case_f)(char*, int*);
4545

46+
#ifdef __ELF__
47+
// in ELF object format, we can simply query custom section rather than scan through the whole binary memory
48+
// to find function pointer. We do this whenever possible because in this way, we don't have access error
49+
// when hooking up with memory checkers such as address sanitizer or valgrind
50+
typedef struct {
51+
case_f func;
52+
char* name;
53+
} case_t;
54+
55+
#define TEST_CASE(desc) \
56+
static void __attribute__((used)) INTERNAL_CATCH_UNIQUE_NAME(__test_case_func__) (char* __case_name__, int* __case_result__); \
57+
static case_t INTERNAL_CATCH_UNIQUE_NAME(__test_case_ctx__) __attribute__((used)) __attribute__((section("case_data"))) = { .func = INTERNAL_CATCH_UNIQUE_NAME(__test_case_func__), .name = desc }; \
58+
static void INTERNAL_CATCH_UNIQUE_NAME(__test_case_func__) (char* __case_name__, int* __case_result__)
59+
#else
4660
typedef struct {
4761
uint64_t sig_head;
48-
case_f driver;
62+
case_f func;
4963
char* name;
5064
uint64_t sig_tail;
5165
} case_t;
5266

5367
#define TEST_CASE(desc) \
54-
static void __attribute__((used)) INTERNAL_CATCH_UNIQUE_NAME(__test_case_driver__) (char* __case_name__, int* __case_result__); \
55-
static case_t INTERNAL_CATCH_UNIQUE_NAME(__test_case_ctx__) __attribute__((used)) = { .driver = INTERNAL_CATCH_UNIQUE_NAME(__test_case_driver__), .sig_head = 0x883253372849284B, .name = desc, .sig_tail = 0x883253372849284B }; \
56-
static void INTERNAL_CATCH_UNIQUE_NAME(__test_case_driver__) (char* __case_name__, int* __case_result__)
68+
static void __attribute__((used)) INTERNAL_CATCH_UNIQUE_NAME(__test_case_func__) (char* __case_name__, int* __case_result__); \
69+
static case_t INTERNAL_CATCH_UNIQUE_NAME(__test_case_ctx__) __attribute__((used)) = { .func = INTERNAL_CATCH_UNIQUE_NAME(__test_case_func__), .sig_head = 0x883253372849284B, .name = desc, .sig_tail = 0x883253372849284B }; \
70+
static void INTERNAL_CATCH_UNIQUE_NAME(__test_case_func__) (char* __case_name__, int* __case_result__)
71+
#endif
5772

5873
#define ABORT_CASE (*__case_result__) = -1; return;
5974

test/case_main.h

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,52 @@
11
#ifndef _GUARD_case_main_h_
22
#define _GUARD_case_main_h_
33

4+
static void case_run(case_t* test_case, int i, int total, int* pass, int* fail)
5+
{
6+
printf("\033[0;34m[%d/%d]\033[0;0m \033[1;33m[RUN]\033[0;0m %s ...", i + 1, total, test_case->name);
7+
fflush(stdout);
8+
int result = 0;
9+
test_case->func(test_case->name, &result);
10+
if (result == 0)
11+
{
12+
(*pass)++;
13+
printf("\r\033[0;34m[%d/%d]\033[0;0m \033[1;32m[PASS]\033[0;0m %s \n", i + 1, total, test_case->name);
14+
} else {
15+
(*fail)++;
16+
printf("\n\033[0;34m[%d/%d]\033[0;0m \033[1;31m[FAIL]\033[0;0m %s\n", i + 1, total, test_case->name);
17+
}
18+
}
19+
20+
static void case_conclude(int pass, int fail)
21+
{
22+
if (fail == 0)
23+
printf("\033[0;32mall test case(s) passed, congratulations!\033[0;0m\n");
24+
else
25+
printf("\033[0;31m%d of %d test case(s) passed\033[0;0m\n", pass, fail + pass);
26+
}
27+
28+
#ifdef __ELF__
29+
// in ELF object format, we can simply query custom section rather than scan through the whole binary memory
30+
// to find function pointer. We do this whenever possible because in this way, we don't have access error
31+
// when hooking up with memory checkers such as address sanitizer or valgrind
32+
extern case_t __start_case_data[];
33+
extern case_t __stop_case_data[];
34+
35+
int main(int argc, char** argv)
36+
{
37+
int total = __stop_case_data - __start_case_data;
38+
int i, pass = 0, fail = 0;
39+
for (i = 0; i < total; i++)
40+
{
41+
case_t* test_case = __start_case_data + i;
42+
case_run(test_case, i, total, &pass, &fail);
43+
}
44+
case_conclude(pass, fail);
45+
return fail;
46+
}
47+
48+
#else
49+
450
#include <stdio.h>
551
#include <unistd.h>
652
#include <stdlib.h>
@@ -309,36 +355,20 @@ int main(int argc, char** argv)
309355
int i;
310356
for (i = 0; i < len; i++)
311357
{
312-
case_t* test_suite = (case_t*)(start_pointer + i);
313-
if (test_suite->sig_head == the_sig && test_suite->sig_tail == the_sig)
358+
case_t* test_case = (case_t*)(start_pointer + i);
359+
if (test_case->sig_head == the_sig && test_case->sig_tail == the_sig)
314360
total++;
315361
}
316-
int j = 1, pass = 0, fail = 0;
362+
int j = 0, pass = 0, fail = 0;
317363
for (i = 0; i < len; i++)
318364
{
319-
case_t* test_suite = (case_t*)(start_pointer + i);
320-
if (test_suite->sig_head == the_sig && test_suite->sig_tail == the_sig)
321-
{
322-
printf("\033[0;34m[%d/%d]\033[0;0m \033[1;33m[RUN]\033[0;0m %s ...", j, total, test_suite->name);
323-
fflush(stdout);
324-
int result = 0;
325-
test_suite->driver(test_suite->name, &result);
326-
if (result == 0)
327-
{
328-
pass++;
329-
printf("\r\033[0;34m[%d/%d]\033[0;0m \033[1;32m[PASS]\033[0;0m %s \n", j, total, test_suite->name);
330-
} else {
331-
fail++;
332-
printf("\n\033[0;34m[%d/%d]\033[0;0m \033[1;31m[FAIL]\033[0;0m %s\n", j, total, test_suite->name);
333-
}
334-
j++;
335-
}
365+
case_t* test_case = (case_t*)(start_pointer + i);
366+
if (test_case->sig_head == the_sig && test_case->sig_tail == the_sig)
367+
case_run(test_case, j++, total, &pass, &fail);
336368
}
337-
if (fail == 0)
338-
printf("\033[0;32mall test case(s) passed, congratulations!\033[0;0m\n");
339-
else
340-
printf("\033[0;31m%d of %d test case(s) passed\033[0;0m\n", pass, fail + pass);
369+
case_conclude(pass, fail);
341370
return fail;
342371
}
343372

344373
#endif
374+
#endif

0 commit comments

Comments
 (0)