Skip to content

Commit

Permalink
Merge pull request #281 from BrainJS/185-cross-validation-fixes
Browse files Browse the repository at this point in the history
fix: Fix CrossValidate to have tests for when data too small
  • Loading branch information
robertleeplummerjr authored Oct 2, 2018
2 parents f0a1a56 + ca437f3 commit fc48ae6
Show file tree
Hide file tree
Showing 14 changed files with 85 additions and 53 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,9 @@ With multiple networks you can train in parallel like this:
### Cross Validation
[Cross Validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)) can provide a less fragile way of training on larger data sets. The brain.js api provides Cross Validation in this example:
```js
const crossValidate = new CrossValidate(brain.NeuralNetwork, networkOptions);
const stats = crossValidate.train(data, trainingOptions, k); //note k (or KFolds) is optional
const crossValidate = new brain.CrossValidate(brain.NeuralNetwork, networkOptions);
crossValidate.train(data, trainingOptions, k); //note k (or KFolds) is optional
const json = crossValidate.toJSON(); // all stats in json as well as neural networks
const net = crossValidate.toNeuralNetwork();


Expand Down
2 changes: 1 addition & 1 deletion bower.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@
"node_modules",
"test"
],
"version": "1.4.1"
"version": "1.4.2"
}
15 changes: 10 additions & 5 deletions browser.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* license: MIT (http://opensource.org/licenses/MIT)
* author: Heather Arthur <[email protected]>
* homepage: https://github.com/brainjs/brain.js#readme
* version: 1.4.1
* version: 1.4.2
*
* acorn:
* license: MIT (http://opensource.org/licenses/MIT)
Expand Down Expand Up @@ -214,8 +214,13 @@ var CrossValidate = function () {

}, {
key: "train",
value: function train(data, trainOpts, k) {
k = k || 4;
value: function train(data) {
var trainOpts = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};
var k = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : 4;

if (data.length <= k) {
throw new Error("Training set size is too small for " + data.length + " k folds of " + k);
}
var size = data.length / k;

if (data.constructor === Array) {
Expand Down Expand Up @@ -1946,8 +1951,8 @@ var NeuralNetwork = function () {
falseNeg: falseNeg,
falsePos: falsePos,
total: data.length,
precision: truePos / (truePos + falsePos),
recall: truePos / (truePos + falseNeg),
precision: truePos > 0 ? truePos / (truePos + falsePos) : 0,
recall: truePos > 0 ? truePos / (truePos + falseNeg) : 0,
accuracy: (trueNeg + truePos) / data.length
});
}
Expand Down
13 changes: 7 additions & 6 deletions browser.min.js

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions dist/cross-validate.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/cross-validate.js.map

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions dist/neural-network.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/neural-network.js.map

Large diffs are not rendered by default.

16 changes: 2 additions & 14 deletions examples-typescript/cross-validate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,11 @@ const trainingData = [
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
// repeat xor data to have enough to train with
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },
{ input: [1, 0], output: [1] }
];

const netOptions = {
Expand Down
16 changes: 2 additions & 14 deletions examples/cross-validate.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,11 @@ const trainingData = [
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
// repeat xor data to have enough to train with
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

// xor repeats
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },
{ input: [1, 0], output: [1] }
];

const netOptions = {
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "brain.js",
"description": "Neural network library",
"version": "1.4.1",
"version": "1.4.2",
"author": "Heather Arthur <[email protected]>",
"repository": {
"type": "git",
Expand Down
6 changes: 4 additions & 2 deletions src/cross-validate.js
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ export default class CrossValidate {
* }
* }
*/
train(data, trainOpts, k) {
k = k || 4;
train(data, trainOpts = {}, k = 4) {
if (data.length <= k) {
throw new Error(`Training set size is too small for ${ data.length } k folds of ${ k }`);
}
let size = data.length / k;

if (data.constructor === Array) {
Expand Down
4 changes: 2 additions & 2 deletions src/neural-network.js
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,8 @@ export default class NeuralNetwork {
falseNeg: falseNeg,
falsePos: falsePos,
total: data.length,
precision: truePos / (truePos + falsePos),
recall: truePos / (truePos + falseNeg),
precision: truePos > 0 ? truePos / (truePos + falsePos) : 0,
recall: truePos > 0 ? truePos / (truePos + falseNeg) : 0,
accuracy: (trueNeg + truePos) / data.length
});
}
Expand Down
42 changes: 42 additions & 0 deletions test/base/cross-validation.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import assert from 'assert';
import brain from '../../src';
import CrossValidate from '../../src/cross-validate';

describe('CrossValidation', () => {
describe('simple xor example', () => {
it('throws exception when training set is too small', () => {
const xorTrainingData = [
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] }
];
const net = new CrossValidate(brain.NeuralNetwork);
assert.throws(() => {
net.train(xorTrainingData);
});
});
it('handles training and outputs values that are all numbers', () => {
const xorTrainingData = [
{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] },

{ input: [0, 1], output: [1] },
{ input: [0, 0], output: [0] },
{ input: [1, 1], output: [0] },
{ input: [1, 0], output: [1] }
];
const net = new CrossValidate(brain.NeuralNetwork);
net.train(xorTrainingData);
const json = net.toJSON();
for (let p in json.avgs) {
assert(json.avgs[p] >= 0);
}
for (let p in json.stats) {
assert(json.stats[p] >= 0);
}
});
});
});

0 comments on commit fc48ae6

Please sign in to comment.