Skip to content

Commit

Permalink
Use fixed-size arrays in Neuron functions (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
devfacet committed Apr 7, 2024
1 parent 89c16ee commit 00c9cf0
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 16 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ git submodule update --init
### Build and run examples

```shell
# Generic examples
make build-examples ARCH=generic
make run-examples ARCH=generic

# Arm examples
make build-examples ARCH=arm TECH=neon,cmsis-dsp
make run-examples ARCH=arm TECH=neon,cmsis-dsp
```
Expand Down
6 changes: 3 additions & 3 deletions include/nn_neuron.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ typedef struct {
} NNNeuron;

// nn_neuron_init initializes a neuron with the given arguments.
bool nn_neuron_init(NNNeuron *neuron, const float *weights, size_t n_weights, float bias, NNActivationFunction act_func, NNDotProductFunction dot_product_func, NNError *error);
bool nn_neuron_init(NNNeuron *neuron, const float weights[NEURON_MAX_WEIGHTS], size_t n_weights, float bias, NNActivationFunction act_func, NNDotProductFunction dot_product_func, NNError *error);

// nn_neuron_set_weights sets the weights of the given neuron.
bool nn_neuron_set_weights(NNNeuron *neuron, const float *weights, NNError *error);
bool nn_neuron_set_weights(NNNeuron *neuron, const float weights[NEURON_MAX_WEIGHTS], NNError *error);

// nn_neuron_set_bias sets the bias of the given neuron.
bool nn_neuron_set_bias(NNNeuron *neuron, float bias, NNError *error);

// nn_neuron_compute computes the given neuron and returns the output.
float nn_neuron_compute(const NNNeuron *neuron, const float *inputs, NNError *error);
float nn_neuron_compute(const NNNeuron *neuron, const float inputs[NEURON_MAX_WEIGHTS], NNError *error);

#endif // NN_NEURON_H
14 changes: 3 additions & 11 deletions src/nn_neuron.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <stdio.h>

// nn_neuron_init initializes a neuron with the given arguments.
bool nn_neuron_init(NNNeuron *neuron, const float *weights, size_t n_weights, float bias, NNActivationFunction act_func, NNDotProductFunction dot_product_func, NNError *error) {
bool nn_neuron_init(NNNeuron *neuron, const float weights[NEURON_MAX_WEIGHTS], size_t n_weights, float bias, NNActivationFunction act_func, NNDotProductFunction dot_product_func, NNError *error) {
nn_error_set(error, NN_ERROR_NONE, NULL);
if (neuron == NULL) {
nn_error_set(error, NN_ERROR_INVALID_INSTANCE, "neuron is NULL");
Expand All @@ -32,16 +32,12 @@ bool nn_neuron_init(NNNeuron *neuron, const float *weights, size_t n_weights, fl
}

// nn_neuron_set_weights sets the weights of the given neuron.
bool nn_neuron_set_weights(NNNeuron *neuron, const float *weights, NNError *error) {
bool nn_neuron_set_weights(NNNeuron *neuron, const float weights[NEURON_MAX_WEIGHTS], NNError *error) {
nn_error_set(error, NN_ERROR_NONE, NULL);
if (neuron == NULL) {
nn_error_set(error, NN_ERROR_INVALID_INSTANCE, "neuron is NULL");
return false;
}
if (weights == NULL) {
nn_error_set(error, NN_ERROR_INVALID_INSTANCE, "weights is NULL");
return false;
}
for (size_t i = 0; i < neuron->n_weights; ++i) {
neuron->weights[i] = weights[i];
}
Expand All @@ -60,16 +56,12 @@ bool nn_neuron_set_bias(NNNeuron *neuron, float bias, NNError *error) {
}

// nn_neuron_compute computes the given neuron and returns the output.
float nn_neuron_compute(const NNNeuron *neuron, const float *inputs, NNError *error) {
float nn_neuron_compute(const NNNeuron *neuron, const float inputs[NEURON_MAX_WEIGHTS], NNError *error) {
nn_error_set(error, NN_ERROR_NONE, NULL);
if (neuron == NULL) {
nn_error_set(error, NN_ERROR_INVALID_INSTANCE, "neuron is NULL");
return NAN;
}
if (inputs == NULL) {
nn_error_set(error, NN_ERROR_INVALID_INSTANCE, "inputs is NULL");
return NAN;
}

// Initialize the result
float result = 0.0f;
Expand Down

0 comments on commit 00c9cf0

Please sign in to comment.