Skip to content

Commit 3c34194

Browse files
authored
Upgrade tfjs (#427)
* significant TFJS upgrade * adding scripting test
1 parent 8bfd8ab commit 3c34194

File tree

6 files changed

+1527
-134
lines changed

6 files changed

+1527
-134
lines changed

__tests__/regressionCheck.ts

+51-31
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
1-
import * as tf from '@tensorflow/tfjs'
2-
import { load } from '../src/index'
3-
const fs = require('fs');
4-
const jpeg = require('jpeg-js');
5-
1+
import * as tf from "@tensorflow/tfjs";
2+
import { load } from "../src/index";
3+
import { exec } from "child_process";
4+
const fs = require("fs");
5+
const jpeg = require("jpeg-js");
66

77
// Fix for JEST
8-
const globalAny: any = global
9-
globalAny.fetch = require('node-fetch')
10-
const timeoutMS = 10000
11-
const NUMBER_OF_CHANNELS = 3
12-
8+
const globalAny: any = global;
9+
globalAny.fetch = require("node-fetch");
10+
const timeoutMS = 10000;
11+
const NUMBER_OF_CHANNELS = 3;
1312

1413
const readImage = (path: string) => {
15-
const buf = fs.readFileSync(path)
16-
const pixels = jpeg.decode(buf, true)
17-
return pixels
18-
}
14+
const buf = fs.readFileSync(path);
15+
const pixels = jpeg.decode(buf, true);
16+
return pixels;
17+
};
1918

2019
// @ts-ignore
2120
const imageByteArray = (image, numChannels: number) => {
22-
const pixels = image.data
21+
const pixels = image.data;
2322
const numPixels = image.width * image.height;
2423
const values = new Int32Array(numPixels * numChannels);
2524

@@ -29,22 +28,43 @@ const imageByteArray = (image, numChannels: number) => {
2928
}
3029
}
3130

32-
return values
33-
}
31+
return values;
32+
};
3433

3534
// @ts-ignore
3635
const imageToInput = (image, numChannels: number) => {
37-
const values = imageByteArray(image, numChannels)
38-
const outShape = [image.height, image.width, numChannels] as [number, number, number];
39-
const input = tf.tensor3d(values, outShape, 'int32');
40-
41-
return input
42-
}
43-
44-
it("Snapshots", async () => {
45-
const model = await load()
46-
const logo = readImage(`${__dirname}/../_art/nsfwjs_logo.jpg`)
47-
const input = imageToInput(logo, NUMBER_OF_CHANNELS)
48-
const predictions = await model.classify(input)
49-
expect(predictions).toMatchSnapshot()
50-
}, timeoutMS)
36+
const values = imageByteArray(image, numChannels);
37+
const outShape = [image.height, image.width, numChannels] as [
38+
number,
39+
number,
40+
number
41+
];
42+
const input = tf.tensor3d(values, outShape, "int32");
43+
44+
return input;
45+
};
46+
47+
it(
48+
"Snapshots",
49+
async () => {
50+
const model = await load();
51+
const logo = readImage(`${__dirname}/../_art/nsfwjs_logo.jpg`);
52+
const input = imageToInput(logo, NUMBER_OF_CHANNELS);
53+
const predictions = await model.classify(input);
54+
expect(predictions).toMatchSnapshot();
55+
},
56+
timeoutMS
57+
);
58+
59+
it(
60+
"Bundles and minifies",
61+
(done) => {
62+
const cmd = "yarn scriptbundle && yarn minbundle";
63+
exec(cmd, (err) => {
64+
if (err) done.fail("Failed to bundle and minify");
65+
// All good!
66+
done();
67+
});
68+
},
69+
timeoutMS * 6
70+
);

example/minimal_demo/index.html

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
<!-- Load TensorFlow.js. This is required -->
22
<script
3-
src="https://unpkg.com/@tensorflow/tfjs@1.2.8"
3+
src="https://unpkg.com/@tensorflow/tfjs@2.6.0"
44
type="text/javascript"
55
></script>
6-
<script src="https://unpkg.com/nsfwjs@2.1.0" type="text/javascript"></script>
6+
<script src="https://unpkg.com/nsfwjs@2.3.0" type="text/javascript"></script>
77

88
<!-- For testing: Load from local bundle `yarn scriptbundle && yarn minbundle` -->
99
<!-- <script src="../../dist/nsfwjs.min.js"></script> -->
1010

1111
<script>
1212
// const nsfwjs = require('nsfwjs')
13-
const img = new Image()
14-
img.crossOrigin = 'anonymous'
13+
const img = new Image();
14+
img.crossOrigin = "anonymous";
1515
// some image here
16-
img.src = 'https://i.imgur.com/Kwxetau.jpg'
16+
img.src = "https://i.imgur.com/Kwxetau.jpg";
1717

1818
// Load the model.
19-
nsfwjs.load().then(model => {
19+
nsfwjs.load().then((model) => {
2020
// Classify the image.
21-
model.classify(img).then(predictions => {
22-
console.log('Predictions', predictions)
23-
})
24-
})
21+
model.classify(img).then((predictions) => {
22+
console.log("Predictions", predictions);
23+
});
24+
});
2525
</script>
2626
<pre>Checkout console.log output for results!</pre>

example/nsfw_demo/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"version": "0.1.0",
44
"private": true,
55
"dependencies": {
6-
"@tensorflow/tfjs": "^1.7.4",
6+
"@tensorflow/tfjs": "^2.6.0",
77
"nsfwjs": "../../",
88
"react": "^16.8.1",
99
"react-dom": "^16.8.1",

0 commit comments

Comments
 (0)