-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for generic matrix multiplication
- Loading branch information
1 parent
4c635a2
commit 43d7182
Showing
3 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#ifndef NN_DOT_PRODUCT_MATRIX_H | ||
#define NN_DOT_PRODUCT_MATRIX_H | ||
|
||
#include "nn_error.h" | ||
#include <stddef.h> | ||
|
||
// MATRIX_MAX_ROWS defines the maximum number of rows in a matrix. | ||
#ifndef MATRIX_MAX_ROWS | ||
#define MATRIX_MAX_ROWS 3 | ||
#endif | ||
|
||
// MATRIX_MAX_COLS defines the maximum number of columns in a matrix. | ||
#ifndef MATRIX_MAX_COLS | ||
#define MATRIX_MAX_COLS 3 | ||
#endif | ||
|
||
// NNDotProductMatrixFunction represents a function that calculates | ||
// the dot product of two matrices. | ||
typedef void (*NNDotProductMatrixFunction)(float result[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float a[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float b[MATRIX_MAX_ROWS][MATRIX_MAX_COLS]); | ||
|
||
// nn_dot_product_matrix calculates the dot product of two square | ||
// matrices. | ||
// | ||
// The dimensions of the input matrices and the resultant matrix are | ||
// implicitly the same. | ||
void nn_dot_product_matrix(float result[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float a[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float b[MATRIX_MAX_ROWS][MATRIX_MAX_COLS]); | ||
|
||
#endif // NN_DOT_PRODUCT_MATRIX_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#define NN_DOT_PRODUCT_MATRIX_C | ||
#include "nn_dot_product_matrix.h" | ||
#include "nn_debug.h" | ||
#include <stddef.h> | ||
#include <string.h> | ||
|
||
// nn_dot_product_matrix calculates the dot product of two square | ||
// matrices. | ||
void nn_dot_product_matrix(float result[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float a[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float b[MATRIX_MAX_ROWS][MATRIX_MAX_COLS]) { | ||
NN_DEBUG_PRINT(5, "function %s called\n", __func__); | ||
|
||
// Initialize the result matrix. | ||
for (int i = 0; i < MATRIX_MAX_ROWS; i++) { | ||
memset(&result[i], 0, MATRIX_MAX_COLS * sizeof(float)); | ||
} | ||
|
||
// Multiply two square matrices. | ||
for (int i = 0; i < MATRIX_MAX_ROWS; i++) { | ||
for (int j = 0; j < MATRIX_MAX_COLS; j++) { | ||
for (int k = 0; k < MATRIX_MAX_ROWS; k++) { | ||
result[i][j] = result[i][j] + a[i][k] * b[k][j]; | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
#include "nn_config.h" | ||
#include "nn_debug.h" | ||
#include "nn_dot_product_matrix.h" | ||
#include <assert.h> | ||
#include <math.h> | ||
#include <stdbool.h> | ||
#include <stdio.h> | ||
|
||
// N_TEST_CASES defines the number of test cases. | ||
#define N_TEST_CASES 4 | ||
// DEFAULT_OUTPUT_TOLERANCE defines the default tolerance for comparing output values. | ||
#define DEFAULT_OUTPUT_TOLERANCE 0.0001f | ||
|
||
// TestCase defines a single test case. | ||
typedef struct { | ||
float a[MATRIX_MAX_ROWS][MATRIX_MAX_COLS]; | ||
float b[MATRIX_MAX_ROWS][MATRIX_MAX_COLS]; | ||
float bias; | ||
NNDotProductMatrixFunction dot_product_matrix_func; | ||
float output_tolerance; | ||
float expected_output[MATRIX_MAX_ROWS][MATRIX_MAX_COLS]; | ||
} TestCase; | ||
|
||
// run_test_cases runs the test cases. | ||
void run_test_cases(TestCase *test_cases, int n_cases, char *info, NNDotProductMatrixFunction dot_product_matrix_func) { | ||
for (int i = 0; i < n_cases; ++i) { | ||
TestCase tc = test_cases[i]; | ||
|
||
float output[MATRIX_MAX_ROWS][MATRIX_MAX_COLS]; | ||
|
||
NN_DEBUG_PRINT(5, "A:\n"); | ||
for (int i = 0; i < MATRIX_MAX_ROWS; i++) { | ||
for (int j = 0; j < MATRIX_MAX_COLS; j++) { | ||
NN_DEBUG_PRINT(5, " %f", tc.a[i][j]); | ||
} | ||
NN_DEBUG_PRINT(5, "\n"); | ||
} | ||
|
||
NN_DEBUG_PRINT(5, "B:\n"); | ||
for (int i = 0; i < MATRIX_MAX_ROWS; i++) { | ||
for (int j = 0; j < MATRIX_MAX_COLS; j++) { | ||
NN_DEBUG_PRINT(5, " %f", tc.b[i][j]); | ||
} | ||
NN_DEBUG_PRINT(5, "\n"); | ||
} | ||
|
||
dot_product_matrix_func(output, tc.a, tc.b); | ||
|
||
NN_DEBUG_PRINT(5, "C:\n"); | ||
for (int i = 0; i < MATRIX_MAX_ROWS; i++) { | ||
for (int j = 0; j < MATRIX_MAX_COLS; j++) { | ||
NN_DEBUG_PRINT(5, " %f", tc.expected_output[i][j]); | ||
} | ||
NN_DEBUG_PRINT(5, "\n"); | ||
} | ||
|
||
for (int m = 0; m < MATRIX_MAX_ROWS; m++) { | ||
for (int n = 0; n < MATRIX_MAX_COLS; n++) { | ||
assert(isnan(output[m][n]) == false); | ||
assert(fabs(output[m][n] - tc.expected_output[m][n]) < tc.output_tolerance); | ||
} | ||
} | ||
printf("passed: %s case=%d info=%s\n", __func__, i + 1, info); | ||
} | ||
} | ||
|
||
int main() { | ||
// nn_set_debug_level(10); | ||
|
||
TestCase test_cases[N_TEST_CASES] = { | ||
{ | ||
.a = {{0}}, | ||
.b = {{0}}, | ||
.output_tolerance = DEFAULT_OUTPUT_TOLERANCE, | ||
.expected_output = {{0}}, | ||
}, | ||
|
||
{ | ||
.a = {{ 3, 0}, | ||
{-1, 2}, | ||
{ 1, 1}}, | ||
.b = {{4, -1}, | ||
{0, 2}}, | ||
.output_tolerance = DEFAULT_OUTPUT_TOLERANCE, | ||
.expected_output = {{ 12, -3}, | ||
{-4, 5}, | ||
{ 4, 1}}, | ||
}, | ||
|
||
{ | ||
.a = {{ 1, 5, 2}, | ||
{-1, 0, 1}, | ||
{ 3, 2, 4}}, | ||
.b = {{ 6, 1, 3}, | ||
{-1, 1, 2}, | ||
{ 4, 1, 3}}, | ||
.output_tolerance = DEFAULT_OUTPUT_TOLERANCE, | ||
.expected_output = {{ 9, 8, 19}, | ||
{-2, 0, 0}, | ||
{32, 9, 25}}, | ||
}, | ||
|
||
{ | ||
.a = {{1, 2}, | ||
{3, 4}}, | ||
.b = {{5}, | ||
{6}}, | ||
.output_tolerance = DEFAULT_OUTPUT_TOLERANCE, | ||
.expected_output = {{17}, | ||
{39}}, | ||
}, | ||
|
||
}; | ||
run_test_cases(test_cases, N_TEST_CASES, "nn_dot_product_matrix", nn_dot_product_matrix); | ||
return 0; | ||
} |