Skip to content

Commit 710e5e8

Browse files
authored
Merge pull request #39 from DanRuta/dev
v3.2.0
2 parents 4541adf + 26bd527 commit 710e5e8

39 files changed

+3888
-238
lines changed

CHANGELOG.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
# 3.2.0 - IMG data, validation, early stopping
2+
---
3+
#### Network
4+
- Added weight+bias importing and exporting via images, using IMGArrays
5+
- Added validation config to .train(), with interval config
6+
- Added early stopping to validation, with threshold stopping condition
7+
- Added early stopping patience condition
8+
- Added early stopping divergence condition
9+
- Breaking change: "error" key in training callbacks have been changed to "trainingError"
10+
- Breaking change: Removed ability to use either data keys 'expected' and 'output'. Now just 'expected'.
11+
12+
#### NetUtil
13+
- Added splitData function
14+
- Added normalize function
15+
16+
#### NetMath
17+
- Added root mean squared error cost function
18+
- Added momentum weight update function
19+
- Breaking change: Renamed "vanilla update fn" to "vanilla sgd"
20+
121
# 3.1.0 - Optimizations
222
---
323
#### ConvLayer

README.md

Lines changed: 136 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@ https://ai.danruta.co.uk/webassembly - Performance comparison between JS and Web
1919
---
2020
There are two different versions of jsNet: WebAssembly, and JavaScript-only. There are demos included for loading both versions, in nodejs, as well as in the browser. The WebAssembly version is a little more complex to load, due to the NetWASM files which are generated by emscripten, containing the compiled code and the glue code to manage the WASM code. The ```NetWASM.js``` lazy loads the ```NetWASM.wasm``` file with the given path.
2121

22-
The API has been kept the same as the JavaScript only version. Every single value has get/set bindings to the WebAssembly variables, meaning that apart from not being able to freely browse the values in dev tools (need to call them, to see them), you should notice no API difference between the two versions.
22+
The API has been kept the same as the JavaScript only version. Every single value has get/set bindings to the WebAssembly variables, meaning that apart from not being able to freely browse the values in dev tools (need to call them, to see them), you should notice no API difference between the two versions. One thing to note is that when changing primitive WebAssembly array values, eg, setting `net.layers[1].neurons[0].weights[0]` to 1, you need to set the entire, modified weights array, not at an index. For example, you would do this instead:
23+
```javascript
24+
const weights = net.layers[1].neurons[0].weights
25+
weights[0] = 1
26+
net.layers[1].neurons[0].weights = weights
27+
```
2328

2429
Note that you need to serve files via a server (a basic server is an included) to load WebAssembly into a browser.
2530

@@ -164,11 +169,12 @@ When building a convolutional network, make sure that the number of neurons in t
164169
### Training
165170
----
166171

167-
The data structure must be an object with key ```input``` having an array of numbers, and key ```expected``` or ```output``` holding the expected output of the network. For example, the following are both valid inputs for both training and testing.
172+
The data structure must be an object with key ```input``` having an array of numbers, and key ```expected``` holding the expected output of the network. For example, the following is a valid input for training, validation and testing.
168173
```javascript
169174
{input: [1,0,0.2], expected: [1, 2]}
170-
{input: [1,0,0.2], output: [1, 2]}
171175
```
176+
***Tip**: You can normalize data using the ```NetUtil.normalize()``` function (see at the bottom)*
177+
172178
You train the network by passing a set of data. The network will log to the console the error and epoch number, after each epoch, as well as time elapsed and average epoch duration.
173179
```javascript
174180
const {training} = mnist.set(800, 200) // Get the training data from the mnist library, linked above
@@ -187,16 +193,70 @@ By default, this is ```1``` and represents how many times the data passed will b
187193
net.train(training, {epochs: 5}) // This will run through the training data 5 times
188194
```
189195
###### Callback
190-
You can also provide a callback in the options parameter, which will get called after each iteration (Maybe updating a graph?). The callback is passed how many iterations have passed, the error, the milliseconds elapsed and the input data for that iteration.
196+
You can also provide a callback in the options parameter, which will get called after each iteration (Maybe updating a graph?). The callback is passed how many iterations have passed, the milliseconds elapsed since training started, and the validation error OR the training error with input data for that iteration.
191197
```javascript
192-
const doSomeStuff = ({iterations, error, elapsed, input}) => ....
198+
const doSomeStuff = ({iterations, trainingError, validationError, elapsed, input}) => ....
193199
net.train(training, {callback: doSomeStuff})
194200
```
195201
###### Log
196202
You can turn off the logging by passing log: false in the options parameter.
197203
```javascript
198204
net.train(training, {log: false})
199205
```
206+
207+
###### Validation
208+
You can specify an array of data to use as validation. This must have the same structure as the training/test data. The validation config contains three parts: data, interval, and early stopping (see below). The data is where the data is provided. The interval is an integer, representing how many training iterations pass between validations of the entire validation set. By default, this is set to 1 epoch, aka the length of the given training data set.
209+
```javascript
210+
// Validate every 5 training iterations
211+
net.train(training, {validation: {
212+
data: [...],
213+
interval: 5
214+
}})
215+
// Validate every 3 epochs
216+
net.train(training, {validation: {
217+
data: [...],
218+
interval: training.length * 3
219+
}})
220+
```
221+
**Tip**: You can use ```NetUtil.splitData(data)``` to split a large array of data into training, validation, and test arrays, with default or specified ratios. See the NetUtil section at the bottom.
222+
223+
###### Early stopping
224+
When using validation data, you can specify an extra config object, `earlyStopping`, to configure stopping the training early, once a condition has been met, to counter overfitting. By default, this is turned off, but each option has default values, once the type is specified, via the `type` key.
225+
226+
| Type | What it does | Available Configurations | Default value |
227+
|:-------------:| :-----:| :-----:| :---: |
228+
| threshold | Stops the training the first time the validation error reaches, or goes below the specified threshold. A final backward pass is made, and weights updated, before stopping. | threshold. | 0.01 |
229+
| patience | This backs up the weights and biases of the network when the validation error reaches a new best low, following which, if the validation error is worse, a certain number of times in a row, it stops the training and reverts the network weights and biases to the backed up values. The number of times in a row to tolerate is configured via the `patience` hyperparameter | patience | 20 |
230+
| divergence | This backs up the weights and biases of the network when the validation error reaches a new best low, following which, if the validation error is worse, by at least a percent value equal to that specified, it stops the training and reverts the network weights and biases to the backed up values. The percentage is configured via the `percent` hyperparameter. A very jittery validation error is likely to stop the training very early, when using this condition. | percent | 30 |
231+
232+
Examples:
233+
```javascript
234+
// Threshold - Training stops once the validation error reaches down to at most 0.2
235+
net.train(training, {validation: {
236+
data: [...],
237+
earlyStopping: {
238+
type: "threshold",
239+
threshold: 0.2
240+
}
241+
}})
242+
// Patience - Training stops once the validation error is worse than the best found, 20 times in a row
243+
net.train(training, {validation: {
244+
data: [...],
245+
earlyStopping: {
246+
type: "patience",
247+
patience: 10
248+
}
249+
}})
250+
// Divergence - Training stops once the validation error is worse than the best found, by 30%
251+
net.train(training, {validation: {
252+
data: [...],
253+
earlyStopping: {
254+
type: "divergence",
255+
percent: 30
256+
}
257+
}})
258+
```
259+
200260
###### Mini Batch Size
201261
You can use mini batch SGD training by specifying a mini batch size to use (changing it from the default, 1). You can set it to true, and it will default to how many classifications there are in the training data.
202262

@@ -235,20 +295,46 @@ net.train(training).then(() => net.test(test, {callback: doSomeStuff}))
235295

236296
### Exporting
237297
---
238-
Weights data is exported as a JSON object.
298+
There are two way you can manage your data. The built in way is to use JSON for importing and exporting. If you provide my IMGArrays library (https://github.com/DanRuta/IMGArrays), you can alternatively use images, which are much quicker and easier to use, when using the browser.
299+
300+
To export weights data as JSON:
239301
```javascript
240302
const data = trainedNet.toJSON()
241303
```
242304

305+
See the IMGArrays library documentation for more details, and nodejs instructions, but its integration into jsNet is as follows:
306+
```javascript
307+
const canvas = trainedNet.toIMG(IMGArrays, opts)
308+
IMGArrays.downloadImage(canvas)
309+
```
310+
243311
### Importing
244312
---
245-
Only the weights are exported. You still need to build the net with the same structure and configs, eg activation function.
313+
Only the weights are exported. You still need to build the net with the same structure and configs, eg activation function. Again, data can be imported as either JSON or an image, when using IMGArrays, like above.
314+
315+
When using json:
246316
```javascript
247317
const freshNetwork = new Network(...)
248318
freshNetwork.fromJSON(data)
249319
```
250320
If using exported data from before version 2.0.0, just do a find-replace of "neurons" -> "weights" on the exported data and it will work with the new version.
251321

322+
When using IMGArrays:
323+
```javascript
324+
const freshNetwork = new Network(...)
325+
freshNetwork.fromIMG(document.querySelector("img"), IMGArrays, opts)
326+
```
327+
328+
As an example you could run, you can use the image below to load data for the following jsNet configuration, to have a basic model trained on MNIST.
329+
```javascript
330+
const net = new Network({
331+
layers: [new FCLayer(784), new FCLayer(100), new FCLayer(10)]
332+
})
333+
net.fromIMG(document.querySelector("img"), IMGArrays)
334+
```
335+
336+
<img width="100%" src="fc-784f-100f-10f.png">
337+
252338
### Trained usage
253339
---
254340
Once the network has been trained, tested and imported into your page, you can use it via the ```forward``` function.
@@ -275,7 +361,7 @@ const net = new Network({
275361
l2: undefined,
276362
l1: undefined,
277363
layers: [ /* 3 FCLayers */ ]
278-
updateFn: "vanillaupdatefn",
364+
updateFn: "vanillasgd",
279365
weightsConfig: {
280366
distribution: "xavieruniform"
281367
}
@@ -289,7 +375,7 @@ You can check the framework version via Network.version (static).
289375
| Attribute | What it does | Available Configurations | Default value |
290376
|:-------------:| :-----:| :-----:| :---: |
291377
| learningRate | The speed at which the net will learn. | Any number | 0.2 (see below for exceptions) |
292-
| cost | Cost function to use when printing out the net error | crossEntropy, meanSquaredError | meansquarederror |
378+
| cost | Cost function to use when printing out the net error | crossEntropy, meanSquaredError, rootMeanSquaredError | meansquarederror |
293379
| channels | Specifies the number of channels in the input data. EG, 3 for RGB images. Used by convolutional networks. | Any number | undefined |
294380
| conv | (See ConvLayer) An object where the optional keys filterSize, zeroPadding and stride set values for all Conv layers to default to | Object | {} |
295381
| pool | (See PoolLayer) An object where the optional keys size and stride set values for all Pool layers to default to | Object | {} |
@@ -327,9 +413,10 @@ Learning rate is 0.2 by default, except when using the following configurations:
327413
### Weight update functions
328414
| Attribute | What it does | Available Configurations | Default value |
329415
|:-------------:| :-----:| :-----:| :---: |
330-
| updateFn | The function used for updating the weights/bias. The vanillaupdatefn option just sets the network to update the weights without any changes to learning rate. | vanillaupdatefn, gain, adagrad, RMSProp, adam , adadelta| vanillaupdatefn |
416+
| updateFn | The function used for updating the weights/bias. The vanillasgd option just sets the network to update the weights without any changes to learning rate. | vanillasgd, gain, adagrad, RMSProp, adam , adadelta, momentum | vanillasgd |
331417
| rmsDecay | The decay rate for RMSProp, when used | Any number | 0.99 |
332418
| rho | Momentum for Adadelta, when used | Any number | 0.95 |
419+
| momentum | Momentum for the (sgd) momentum update function. | Any number | 0.9 |
333420

334421
##### Examples
335422
```javascript
@@ -543,6 +630,45 @@ net = new Network({
543630
learningRate: 0.05
544631
})
545632
```
633+
### NetUtil
634+
There is a NetUtil class included, containing some potentially useful functions.
635+
636+
### shuffle(data)
637+
_array_ **data** - The data array to shuffle
638+
639+
This randomly shuffles an array _in place_ (aka, data passed by reference, the parameter passed will be changed).
640+
##### Example
641+
```javascript
642+
const data = [1,2,3,4,5]
643+
NetUtil.shuffle(data)
644+
// data != [1,2,3,4,5]
645+
```
646+
647+
### splitData(data), splitData(data, {training=0.7, validation=0.15, test=0.15})
648+
_array_ **data** - The data array to split
649+
_object_ configs: Override values for the ratios to split. The values should add up to 1.
650+
651+
This is used for splitting a large array of data into the different parts needed for training.
652+
##### Example
653+
```javascript
654+
const data = [1,2,3,4,5]
655+
const {training, validation, test} = NetUtil.splitData(data)
656+
// or
657+
const {training, validation, test} = NetUtil.splitData(data, {training: 0.5, validation: 0.25, test: 0.25})
658+
```
659+
660+
### normalize(data)
661+
_array_ **data** - The data array to normalize
662+
663+
This normalizes an array of positive and/or negative numbers to a [0-1] range. The data is changed in place, similar to the shuffle function.
664+
##### Example
665+
```javascript
666+
const data = [1,2,3,-5,0.4,2]
667+
const {minValue, maxValue} = NetUtil.normalize(data)
668+
// data == [0.75, 0.875, 1, 0, 0.675, 0.875]
669+
// minValue == -5
670+
// maxValue == 3
671+
```
546672

547673
## Future plans
548674
---

dev/cpp/ConvLayer.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,13 @@ void ConvLayer::applyDeltaWeights (void) {
208208
+ net->l2 * filterWeights[f][c][r][v]
209209
+ net->l1 * (filterWeights[f][c][r][v] > 0 ? 1 : -1)) / net->miniBatchSize;
210210

211-
filterWeights[f][c][r][v] = NetMath::vanillaupdatefn(netInstance, filterWeights[f][c][r][v], regularized);
211+
filterWeights[f][c][r][v] = NetMath::vanillasgd(netInstance, filterWeights[f][c][r][v], regularized);
212212

213213
if (net->maxNorm) net->maxNormTotal += filterWeights[f][c][r][v] * filterWeights[f][c][r][v];
214214
}
215215
}
216216
}
217-
biases[f] = NetMath::vanillaupdatefn(netInstance, biases[f], deltaBiases[f]);
217+
biases[f] = NetMath::vanillasgd(netInstance, biases[f], deltaBiases[f]);
218218
}
219219
break;
220220
case 1: // gain
@@ -318,4 +318,14 @@ void ConvLayer::applyDeltaWeights (void) {
318318
net->maxNormTotal = sqrt(net->maxNormTotal);
319319
NetMath::maxNorm(netInstance);
320320
}
321-
}
321+
}
322+
323+
void ConvLayer::backUpValidation (void) {
324+
validationBiases = biases;
325+
validationFilterWeights = filterWeights;
326+
}
327+
328+
void ConvLayer::restoreValidation (void) {
329+
biases = validationBiases;
330+
filterWeights = validationFilterWeights;
331+
}

dev/cpp/FCLayer.cpp

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,11 @@ void FCLayer::applyDeltaWeights (void) {
188188
+ net->l2 * weights[n][dw]
189189
+ net->l1 * (weights[n][dw] > 0 ? 1 : -1)) / net->miniBatchSize;
190190

191-
weights[n][dw] = NetMath::vanillaupdatefn(netInstance, weights[n][dw], regularized);
191+
weights[n][dw] = NetMath::vanillasgd(netInstance, weights[n][dw], regularized);
192192

193193
if (net->maxNorm) net->maxNormTotal += weights[n][dw] * weights[n][dw];
194194
}
195-
biases[n] = NetMath::vanillaupdatefn(netInstance, biases[n], deltaBiases[n]);
195+
biases[n] = NetMath::vanillasgd(netInstance, biases[n], deltaBiases[n]);
196196
}
197197
break;
198198
case 1: // gain
@@ -276,4 +276,33 @@ void FCLayer::applyDeltaWeights (void) {
276276
net->maxNormTotal = sqrt(net->maxNormTotal);
277277
NetMath::maxNorm(netInstance);
278278
}
279-
}
279+
}
280+
281+
void FCLayer::backUpValidation (void) {
282+
283+
validationBiases = {};
284+
validationWeights = {};
285+
286+
for (int n=0; n<neurons.size(); n++) {
287+
validationBiases.push_back(biases[n]);
288+
289+
std::vector<double> neuron;
290+
291+
for (int w=0; w<weights[n].size(); w++) {
292+
neuron.push_back(weights[n][w]);
293+
}
294+
295+
validationWeights.push_back(neuron);
296+
}
297+
}
298+
299+
void FCLayer::restoreValidation (void) {
300+
301+
for (int n=0; n<neurons.size(); n++) {
302+
biases[n] = validationBiases[n];
303+
304+
for (int w=0; w<weights[n].size(); w++) {
305+
weights[n][w] = validationWeights[n][w];
306+
}
307+
}
308+
}

dev/cpp/Filter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ void Filter::init (int netInstance, int channels, int filterSize) {
1111
case 2: // adagrad
1212
case 3: // rmsprop
1313
case 5: // adadelta
14+
case 6: // momentum
1415
biasCache = 0;
1516
weightsCache = NetUtil::createVolume<double>(channels, filterSize, filterSize, 0);
1617

0 commit comments

Comments
 (0)