-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add nn_accuracy * Add nn_argmax * Move _POSIX_C_SOURCE
- Loading branch information
Showing
21 changed files
with
749 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef NN_ACCURACY_H | ||
#define NN_ACCURACY_H | ||
|
||
#include "nn_tensor.h" | ||
#include "nn_error.h" | ||
|
||
/** | ||
* @brief Returns the accuracy between the predictions and actual tensors. | ||
* | ||
* @param predictions The predictions (output of the network) tensor. | ||
* @param actual The actual (ground truth) tensor (one-hot encoded or categorical). | ||
* @param error The error instance to set if an error occurs. | ||
* | ||
* @return The accuracy. | ||
*/ | ||
NNTensorUnit nn_accuracy(const NNTensor *predictions, const NNTensor *actual, NNError *error); | ||
|
||
#endif // NN_ACCURACY_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#ifndef NN_ARGMAX_H | ||
#define NN_ARGMAX_H | ||
|
||
#include "nn_tensor.h" | ||
#include "nn_error.h" | ||
#include <stdbool.h> | ||
#include <stddef.h> | ||
|
||
/** | ||
* @brief Finds the index of the maximum value in a 1-dimensional tensor. | ||
* | ||
* @param input The input tensor. | ||
* @param error The error instance to set if an error occurs. | ||
* | ||
* @return The index of the maximum value. | ||
*/ | ||
size_t nn_argmax(const NNTensor *input, NNError *error); | ||
|
||
/** | ||
* @brief Finds the indices of the maximum values in each row of a 2-dimensional tensor batch. | ||
* | ||
* @param input The input tensor. | ||
* @param output The output tensor to store the indices. | ||
* @param error The error instance to set if an error occurs. | ||
* | ||
* @return True or false. | ||
*/ | ||
bool nn_argmax_tensor_batch(const NNTensor *input, NNTensor *output, NNError *error); | ||
|
||
#endif // NN_ARGMAX_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# This script generates test cases for nn_accuracy function. | ||
|
||
import numpy as np | ||
|
||
# Returns the softmax activation function result. | ||
def nn_act_func_softmax(x): | ||
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) | ||
return exp_x / np.sum(exp_x, axis=-1, keepdims=True) | ||
|
||
# Generates a one-hot encoded vector. | ||
def one_hot_encode(labels, num_classes): | ||
return np.eye(num_classes)[labels] | ||
|
||
# Generates a test case. | ||
def generate_test_case(batch_size, num_classes, one_hot): | ||
# Init vars | ||
predictions = np.random.uniform(0, 1, (batch_size, num_classes)) | ||
predictions = nn_act_func_softmax(predictions) # ensure the predictions are probabilities | ||
actual = np.random.randint(0, num_classes, batch_size) | ||
if one_hot: | ||
actual = one_hot_encode(actual, num_classes) | ||
accuracy = np.mean(np.argmax(predictions, axis=1) == np.argmax(actual, axis=1)) if one_hot else np.mean(np.argmax(predictions, axis=1) == actual) | ||
|
||
# Generate the partial C code | ||
predictions_c = ", ".join(map(str, predictions.flatten())) | ||
actual_c = ", ".join(map(str, actual.flatten())) | ||
return f""" | ||
{{ | ||
.predictions = nn_tensor_init_NNTensor(2, (const size_t[]){{{batch_size}, {num_classes}}}, false, (const NNTensorUnit[]){{{predictions_c}}}, NULL), | ||
.actual = nn_tensor_init_NNTensor({2 if one_hot else 1}, (const size_t[]){{{batch_size}{', ' + str(num_classes) if one_hot else ''}}}, false, (const NNTensorUnit[]){{{actual_c}}}, NULL), | ||
.expected_value = {accuracy}, | ||
.expected_tolerance = default_expected_tolerance, | ||
}}""" | ||
|
||
# Generates test cases. | ||
def generate_test_cases(batch_sizes, num_classes_list, one_hot_encodings): | ||
test_cases = [] | ||
for batch_size in batch_sizes: | ||
for num_classes in num_classes_list: | ||
for one_hot in one_hot_encodings: | ||
test_case = generate_test_case(batch_size, num_classes, one_hot) | ||
test_cases.append(test_case) | ||
|
||
return test_cases | ||
|
||
# Generate test cases | ||
np.random.seed(2024) | ||
batch_sizes = [1, 2, 3, 4, 5] | ||
num_classes_list = [2, 3, 4, 5, 6] | ||
one_hot_encodings = [True, False] | ||
test_cases = generate_test_cases(batch_sizes, num_classes_list, one_hot_encodings) | ||
|
||
print(f"TestCase test_cases[] = {{{','.join(test_cases)},\n}};") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#include "nn_accuracy.h" | ||
#include "nn_argmax.h" | ||
#include "nn_debug.h" | ||
#include "nn_error.h" | ||
#include "nn_tensor.h" | ||
|
||
NNTensorUnit nn_accuracy(const NNTensor *predictions, const NNTensor *actual, NNError *error) { | ||
NN_DEBUG_PRINT(5, "function %s called with predictions.dims=%zu actual.dims=%zu\n", __func__, predictions->dims, actual->dims); | ||
|
||
if (!(predictions->flags & NN_TENSOR_FLAG_INIT) || !(actual->flags & NN_TENSOR_FLAG_INIT)) { | ||
nn_error_set(error, NN_ERROR_INVALID_ARGUMENT, "tensor predictions or actual is not initialized"); | ||
return 0; | ||
} else if (predictions->dims != 2 || actual->dims < 1 || actual->dims > 2 || predictions->sizes[0] != actual->sizes[0]) { | ||
// Only one-hot encoded or categorical tensors with the same batch size are allowed | ||
nn_error_set(error, NN_ERROR_INVALID_ARGUMENT, "only 2-dimensional predictions tensor and 1 or 2-dimensional actual tensor with the same batch size are allowed"); | ||
return 0; | ||
} | ||
|
||
// Determine the batch size, the number of classes and if the actual tensor is one-hot encoded | ||
size_t batch_size = predictions->sizes[0]; | ||
size_t num_classes = predictions->sizes[1]; | ||
bool one_hot = (actual->dims == 2 && actual->sizes[1] == num_classes); | ||
|
||
// Find the index of the maximum value in the predictions tensor | ||
NNTensor *predictions_argmax = nn_tensor_init_NNTensor(1, (size_t[]){batch_size}, true, NULL, error); | ||
if (!predictions_argmax) { | ||
return 0; | ||
} | ||
if (!nn_argmax_tensor_batch(predictions, predictions_argmax, error)) { | ||
nn_tensor_destroy_NNTensor(predictions_argmax); | ||
return 0; | ||
} | ||
|
||
// Compute the accuracy | ||
NNTensorUnit accuracy = 0; | ||
if (one_hot) { | ||
// Find the index of the maximum value in the actual tensor | ||
NNTensor *actual_argmax = nn_tensor_init_NNTensor(1, (size_t[]){batch_size}, true, NULL, error); | ||
if (!actual_argmax) { | ||
nn_tensor_destroy_NNTensor(predictions_argmax); | ||
return 0; | ||
} | ||
if (!nn_argmax_tensor_batch(actual, actual_argmax, error)) { | ||
nn_tensor_destroy_NNTensor(predictions_argmax); | ||
nn_tensor_destroy_NNTensor(actual_argmax); | ||
return 0; | ||
} | ||
|
||
// Iterate over the batch | ||
for (size_t i = 0; i < batch_size; i++) { | ||
// If the predicted class is equal to the actual class | ||
if (predictions_argmax->data[i] == actual_argmax->data[i]) { | ||
accuracy += 1; | ||
} | ||
} | ||
nn_tensor_destroy_NNTensor(actual_argmax); | ||
} else { | ||
// Iterate over the batch | ||
for (size_t i = 0; i < batch_size; i++) { | ||
// Find the index of the maximum value in the actual tensor | ||
size_t actual_argmax = (size_t)actual->data[i]; | ||
// If the predicted class is equal to the actual class | ||
if (predictions_argmax->data[i] == actual_argmax) { | ||
accuracy += 1; | ||
} | ||
} | ||
} | ||
nn_tensor_destroy_NNTensor(predictions_argmax); | ||
|
||
// Average the accuracy | ||
accuracy /= batch_size; | ||
|
||
return accuracy; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#include "nn_argmax.h" | ||
#include "nn_debug.h" | ||
#include "nn_error.h" | ||
|
||
size_t nn_argmax(const NNTensor *input, NNError *error) { | ||
NN_DEBUG_PRINT(5, "function %s called with input.dims=%zu\n", __func__, input->dims); | ||
|
||
if (!(input->flags & NN_TENSOR_FLAG_INIT)) { | ||
nn_error_set(error, NN_ERROR_INVALID_ARGUMENT, "tensor input is not initialized"); | ||
return 0; | ||
} else if (input->dims != 1) { | ||
nn_error_set(error, NN_ERROR_INVALID_ARGUMENT, "only 1-dimensional tensors are allowed"); | ||
return 0; | ||
} | ||
|
||
// Find the index of the maximum value in the input tensor | ||
size_t max_index = 0; | ||
NNTensorUnit max_value = input->data[0]; | ||
for (size_t i = 1; i < input->sizes[0]; ++i) { | ||
if (input->data[i] > max_value) { | ||
max_value = input->data[i]; | ||
max_index = i; | ||
} | ||
} | ||
|
||
return max_index; | ||
} | ||
|
||
bool nn_argmax_tensor_batch(const NNTensor *input, NNTensor *output, NNError *error) { | ||
NN_DEBUG_PRINT(5, "function %s called with input.dims=%zu output.dims=%zu\n", __func__, input->dims, output->dims); | ||
|
||
if (!(input->flags & NN_TENSOR_FLAG_INIT) || !(output->flags & NN_TENSOR_FLAG_INIT)) { | ||
nn_error_set(error, NN_ERROR_INVALID_ARGUMENT, "tensor input or output is not initialized"); | ||
return false; | ||
} else if (input->dims != 2 || output->dims != 1 || input->sizes[0] != output->sizes[0]) { | ||
nn_error_set(error, NN_ERROR_INVALID_ARGUMENT, "only 2-dimensional input tensor and 1-dimensional output tensor with the same batch size are allowed"); | ||
return false; | ||
} | ||
|
||
// Find the index of the maximum value in each row of the input tensor batch | ||
size_t batch_size = input->sizes[0]; | ||
size_t num_classes = input->sizes[1]; | ||
for (size_t i = 0; i < batch_size; ++i) { | ||
size_t max_index = 0; | ||
NNTensorUnit max_value = input->data[i * num_classes]; // first element in the row | ||
for (size_t j = 1; j < num_classes; ++j) { | ||
if (input->data[i * num_classes + j] > max_value) { | ||
max_value = input->data[i * num_classes + j]; | ||
max_index = j; | ||
} | ||
} | ||
output->data[i] = max_index; | ||
} | ||
|
||
return true; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
void test_nn_accuracy(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
tests/arch/generic/accuracy/nn_accuracy.c | ||
src/nn_accuracy.c | ||
src/nn_app.c | ||
src/nn_argmax.c | ||
src/nn_config.c | ||
src/nn_error.c | ||
src/nn_test.c |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#include "./accuracy.h" | ||
#include "nn_app.h" | ||
|
||
int main(int argc, char *argv[]) { | ||
nn_init_app(argc, argv); | ||
// nn_set_debug_level(5); // for debugging | ||
|
||
test_nn_accuracy(); | ||
|
||
return 0; | ||
} |
Oops, something went wrong.