Skip to content

Commit

Permalink
Add support for generic matrix multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
evanmcclure committed Apr 15, 2024
1 parent 4c635a2 commit 43d7182
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 0 deletions.
28 changes: 28 additions & 0 deletions include/nn_dot_product_matrix.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef NN_DOT_PRODUCT_MATRIX_H
#define NN_DOT_PRODUCT_MATRIX_H

#include "nn_error.h"
#include <stddef.h>

// MATRIX_MAX_ROWS defines the maximum number of rows in a matrix.
#ifndef MATRIX_MAX_ROWS
#define MATRIX_MAX_ROWS 3
#endif

// MATRIX_MAX_COLS defines the maximum number of columns in a matrix.
#ifndef MATRIX_MAX_COLS
#define MATRIX_MAX_COLS 3
#endif

// NNDotProductMatrixFunction represents a function that calculates
// the dot product of two matrices.
typedef void (*NNDotProductMatrixFunction)(float result[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float a[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float b[MATRIX_MAX_ROWS][MATRIX_MAX_COLS]);

// nn_dot_product_matrix calculates the dot product of two square
// matrices.
//
// The dimensions of the input matrices and the resultant matrix are
// implicitly the same.
void nn_dot_product_matrix(float result[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float a[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float b[MATRIX_MAX_ROWS][MATRIX_MAX_COLS]);

#endif // NN_DOT_PRODUCT_MATRIX_H
25 changes: 25 additions & 0 deletions src/nn_dot_product matrix.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#define NN_DOT_PRODUCT_MATRIX_C
#include "nn_dot_product_matrix.h"
#include "nn_debug.h"
#include <stddef.h>
#include <string.h>

// nn_dot_product_matrix calculates the dot product of two square
// matrices.
void nn_dot_product_matrix(float result[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float a[MATRIX_MAX_ROWS][MATRIX_MAX_COLS], const float b[MATRIX_MAX_ROWS][MATRIX_MAX_COLS]) {
NN_DEBUG_PRINT(5, "function %s called\n", __func__);

// Initialize the result matrix.
for (int i = 0; i < MATRIX_MAX_ROWS; i++) {
memset(&result[i], 0, MATRIX_MAX_COLS * sizeof(float));
}

// Multiply two square matrices.
for (int i = 0; i < MATRIX_MAX_ROWS; i++) {
for (int j = 0; j < MATRIX_MAX_COLS; j++) {
for (int k = 0; k < MATRIX_MAX_ROWS; k++) {
result[i][j] = result[i][j] + a[i][k] * b[k][j];
}
}
}
}
116 changes: 116 additions & 0 deletions tests/arch/generic/dot_product_matrix/main.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#include "nn_config.h"
#include "nn_debug.h"
#include "nn_dot_product_matrix.h"
#include <assert.h>
#include <math.h>
#include <stdbool.h>
#include <stdio.h>

// N_TEST_CASES defines the number of test cases.
#define N_TEST_CASES 4
// DEFAULT_OUTPUT_TOLERANCE defines the default tolerance for comparing output values.
#define DEFAULT_OUTPUT_TOLERANCE 0.0001f

// TestCase defines a single test case.
typedef struct {
float a[MATRIX_MAX_ROWS][MATRIX_MAX_COLS];
float b[MATRIX_MAX_ROWS][MATRIX_MAX_COLS];
float bias;
NNDotProductMatrixFunction dot_product_matrix_func;
float output_tolerance;
float expected_output[MATRIX_MAX_ROWS][MATRIX_MAX_COLS];
} TestCase;

// run_test_cases runs the test cases.
void run_test_cases(TestCase *test_cases, int n_cases, char *info, NNDotProductMatrixFunction dot_product_matrix_func) {
for (int i = 0; i < n_cases; ++i) {
TestCase tc = test_cases[i];

float output[MATRIX_MAX_ROWS][MATRIX_MAX_COLS];

NN_DEBUG_PRINT(5, "A:\n");
for (int i = 0; i < MATRIX_MAX_ROWS; i++) {
for (int j = 0; j < MATRIX_MAX_COLS; j++) {
NN_DEBUG_PRINT(5, " %f", tc.a[i][j]);
}
NN_DEBUG_PRINT(5, "\n");
}

NN_DEBUG_PRINT(5, "B:\n");
for (int i = 0; i < MATRIX_MAX_ROWS; i++) {
for (int j = 0; j < MATRIX_MAX_COLS; j++) {
NN_DEBUG_PRINT(5, " %f", tc.b[i][j]);
}
NN_DEBUG_PRINT(5, "\n");
}

dot_product_matrix_func(output, tc.a, tc.b);

NN_DEBUG_PRINT(5, "C:\n");
for (int i = 0; i < MATRIX_MAX_ROWS; i++) {
for (int j = 0; j < MATRIX_MAX_COLS; j++) {
NN_DEBUG_PRINT(5, " %f", tc.expected_output[i][j]);
}
NN_DEBUG_PRINT(5, "\n");
}

for (int m = 0; m < MATRIX_MAX_ROWS; m++) {
for (int n = 0; n < MATRIX_MAX_COLS; n++) {
assert(isnan(output[m][n]) == false);
assert(fabs(output[m][n] - tc.expected_output[m][n]) < tc.output_tolerance);
}
}
printf("passed: %s case=%d info=%s\n", __func__, i + 1, info);
}
}

int main() {
// nn_set_debug_level(10);

TestCase test_cases[N_TEST_CASES] = {
{
.a = {{0}},
.b = {{0}},
.output_tolerance = DEFAULT_OUTPUT_TOLERANCE,
.expected_output = {{0}},
},

{
.a = {{ 3, 0},
{-1, 2},
{ 1, 1}},
.b = {{4, -1},
{0, 2}},
.output_tolerance = DEFAULT_OUTPUT_TOLERANCE,
.expected_output = {{ 12, -3},
{-4, 5},
{ 4, 1}},
},

{
.a = {{ 1, 5, 2},
{-1, 0, 1},
{ 3, 2, 4}},
.b = {{ 6, 1, 3},
{-1, 1, 2},
{ 4, 1, 3}},
.output_tolerance = DEFAULT_OUTPUT_TOLERANCE,
.expected_output = {{ 9, 8, 19},
{-2, 0, 0},
{32, 9, 25}},
},

{
.a = {{1, 2},
{3, 4}},
.b = {{5},
{6}},
.output_tolerance = DEFAULT_OUTPUT_TOLERANCE,
.expected_output = {{17},
{39}},
},

};
run_test_cases(test_cases, N_TEST_CASES, "nn_dot_product_matrix", nn_dot_product_matrix);
return 0;
}

0 comments on commit 43d7182

Please sign in to comment.