From 63250eceec9dab31f10345e77722080b17100bc7 Mon Sep 17 00:00:00 2001 From: gaikwadrahul8 <115997457+gaikwadrahul8@users.noreply.github.com> Date: Sat, 13 Apr 2024 01:51:11 +0530 Subject: [PATCH] Address tfjs-react-native typos in documentation strings (#8217) --- tfjs-react-native/DEVELOPMENT.md | 13 +- .../integration_rn59/components/ml.ts | 69 +- .../components/tfjs_unit_test_runner.tsx | 2 +- .../src/bundle_resource_io_test.ts | 302 +++--- .../src/camera/camera_stream.tsx | 4 +- tfjs-react-native/src/camera/camera_test.ts | 979 +++++++++--------- .../src/platform_react_native.ts | 222 ++-- 7 files changed, 840 insertions(+), 751 deletions(-) diff --git a/tfjs-react-native/DEVELOPMENT.md b/tfjs-react-native/DEVELOPMENT.md index a27676e346a..d5ba2f8b93c 100644 --- a/tfjs-react-native/DEVELOPMENT.md +++ b/tfjs-react-native/DEVELOPMENT.md @@ -1,11 +1,11 @@ # Development -This file will document some of the differences from the regular developement workflow in [DEVELOPMENT.md](../DEVELOPMENT.md). You should read that document first to get familiar with typical TensorFlow.js development workflow. +This file will document some of the differences from the regular development workflow in [DEVELOPMENT.md](../DEVELOPMENT.md). You should read that document first to get familiar with typical TensorFlow.js development workflow. Development and testing for tfjs-react-native is somewhat different from the packages like tfjs-core or tfjs-layers for a few reasons: -- __Dependency on having a physical mobile device to run__: While the CPU backend can run in a simulator, the WebGL one requires running on a physical device. So most of the time you will want to test something using a mobile device connected to your computer. -- __No browser or node environment__: We are running JavaScript outside of a browser and outside of node. We thus have to make sure we don't include things that depend on those two environments. +- **Dependency on having a physical mobile device to run**: While the CPU backend can run in a simulator, the WebGL one requires running on a physical device. So most of the time you will want to test something using a mobile device connected to your computer. +- **No browser or node environment**: We are running JavaScript outside of a browser and outside of node. We thus have to make sure we don't include things that depend on those two environments. ## Key Terms & Caveats @@ -13,7 +13,7 @@ These are a few key terms/technologies to be familiar with that are different fr - [React Native](https://facebook.github.io/react-native/) — This is the framework that this package targets. - [Metro](https://facebook.github.io/metro/) — This is the bundler used to create the JavaScript bundle that is loaded into the native app by react native. - - The bundle needs to be created at 'compile time' thus all imports/requires need to be resolved. Thus _dynamic_ `import`s/`require`s are __statically resolved__. So you cannot exclude a require with a conditional in JS code. + - The bundle needs to be created at 'compile time' thus all imports/requires need to be resolved. Thus _dynamic_ `import`s/`require`s are **statically resolved**. So you cannot exclude a require with a conditional in JS code. - Since tfjs does dynamic `require`'s of certain node libraries that are not present in react native, files that do that need to be excluded from the metro build process. For end users, this is documented in the [README](../README.md), but it also happens in `integration_rn59/prep_tests.ts`. - Metro does not play well with symlinks, so if you are trying to develop against a local build of tfjs, copy the dist folder into the app's node_modules as appropriate. Do not use yalc. - [.ipa](https://en.wikipedia.org/wiki/.ipa) & [.apk](https://en.wikipedia.org/wiki/Android_application_package) — These are the formats for the final native bundle that is put on an iOS and Android device. They are created by their respective dev tools, [XCode](https://developer.apple.com/xcode/) and [Android Studio](https://developer.android.com/studio). @@ -33,8 +33,9 @@ Unit tests from tfjs-core are imported into a react native application and run a Because these are part of an app to run them you must compile and run the integration_rn59 of the target device. There is a button in that app to start the unit tests. This is _automated in CI_ and runs on: - - Changes to tfjs-core: [Tests will be run against HEAD of tfjs-core](../tfjs-core/cloudbuild.yml) - - Changes to tfjs-react-native: [Tests will be run against the **published** version](./cloudbuild.yml) of tfjs on npm that is references in `integration_rn59/package.json` + +- Changes to tfjs-core: [Tests will be run against HEAD of tfjs-core](../tfjs-core/cloudbuild.yml) +- Changes to tfjs-react-native: [Tests will be run against the **published** version](./cloudbuild.yml) of tfjs on npm that is references in `integration_rn59/package.json` ### Other integration tests diff --git a/tfjs-react-native/integration_rn59/components/ml.ts b/tfjs-react-native/integration_rn59/components/ml.ts index 7639493fbb3..8af7486f697 100644 --- a/tfjs-react-native/integration_rn59/components/ml.ts +++ b/tfjs-react-native/integration_rn59/components/ml.ts @@ -15,9 +15,12 @@ * ============================================================================= */ -import * as mobilenet from '@tensorflow-models/mobilenet'; -import * as tf from '@tensorflow/tfjs'; -import {asyncStorageIO, bundleResourceIO} from '@tensorflow/tfjs-react-native'; +import * as mobilenet from "@tensorflow-models/mobilenet"; +import * as tf from "@tensorflow/tfjs"; +import { + asyncStorageIO, + bundleResourceIO, +} from "@tensorflow/tfjs-react-native"; // All functions (i.e. 'runners") in this file are async // functions that return a function that can be invoked to @@ -64,11 +67,12 @@ export async function mobilenetRunner() { * A runner that loads a model bundled with the app and runs a prediction * through it. */ -const modelJson = require('../assets/model/bundle_model_test.json'); -const modelWeights = require('../assets/model/bundle_model_test_weights.bin'); +const modelJson = require("../assets/model/bundle_model_test.json"); +const modelWeights = require("../assets/model/bundle_model_test_weights.bin"); export async function localModelRunner() { - const model = - await tf.loadLayersModel(bundleResourceIO(modelJson, modelWeights)); + const model = await tf.loadLayersModel( + bundleResourceIO(modelJson, modelWeights) + ); return async () => { const res = model.predict(tf.randomNormal([1, 10])) as tf.Tensor; @@ -81,11 +85,12 @@ export async function localModelRunner() { * A runner that loads a model bundled with the app and runs a prediction * through it. */ -const modelJson2 = require('../assets/graph_model/model.json'); -const modelWeights2 = require('../assets/graph_model/group1-shard1of1.bin'); +const modelJson2 = require("../assets/graph_model/model.json"); +const modelWeights2 = require("../assets/graph_model/group1-shard1of1.bin"); export async function localGraphModelRunner() { - const model = - await tf.loadGraphModel(bundleResourceIO(modelJson2, modelWeights2)); + const model = await tf.loadGraphModel( + bundleResourceIO(modelJson2, modelWeights2) + ); return async () => { const res = model.predict(tf.randomNormal([1, 10])) as tf.Tensor; const data = await res.data(); @@ -97,33 +102,35 @@ export async function localGraphModelRunner() { * A runner that loads a sharded model bundled with the app and runs a * prediction through it. */ -const shardedModelJson = require('../assets/sharded_model/model.json'); -const shardedModelWeights1: number = - require('../assets/sharded_model/group1-shard1of2.bin'); -const shardedModelWeights2: number = - require('../assets/sharded_model/group1-shard2of2.bin'); +const shardedModelJson = require("../assets/sharded_model/model.json"); +const shardedModelWeights1: number = require("../assets/sharded_model/group1-shard1of2.bin"); +const shardedModelWeights2: number = require("../assets/sharded_model/group1-shard2of2.bin"); export async function localShardedGraphModelRunner() { - const model = await tf.loadGraphModel(bundleResourceIO( - shardedModelJson, [shardedModelWeights1, shardedModelWeights2])); + const model = await tf.loadGraphModel( + bundleResourceIO(shardedModelJson, [ + shardedModelWeights1, + shardedModelWeights2, + ]) + ); return async () => { const input = tf.zeros([1, 224, 224, 3]); const res = model.predict(input) as tf.Tensor; const data = await res.data(); - return JSON.stringify({predictionsLength: data.length}); + return JSON.stringify({ predictionsLength: data.length }); }; } /** - * A runner that traines a model. + * A runner that trains a model. */ export async function trainModelRunner() { // Define a model for linear regression. const model = tf.sequential(); - model.add(tf.layers.dense({units: 5, inputShape: [1]})); - model.add(tf.layers.dense({units: 1})); - model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); + model.add(tf.layers.dense({ units: 5, inputShape: [1] })); + model.add(tf.layers.dense({ units: 1 })); + model.compile({ loss: "meanSquaredError", optimizer: "sgd" }); // Generate some synthetic data for training. const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]); @@ -131,9 +138,9 @@ export async function trainModelRunner() { return async () => { // Train the model using the data. - await model.fit(xs, ys, {epochs: 20}); + await model.fit(xs, ys, { epochs: 20 }); - return 'done'; + return "done"; }; } @@ -143,14 +150,14 @@ export async function trainModelRunner() { export async function saveModelRunner() { // Define a model for linear regression. const model = tf.sequential(); - model.add(tf.layers.dense({units: 5, inputShape: [1]})); - model.add(tf.layers.dense({units: 1})); - model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); + model.add(tf.layers.dense({ units: 5, inputShape: [1] })); + model.add(tf.layers.dense({ units: 1 })); + model.compile({ loss: "meanSquaredError", optimizer: "sgd" }); return async () => { - await model.save(asyncStorageIO('custom-model-test')); - await tf.loadLayersModel(asyncStorageIO('custom-model-test')); + await model.save(asyncStorageIO("custom-model-test")); + await tf.loadLayersModel(asyncStorageIO("custom-model-test")); - return 'done'; + return "done"; }; } diff --git a/tfjs-react-native/integration_rn59/components/tfjs_unit_test_runner.tsx b/tfjs-react-native/integration_rn59/components/tfjs_unit_test_runner.tsx index c1f2792bcbb..5ec1a022340 100644 --- a/tfjs-react-native/integration_rn59/components/tfjs_unit_test_runner.tsx +++ b/tfjs-react-native/integration_rn59/components/tfjs_unit_test_runner.tsx @@ -123,7 +123,7 @@ export class TestRunner extends Component { const reactReporter: jasmine.CustomReporter = { jasmineStarted: suiteInfo => { // The console.warn below seems necessary in order for the spy on - // console.warn defined in one of the tests to run corrently. + // console.warn defined in one of the tests to run currently. console.warn('starting tests'); //@ts-ignore console.reportErrorsAsExceptions = false; diff --git a/tfjs-react-native/src/bundle_resource_io_test.ts b/tfjs-react-native/src/bundle_resource_io_test.ts index 2fc8b181def..d6c24b0e7dc 100644 --- a/tfjs-react-native/src/bundle_resource_io_test.ts +++ b/tfjs-react-native/src/bundle_resource_io_test.ts @@ -15,91 +15,95 @@ * ============================================================================= */ -import './platform_react_native'; +import "./platform_react_native"; -import * as tf from '@tensorflow/tfjs-core'; +import * as tf from "@tensorflow/tfjs-core"; // tslint:disable-next-line: no-imports-from-dist -import {describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util'; +import { describeWithFlags } from "@tensorflow/tfjs-core/dist/jasmine_util"; -import {bundleResourceIO} from './bundle_resource_io'; -import * as tfjsRn from './platform_react_native'; -import {RN_ENVS} from './test_env_registry'; +import { bundleResourceIO } from "./bundle_resource_io"; +import * as tfjsRn from "./platform_react_native"; +import { RN_ENVS } from "./test_env_registry"; -describeWithFlags('BundleResourceIO', RN_ENVS, () => { +describeWithFlags("BundleResourceIO", RN_ENVS, () => { // Test data. const modelTopology1: {} = { - 'class_name': 'Sequential', - 'keras_version': '2.1.4', - 'config': [{ - 'class_name': 'Dense', - 'config': { - 'kernel_initializer': { - 'class_name': 'VarianceScaling', - 'config': { - 'distribution': 'uniform', - 'scale': 1.0, - 'seed': null, - 'mode': 'fan_avg' - } + class_name: "Sequential", + keras_version: "2.1.4", + config: [ + { + class_name: "Dense", + config: { + kernel_initializer: { + class_name: "VarianceScaling", + config: { + distribution: "uniform", + scale: 1.0, + seed: null, + mode: "fan_avg", + }, + }, + name: "dense", + kernel_constraint: null, + bias_regularizer: null, + bias_constraint: null, + dtype: "float32", + activation: "linear", + trainable: true, + kernel_regularizer: null, + bias_initializer: { class_name: "Zeros", config: {} }, + units: 1, + batch_input_shape: [null, 3], + use_bias: true, + activity_regularizer: null, }, - 'name': 'dense', - 'kernel_constraint': null, - 'bias_regularizer': null, - 'bias_constraint': null, - 'dtype': 'float32', - 'activation': 'linear', - 'trainable': true, - 'kernel_regularizer': null, - 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, - 'units': 1, - 'batch_input_shape': [null, 3], - 'use_bias': true, - 'activity_regularizer': null - } - }], - 'backend': 'tensorflow' + }, + ], + backend: "tensorflow", }; const weightSpecs1: tf.io.WeightsManifestEntry[] = [ { - name: 'dense/kernel', + name: "dense/kernel", shape: [3, 1], - dtype: 'float32', + dtype: "float32", }, { - name: 'dense/bias', + name: "dense/bias", shape: [1], - dtype: 'float32', - } + dtype: "float32", + }, ]; const weightData1 = new ArrayBuffer(16); - it('constructs an IOHandler', async () => { + it("constructs an IOHandler", async () => { const modelJson: tf.io.ModelJSON = { modelTopology: modelTopology1, - weightsManifest: [{ - paths: [], - weights: weightSpecs1, - }] - + weightsManifest: [ + { + paths: [], + weights: weightSpecs1, + }, + ], }; const resourceId = 1; const handler = bundleResourceIO(modelJson, resourceId); - expect(typeof handler.load).toBe('function'); - expect(typeof handler.save).toBe('function'); + expect(typeof handler.load).toBe("function"); + expect(typeof handler.save).toBe("function"); }); - it('loads model artifacts', async () => { + it("loads model artifacts", async () => { const response = new Response(weightData1); - spyOn(tfjsRn, 'fetch').and.returnValue(Promise.resolve(response)); + spyOn(tfjsRn, "fetch").and.returnValue(Promise.resolve(response)); const modelJson: tf.io.ModelJSON = { modelTopology: modelTopology1, - weightsManifest: [{ - paths: [], - weights: weightSpecs1, - }] - + weightsManifest: [ + { + paths: [], + weights: weightSpecs1, + }, + ], }; const resourceId = 1; const handler = bundleResourceIO(modelJson, resourceId); @@ -111,10 +115,11 @@ describeWithFlags('BundleResourceIO', RN_ENVS, () => { expect(loaded.weightData).toEqual(weightData1); }); - it('errors on string modelJSON', async () => { + it("errors on string modelJSON", async () => { const response = new Response(weightData1); - spyOn(tf.env().platform, 'fetch') - .and.returnValue(Promise.resolve(response)); + spyOn(tf.env().platform, "fetch").and.returnValue( + Promise.resolve(response) + ); const modelJson = `{ modelTopology: modelTopology1, @@ -124,100 +129,102 @@ describeWithFlags('BundleResourceIO', RN_ENVS, () => { }] }`; const resourceId = 1; - expect( - () => bundleResourceIO( - modelJson as unknown as tf.io.ModelJSON, resourceId)) - .toThrow(new Error( - 'modelJson must be a JavaScript object (and not a string).\n' + - 'Have you wrapped yor asset path in a require() statment?')); + expect(() => + bundleResourceIO(modelJson as unknown as tf.io.ModelJSON, resourceId) + ).toThrow( + new Error( + "modelJson must be a JavaScript object (and not a string).\n" + + "Have you wrapped yor asset path in a require() statement?" + ) + ); }); }); -describeWithFlags('BundleResourceIO Sharded', RN_ENVS, () => { +describeWithFlags("BundleResourceIO Sharded", RN_ENVS, () => { // Test data. const modelTopology: {} = { - 'class_name': 'Sequential', - 'keras_version': '2.1.4', - 'config': [ + class_name: "Sequential", + keras_version: "2.1.4", + config: [ { - 'class_name': 'Dense', - 'config': { - 'kernel_initializer': { - 'class_name': 'VarianceScaling', - 'config': { - 'distribution': 'uniform', - 'scale': 1.0, - 'seed': null, - 'mode': 'fan_avg' - } + class_name: "Dense", + config: { + kernel_initializer: { + class_name: "VarianceScaling", + config: { + distribution: "uniform", + scale: 1.0, + seed: null, + mode: "fan_avg", + }, }, - 'name': 'dense', - 'kernel_constraint': null, - 'bias_regularizer': null, - 'bias_constraint': null, - 'dtype': 'float32', - 'activation': 'linear', - 'trainable': true, - 'kernel_regularizer': null, - 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, - 'units': 1, - 'batch_input_shape': [null, 3], - 'use_bias': true, - 'activity_regularizer': null - } + name: "dense", + kernel_constraint: null, + bias_regularizer: null, + bias_constraint: null, + dtype: "float32", + activation: "linear", + trainable: true, + kernel_regularizer: null, + bias_initializer: { class_name: "Zeros", config: {} }, + units: 1, + batch_input_shape: [null, 3], + use_bias: true, + activity_regularizer: null, + }, }, { - 'class_name': 'Dense', - 'config': { - 'kernel_initializer': { - 'class_name': 'VarianceScaling', - 'config': { - 'distribution': 'uniform', - 'scale': 1.0, - 'seed': null, - 'mode': 'fan_avg' - } + class_name: "Dense", + config: { + kernel_initializer: { + class_name: "VarianceScaling", + config: { + distribution: "uniform", + scale: 1.0, + seed: null, + mode: "fan_avg", + }, }, - 'name': 'dense2', - 'kernel_constraint': null, - 'bias_regularizer': null, - 'bias_constraint': null, - 'dtype': 'float32', - 'activation': 'linear', - 'trainable': true, - 'kernel_regularizer': null, - 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, - 'units': 1, - 'batch_input_shape': [null, 3], - 'use_bias': true, - 'activity_regularizer': null - } - } + name: "dense2", + kernel_constraint: null, + bias_regularizer: null, + bias_constraint: null, + dtype: "float32", + activation: "linear", + trainable: true, + kernel_regularizer: null, + bias_initializer: { class_name: "Zeros", config: {} }, + units: 1, + batch_input_shape: [null, 3], + use_bias: true, + activity_regularizer: null, + }, + }, ], - 'backend': 'tensorflow' + backend: "tensorflow", }; const weightSpecs: tf.io.WeightsManifestEntry[] = [ { - name: 'dense/kernel', + name: "dense/kernel", shape: [3, 1], - dtype: 'float32', + dtype: "float32", }, { - name: 'dense/bias', + name: "dense/bias", shape: [1], - dtype: 'float32', + dtype: "float32", }, { - name: 'dense2/kernel', + name: "dense2/kernel", shape: [3, 1], - dtype: 'float32', + dtype: "float32", }, { - name: 'dense2/bias', + name: "dense2/bias", shape: [1], - dtype: 'float32', - } + dtype: "float32", + }, ]; const weightData1 = new ArrayBuffer(16); const weightData2 = new ArrayBuffer(16); @@ -225,35 +232,36 @@ describeWithFlags('BundleResourceIO Sharded', RN_ENVS, () => { const combinedWeightsExpected = new ArrayBuffer(32); - it('constructs an IOHandler', async () => { + it("constructs an IOHandler", async () => { const modelJson: tf.io.ModelJSON = { modelTopology, - weightsManifest: [{ - paths: [], - weights: weightSpecs, - }] - + weightsManifest: [ + { + paths: [], + weights: weightSpecs, + }, + ], }; const handler = bundleResourceIO(modelJson, resourceIds); - expect(typeof handler.load).toBe('function'); - expect(typeof handler.save).toBe('function'); + expect(typeof handler.load).toBe("function"); + expect(typeof handler.save).toBe("function"); }); - it('loads model artifacts', async () => { - spyOn(tf.env().platform, 'fetch') - .and.returnValues( - Promise.resolve(new Response(weightData1)), - Promise.resolve(new Response(weightData2)), - ); + it("loads model artifacts", async () => { + spyOn(tf.env().platform, "fetch").and.returnValues( + Promise.resolve(new Response(weightData1)), + Promise.resolve(new Response(weightData2)) + ); const modelJson: tf.io.ModelJSON = { modelTopology, - weightsManifest: [{ - paths: [], - weights: weightSpecs, - }] - + weightsManifest: [ + { + paths: [], + weights: weightSpecs, + }, + ], }; const handler = bundleResourceIO(modelJson, resourceIds); diff --git a/tfjs-react-native/src/camera/camera_stream.tsx b/tfjs-react-native/src/camera/camera_stream.tsx index cf425d4805e..5eebe8498eb 100644 --- a/tfjs-react-native/src/camera/camera_stream.tsx +++ b/tfjs-react-native/src/camera/camera_stream.tsx @@ -110,7 +110,7 @@ const DEFAULT_USE_CUSTOM_SHADERS_TO_RESIZE = false; * gl: ExpoWebGLRenderingContext, * cameraTexture: WebGLTexture * ) => void — When the component is mounted and ready this callback will - * be called and recieve the following 3 elements: + * be called and receive the following 3 elements: * - __images__ is a (iterator)[https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Iterators_and_Generators] * that yields tensors representing the camera image on demand. * - __updateCameraPreview__ is a function that will update the WebGL render @@ -228,7 +228,7 @@ export function cameraWithTensors( } /** - * Callback for GL context creation. We do mose of the work of setting + * Callback for GL context creation. We do more of the work of setting * up the component here. * @param gl */ diff --git a/tfjs-react-native/src/camera/camera_test.ts b/tfjs-react-native/src/camera/camera_test.ts index 45c8b95b8db..010740e8603 100644 --- a/tfjs-react-native/src/camera/camera_test.ts +++ b/tfjs-react-native/src/camera/camera_test.ts @@ -15,15 +15,15 @@ * ============================================================================= */ -import * as tf from '@tensorflow/tfjs-core'; -import {test_util} from '@tensorflow/tfjs-core'; +import * as tf from "@tensorflow/tfjs-core"; +import { test_util } from "@tensorflow/tfjs-core"; // tslint:disable-next-line: no-imports-from-dist -import {describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util'; -import {ExpoWebGLRenderingContext, GLView} from 'expo-gl'; +import { describeWithFlags } from "@tensorflow/tfjs-core/dist/jasmine_util"; +import { ExpoWebGLRenderingContext, GLView } from "expo-gl"; -import {RN_ENVS} from '../test_env_registry'; +import { RN_ENVS } from "../test_env_registry"; -import {detectGLCapabilities, fromTexture, toTexture} from './camera'; +import { detectGLCapabilities, fromTexture, toTexture } from "./camera"; async function createGLContext(): Promise { return GLView.createContextAsync(); @@ -33,20 +33,24 @@ const expectArraysEqual = test_util.expectArraysEqual; let gl: ExpoWebGLRenderingContext; -describeWithFlags('toTexture', RN_ENVS, () => { +describeWithFlags("toTexture", RN_ENVS, () => { beforeAll(async () => { if (gl == null) { gl = await createGLContext(); } }); - it('should not throw', async () => { + it("should not throw", async () => { const height = 2; const width = 2; const depth = 4; - const inTensor: tf.Tensor3D = - tf.truncatedNormal([height, width, depth], 127, 40, 'int32'); + const inTensor: tf.Tensor3D = tf.truncatedNormal( + [height, width, depth], + 127, + 40, + "int32" + ); let texture: WebGLTexture; expect(async () => { @@ -56,57 +60,82 @@ describeWithFlags('toTexture', RN_ENVS, () => { expect(texture instanceof WebGLTexture); }); - it('should roundtrip succesfully', async () => { + it("should roundtrip successfully", async () => { const height = 2; const width = 2; const depth = 4; - const inTensor: tf.Tensor3D = - tf.truncatedNormal([height, width, depth], 127, 40, 'int32'); + const inTensor: tf.Tensor3D = tf.truncatedNormal( + [height, width, depth], + 127, + 40, + "int32" + ); const texture = await toTexture(gl, inTensor); const outTensor = fromTexture( - gl, texture, {width, height, depth}, {width, height, depth}, true); + gl, + texture, + { width, height, depth }, + { width, height, depth }, + true + ); expectArraysEqual(await inTensor.data(), await outTensor.data()); expectArraysEqual(inTensor.shape, outTensor.shape); }); - it('throws if tensor is not int32 dtype', async () => { + it("throws if tensor is not int32 dtype", async () => { const height = 2; const width = 2; const depth = 4; - const floatInput: tf.Tensor3D = - tf.truncatedNormal([height, width, depth], 127, 40, 'float32'); + const floatInput: tf.Tensor3D = tf.truncatedNormal( + [height, width, depth], + 127, + 40, + "float32" + ); expectAsync(toTexture(gl, floatInput)).toBeRejected(); }); - it('throws if tensor is not a tensor3d dtype', async () => { + it("throws if tensor is not a tensor3d dtype", async () => { const batch = 2; const height = 2; const width = 2; const depth = 4; - const oneDInput: tf.Tensor1D = - tf.truncatedNormal([height], 127, 40, 'int32'); + const oneDInput: tf.Tensor1D = tf.truncatedNormal( + [height], + 127, + 40, + "int32" + ); //@ts-ignore expectAsync(toTexture(gl, oneDInput)).toBeRejected(); - const twoDInput: tf.Tensor2D = - tf.truncatedNormal([height, width], 127, 40, 'int32'); + const twoDInput: tf.Tensor2D = tf.truncatedNormal( + [height, width], + 127, + 40, + "int32" + ); //@ts-ignore expectAsync(toTexture(gl, twoDInput)).toBeRejected(); - const fourDInput: tf.Tensor4D = - tf.truncatedNormal([batch, height, width, depth], 127, 40, 'int32'); + const fourDInput: tf.Tensor4D = tf.truncatedNormal( + [batch, height, width, depth], + 127, + 40, + "int32" + ); //@ts-ignore expectAsync(toTexture(gl, fourDInput)).toBeRejected(); }); }); -describeWithFlags('fromTexture:nearestNeighbor', RN_ENVS, () => { +describeWithFlags("fromTexture:nearestNeighbor", RN_ENVS, () => { let texture: WebGLTexture; let input: tf.Tensor3D; const inShape: [number, number, number] = [4, 4, 4]; @@ -117,33 +146,35 @@ describeWithFlags('fromTexture:nearestNeighbor', RN_ENVS, () => { } input = tf.tensor3d( + [ + [ + [200, 201, 202, 255], + [190, 191, 192, 255], + [180, 181, 182, 255], + [170, 171, 172, 255], + ], + [ + [160, 161, 162, 255], + [150, 151, 152, 255], + [140, 141, 142, 255], + [130, 131, 132, 255], + ], [ - [ - [200, 201, 202, 255], - [190, 191, 192, 255], - [180, 181, 182, 255], - [170, 171, 172, 255], - ], - [ - [160, 161, 162, 255], - [150, 151, 152, 255], - [140, 141, 142, 255], - [130, 131, 132, 255], - ], - [ - [120, 121, 122, 255], - [110, 111, 112, 255], - [100, 101, 102, 255], - [90, 91, 92, 255], - ], - [ - [80, 81, 82, 255], - [70, 71, 72, 255], - [60, 61, 62, 255], - [50, 51, 52, 255], - ] - ], - inShape, 'int32'); + [120, 121, 122, 255], + [110, 111, 112, 255], + [100, 101, 102, 255], + [90, 91, 92, 255], + ], + [ + [80, 81, 82, 255], + [70, 71, 72, 255], + [60, 61, 62, 255], + [50, 51, 52, 255], + ], + ], + inShape, + "int32" + ); }); beforeEach(async () => { @@ -154,263 +185,271 @@ describeWithFlags('fromTexture:nearestNeighbor', RN_ENVS, () => { tf.dispose(input); }); - it('same size alignCorners=false', async () => { + it("same size alignCorners=false", async () => { const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - true, - { - alignCorners: false, - interpolation: 'nearest_neighbor', - }, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + true, + { + alignCorners: false, + interpolation: "nearest_neighbor", + } ); expectArraysEqual(await output.data(), await input.data()); expectArraysEqual(output.shape, input.shape); }); - it('same size, alignCorners=true', async () => { + it("same size, alignCorners=true", async () => { const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - true, - { - alignCorners: true, - interpolation: 'nearest_neighbor', - }, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + true, + { + alignCorners: true, + interpolation: "nearest_neighbor", + } ); expectArraysEqual(await output.data(), await input.data()); expectArraysEqual(output.shape, input.shape); }); - it('smaller, resizeNearestNeighbor, same aspect ratio, alignCorners=false', - async () => { - const expectedShape: [number, number, number] = [2, 2, 4]; - const expected = tf.tensor3d( - [ - [ - [200, 201, 202, 255], - [180, 181, 182, 255], - ], - [ - [120, 121, 122, 255], - [100, 101, 102, 255], - ] - ], - expectedShape, 'int32'); - - const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: expectedShape[0], - width: expectedShape[1], - depth: expectedShape[2], - }, - true, - {alignCorners: false, interpolation: 'nearest_neighbor'}, - ); - - expectArraysEqual(await output.data(), await expected.data()); - expectArraysEqual(output.shape, expected.shape); - }); - - it('smaller, resizeNearestNeighbor, same aspect ratio, alignCorners=true', - async () => { - const expectedShape: [number, number, number] = [2, 2, 4]; - const expected = tf.tensor3d( - [ - [ - [200, 201, 202, 255], - [170, 171, 172, 255], - ], - [ - [80, 81, 82, 255], - [50, 51, 52, 255], - ] - ], - expectedShape, 'int32'); - - const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: expectedShape[0], - width: expectedShape[1], - depth: expectedShape[2], - }, - true, - {alignCorners: true, interpolation: 'nearest_neighbor'}, - ); - - expectArraysEqual(await output.data(), await expected.data()); - expectArraysEqual(output.shape, expected.shape); - }); - - it('smaller, resizeNearestNeighbor, wider, alignCorners=false', async () => { + it("smaller, resizeNearestNeighbor, same aspect ratio, alignCorners=false", async () => { + const expectedShape: [number, number, number] = [2, 2, 4]; + const expected = tf.tensor3d( + [ + [ + [200, 201, 202, 255], + [180, 181, 182, 255], + ], + [ + [120, 121, 122, 255], + [100, 101, 102, 255], + ], + ], + expectedShape, + "int32" + ); + + const output = fromTexture( + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: expectedShape[0], + width: expectedShape[1], + depth: expectedShape[2], + }, + true, + { alignCorners: false, interpolation: "nearest_neighbor" } + ); + + expectArraysEqual(await output.data(), await expected.data()); + expectArraysEqual(output.shape, expected.shape); + }); + + it("smaller, resizeNearestNeighbor, same aspect ratio, alignCorners=true", async () => { + const expectedShape: [number, number, number] = [2, 2, 4]; + const expected = tf.tensor3d( + [ + [ + [200, 201, 202, 255], + [170, 171, 172, 255], + ], + [ + [80, 81, 82, 255], + [50, 51, 52, 255], + ], + ], + expectedShape, + "int32" + ); + + const output = fromTexture( + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: expectedShape[0], + width: expectedShape[1], + depth: expectedShape[2], + }, + true, + { alignCorners: true, interpolation: "nearest_neighbor" } + ); + + expectArraysEqual(await output.data(), await expected.data()); + expectArraysEqual(output.shape, expected.shape); + }); + + it("smaller, resizeNearestNeighbor, wider, alignCorners=false", async () => { const expectedShape: [number, number, number] = [2, 3, 4]; const expected = tf.tensor3d( + [ + [ + [200, 201, 202, 255], + [190, 191, 192, 255], + [180, 181, 182, 255], + ], [ - [ - [200, 201, 202, 255], - [190, 191, 192, 255], - [180, 181, 182, 255], - ], - [ - [120, 121, 122, 255], - [110, 111, 112, 255], - [100, 101, 102, 255], - ] + [120, 121, 122, 255], + [110, 111, 112, 255], + [100, 101, 102, 255], ], - expectedShape, 'int32'); + ], + expectedShape, + "int32" + ); const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: expectedShape[0], - width: expectedShape[1], - depth: expectedShape[2], - }, - true, - {alignCorners: false, interpolation: 'nearest_neighbor'}, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: expectedShape[0], + width: expectedShape[1], + depth: expectedShape[2], + }, + true, + { alignCorners: false, interpolation: "nearest_neighbor" } ); expectArraysEqual(await output.data(), await expected.data()); expectArraysEqual(output.shape, expected.shape); }); - it('smaller, resizeNearestNeighbor, wider, alignCorners=true', async () => { + it("smaller, resizeNearestNeighbor, wider, alignCorners=true", async () => { const expectedShape: [number, number, number] = [2, 3, 4]; const expected = tf.tensor3d( + [ [ - [ - [200, 201, 202, 255], - [180, 181, 182, 255], - [170, 171, 172, 255], - ], + [200, 201, 202, 255], + [180, 181, 182, 255], + [170, 171, 172, 255], + ], - [ - [80, 81, 82, 255], - [60, 61, 62, 255], - [50, 51, 52, 255], - ] + [ + [80, 81, 82, 255], + [60, 61, 62, 255], + [50, 51, 52, 255], ], - expectedShape, 'int32'); + ], + expectedShape, + "int32" + ); const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: expectedShape[0], - width: expectedShape[1], - depth: expectedShape[2], - }, - true, - {alignCorners: true, interpolation: 'nearest_neighbor'}, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: expectedShape[0], + width: expectedShape[1], + depth: expectedShape[2], + }, + true, + { alignCorners: true, interpolation: "nearest_neighbor" } ); expectArraysEqual(await output.data(), await expected.data()); expectArraysEqual(output.shape, expected.shape); }); - it('same size, should drop alpha channel', async () => { + it("same size, should drop alpha channel", async () => { await detectGLCapabilities(gl); const expected = tf.tensor3d( + [ + [ + [200, 201, 202], + [190, 191, 192], + [180, 181, 182], + [170, 171, 172], + ], + [ + [160, 161, 162], + [150, 151, 152], + [140, 141, 142], + [130, 131, 132], + ], [ - [ - [200, 201, 202], - [190, 191, 192], - [180, 181, 182], - [170, 171, 172], - ], - [ - [160, 161, 162], - [150, 151, 152], - [140, 141, 142], - [130, 131, 132], - ], - [ - [120, 121, 122], - [110, 111, 112], - [100, 101, 102], - [90, 91, 92], - ], - [ - [80, 81, 82], - [70, 71, 72], - [60, 61, 62], - [50, 51, 52], - ] - ], - [inShape[0], inShape[1], 3], 'int32'); + [120, 121, 122], + [110, 111, 112], + [100, 101, 102], + [90, 91, 92], + ], + [ + [80, 81, 82], + [70, 71, 72], + [60, 61, 62], + [50, 51, 52], + ], + ], + [inShape[0], inShape[1], 3], + "int32" + ); const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: inShape[0], - width: inShape[1], - depth: 3, - }, - true, - { - alignCorners: true, - interpolation: 'nearest_neighbor', - }, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: inShape[0], + width: inShape[1], + depth: 3, + }, + true, + { + alignCorners: true, + interpolation: "nearest_neighbor", + } ); expectArraysEqual(await output.data(), await expected.data()); expectArraysEqual(output.shape, expected.shape); }); }); -describeWithFlags('fromTexture:bilinear', RN_ENVS, () => { +describeWithFlags("fromTexture:bilinear", RN_ENVS, () => { let texture: WebGLTexture; let input: tf.Tensor3D; const inShape: [number, number, number] = [4, 4, 4]; @@ -421,33 +460,35 @@ describeWithFlags('fromTexture:bilinear', RN_ENVS, () => { } input = tf.tensor3d( + [ + [ + [200, 201, 202, 255], + [190, 191, 192, 255], + [180, 181, 182, 255], + [170, 171, 172, 255], + ], [ - [ - [200, 201, 202, 255], - [190, 191, 192, 255], - [180, 181, 182, 255], - [170, 171, 172, 255], - ], - [ - [160, 161, 162, 255], - [150, 151, 152, 255], - [140, 141, 142, 255], - [130, 131, 132, 255], - ], - [ - [120, 121, 122, 255], - [110, 111, 112, 255], - [100, 101, 102, 255], - [90, 91, 92, 255], - ], - [ - [80, 81, 82, 255], - [70, 71, 72, 255], - [60, 61, 62, 255], - [50, 51, 52, 255], - ] - ], - inShape, 'int32'); + [160, 161, 162, 255], + [150, 151, 152, 255], + [140, 141, 142, 255], + [130, 131, 132, 255], + ], + [ + [120, 121, 122, 255], + [110, 111, 112, 255], + [100, 101, 102, 255], + [90, 91, 92, 255], + ], + [ + [80, 81, 82, 255], + [70, 71, 72, 255], + [60, 61, 62, 255], + [50, 51, 52, 255], + ], + ], + inShape, + "int32" + ); }); afterAll(() => { @@ -458,253 +499,263 @@ describeWithFlags('fromTexture:bilinear', RN_ENVS, () => { texture = await toTexture(gl, input); }); - it('same size alignCorners=false', async () => { + it("same size alignCorners=false", async () => { const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - true, - { - alignCorners: false, - interpolation: 'bilinear', - }, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + true, + { + alignCorners: false, + interpolation: "bilinear", + } ); expectArraysEqual(await output.data(), await input.data()); expectArraysEqual(output.shape, input.shape); }); - it('same size, alignCorners=true', async () => { + it("same size, alignCorners=true", async () => { const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - true, - { - alignCorners: true, - interpolation: 'bilinear', - }, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + true, + { + alignCorners: true, + interpolation: "bilinear", + } ); expectArraysEqual(await output.data(), await input.data()); expectArraysEqual(output.shape, input.shape); }); - it('smaller, same aspect ratio, alignCorners=false', async () => { + it("smaller, same aspect ratio, alignCorners=false", async () => { const expectedShape: [number, number, number] = [2, 2, 4]; const expected = tf.tensor3d( + [ + [ + [200, 201, 202, 255], + [180, 181, 182, 255], + ], [ - [ - [200, 201, 202, 255], - [180, 181, 182, 255], - ], - [ - [120, 121, 122, 255], - [100, 101, 102, 255], - ] + [120, 121, 122, 255], + [100, 101, 102, 255], ], - expectedShape, 'int32'); + ], + expectedShape, + "int32" + ); const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: expectedShape[0], - width: expectedShape[1], - depth: expectedShape[2], - }, - true, - {alignCorners: false, interpolation: 'bilinear'}, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: expectedShape[0], + width: expectedShape[1], + depth: expectedShape[2], + }, + true, + { alignCorners: false, interpolation: "bilinear" } ); expectArraysEqual(await output.data(), await expected.data()); expectArraysEqual(output.shape, expected.shape); }); - it('smaller, same aspect ratio, alignCorners=true', async () => { + it("smaller, same aspect ratio, alignCorners=true", async () => { const expectedShape: [number, number, number] = [2, 2, 4]; const expected = tf.tensor3d( + [ + [ + [200, 201, 202, 255], + [170, 171, 172, 255], + ], [ - [ - [200, 201, 202, 255], - [170, 171, 172, 255], - ], - [ - [80, 81, 82, 255], - [50, 51, 52, 255], - ] + [80, 81, 82, 255], + [50, 51, 52, 255], ], - expectedShape, 'int32'); + ], + expectedShape, + "int32" + ); const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: expectedShape[0], - width: expectedShape[1], - depth: expectedShape[2], - }, - true, - {alignCorners: true, interpolation: 'bilinear'}, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: expectedShape[0], + width: expectedShape[1], + depth: expectedShape[2], + }, + true, + { alignCorners: true, interpolation: "bilinear" } ); expectArraysEqual(await output.data(), await expected.data()); expectArraysEqual(output.shape, expected.shape); }); - it('smaller, wider, alignCorners=false', async () => { + it("smaller, wider, alignCorners=false", async () => { const expectedShape: [number, number, number] = [2, 3, 4]; const expected = tf.tensor3d( + [ + [ + [200, 201, 202, 255], + [187, 188, 189, 255], + [173, 174, 175, 255], + ], [ - [ - [200, 201, 202, 255], - [187, 188, 189, 255], - [173, 174, 175, 255], - ], - [ - [120, 121, 122, 255], - [107, 108, 109, 255], - [93, 94, 95, 255], - ] + [120, 121, 122, 255], + [107, 108, 109, 255], + [93, 94, 95, 255], ], - expectedShape, 'int32'); + ], + expectedShape, + "int32" + ); const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: expectedShape[0], - width: expectedShape[1], - depth: expectedShape[2], - }, - true, - {alignCorners: false, interpolation: 'bilinear'}, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: expectedShape[0], + width: expectedShape[1], + depth: expectedShape[2], + }, + true, + { alignCorners: false, interpolation: "bilinear" } ); expectArraysEqual(await output.data(), await expected.data()); expectArraysEqual(output.shape, expected.shape); }); - it('smaller, wider, alignCorners=true', async () => { + it("smaller, wider, alignCorners=true", async () => { const expectedShape: [number, number, number] = [2, 3, 4]; const expected = tf.tensor3d( + [ [ - [ - [200, 201, 202, 255], - [185, 186, 187, 255], - [170, 171, 172, 255], - ], - [ - [80, 81, 82, 255], - [65, 66, 67, 255], - [50, 51, 52, 255], - ] + [200, 201, 202, 255], + [185, 186, 187, 255], + [170, 171, 172, 255], ], - expectedShape, 'int32'); + [ + [80, 81, 82, 255], + [65, 66, 67, 255], + [50, 51, 52, 255], + ], + ], + expectedShape, + "int32" + ); const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: expectedShape[0], - width: expectedShape[1], - depth: expectedShape[2], - }, - true, - {alignCorners: true, interpolation: 'bilinear'}, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: expectedShape[0], + width: expectedShape[1], + depth: expectedShape[2], + }, + true, + { alignCorners: true, interpolation: "bilinear" } ); expectArraysEqual(await output.data(), await expected.data()); expectArraysEqual(output.shape, expected.shape); }); - it('same size, should drop alpha channel', async () => { + it("same size, should drop alpha channel", async () => { await detectGLCapabilities(gl); const expected = tf.tensor3d( + [ + [ + [200, 201, 202], + [190, 191, 192], + [180, 181, 182], + [170, 171, 172], + ], + [ + [160, 161, 162], + [150, 151, 152], + [140, 141, 142], + [130, 131, 132], + ], + [ + [120, 121, 122], + [110, 111, 112], + [100, 101, 102], + [90, 91, 92], + ], [ - [ - [200, 201, 202], - [190, 191, 192], - [180, 181, 182], - [170, 171, 172], - ], - [ - [160, 161, 162], - [150, 151, 152], - [140, 141, 142], - [130, 131, 132], - ], - [ - [120, 121, 122], - [110, 111, 112], - [100, 101, 102], - [90, 91, 92], - ], - [ - [80, 81, 82], - [70, 71, 72], - [60, 61, 62], - [50, 51, 52], - ] - ], - [inShape[0], inShape[1], 3], 'int32'); + [80, 81, 82], + [70, 71, 72], + [60, 61, 62], + [50, 51, 52], + ], + ], + [inShape[0], inShape[1], 3], + "int32" + ); const output = fromTexture( - gl, - texture, - { - height: inShape[0], - width: inShape[1], - depth: inShape[2], - }, - { - height: inShape[0], - width: inShape[1], - depth: 3, - }, - true, - { - alignCorners: true, - interpolation: 'bilinear', - }, + gl, + texture, + { + height: inShape[0], + width: inShape[1], + depth: inShape[2], + }, + { + height: inShape[0], + width: inShape[1], + depth: 3, + }, + true, + { + alignCorners: true, + interpolation: "bilinear", + } ); expectArraysEqual(await output.data(), await expected.data()); expectArraysEqual(output.shape, expected.shape); diff --git a/tfjs-react-native/src/platform_react_native.ts b/tfjs-react-native/src/platform_react_native.ts index 8537981d61a..fa57fa6afb6 100644 --- a/tfjs-react-native/src/platform_react_native.ts +++ b/tfjs-react-native/src/platform_react_native.ts @@ -15,27 +15,31 @@ * ============================================================================= */ -import '@tensorflow/tfjs-backend-cpu'; -import {GPGPUContext, MathBackendWebGL, setWebGLContext} from '@tensorflow/tfjs-backend-webgl'; -import * as tf from '@tensorflow/tfjs-core'; -import {Platform} from '@tensorflow/tfjs-core'; -import {Buffer} from 'buffer'; -import {GLView} from 'expo-gl'; -import {Platform as RNPlatform} from 'react-native'; +import "@tensorflow/tfjs-backend-cpu"; +import { + GPGPUContext, + MathBackendWebGL, + setWebGLContext, +} from "@tensorflow/tfjs-backend-webgl"; +import * as tf from "@tensorflow/tfjs-core"; +import { Platform } from "@tensorflow/tfjs-core"; +import { Buffer } from "buffer"; +import { GLView } from "expo-gl"; +import { Platform as RNPlatform } from "react-native"; -// See implemetation note on fetch +// See implementation note on fetch // tslint:disable-next-line:max-line-length // https://github.com/facebook/react-native/blob/0ee5f68929610106ee6864baa04ea90be0fc5160/Libraries/vendor/core/whatwg-fetch.js#L421 function parseHeaders(rawHeaders: string) { const headers = new Headers(); // Replace instances of \r\n and \n followed by at least one space or // horizontal tab with a space https://tools.ietf.org/html/rfc7230#section-3.2 - const preProcessedHeaders = rawHeaders.replace(/\r?\n[\t ]+/g, ' '); - preProcessedHeaders.split(/\r?\n/).forEach(line => { - const parts = line.split(':'); + const preProcessedHeaders = rawHeaders.replace(/\r?\n[\t ]+/g, " "); + preProcessedHeaders.split(/\r?\n/).forEach((line) => { + const parts = line.split(":"); const key = parts.shift().trim(); if (key) { - const value = parts.join(':').trim(); + const value = parts.join(":").trim(); headers.append(key, value); } }); @@ -67,8 +71,10 @@ function parseHeaders(rawHeaders: string) { * @doc {heading: 'Platform helpers', subheading: 'http'} */ export async function fetch( - path: string, init?: RequestInit, - options?: tf.io.RequestDetails): Promise { + path: string, + init?: RequestInit, + options?: tf.io.RequestDetails +): Promise { return new Promise((resolve, reject) => { const request = new Request(path, init); const xhr = new XMLHttpRequest(); @@ -77,27 +83,28 @@ export async function fetch( const reqOptions = { status: xhr.status, statusText: xhr.statusText, - headers: parseHeaders(xhr.getAllResponseHeaders() || ''), - url: '', + headers: parseHeaders(xhr.getAllResponseHeaders() || ""), + url: "", }; - reqOptions.url = 'responseURL' in xhr ? - xhr.responseURL : - reqOptions.headers.get('X-Request-URL'); + reqOptions.url = + "responseURL" in xhr + ? xhr.responseURL + : reqOptions.headers.get("X-Request-URL"); - //@ts-ignore — ts belives the latter case will never occur. - const body = 'response' in xhr ? xhr.response : xhr.responseText; + //@ts-ignore — ts believes the latter case will never occur. + const body = "response" in xhr ? xhr.response : xhr.responseText; resolve(new Response(body, reqOptions)); }; - xhr.onerror = () => reject(new TypeError('Network request failed')); - xhr.ontimeout = () => reject(new TypeError('Network request failed')); + xhr.onerror = () => reject(new TypeError("Network request failed")); + xhr.ontimeout = () => reject(new TypeError("Network request failed")); xhr.open(request.method, request.url, true); - if (request.credentials === 'include') { + if (request.credentials === "include") { xhr.withCredentials = true; - } else if (request.credentials === 'omit') { + } else if (request.credentials === "omit") { xhr.withCredentials = false; } @@ -105,7 +112,7 @@ export async function fetch( // In react native We need to set the response type to arraybuffer when // fetching binary resources in order for `.arrayBuffer` to work correctly // on the response. - xhr.responseType = 'arraybuffer'; + xhr.responseType = "arraybuffer"; } request.headers.forEach((value: string, name: string) => { @@ -113,8 +120,8 @@ export async function fetch( }); xhr.send( - //@ts-ignore - typeof request._bodyInit === 'undefined' ? null : request._bodyInit, + //@ts-ignore + typeof request._bodyInit === "undefined" ? null : request._bodyInit ); }); } @@ -126,7 +133,10 @@ export class PlatformReactNative implements Platform { * see @fetch docs above. */ async fetch( - path: string, init?: RequestInit, options?: tf.io.RequestDetails) { + path: string, + init?: RequestInit, + options?: tf.io.RequestDetails + ) { return fetch(path, init, options); } @@ -136,16 +146,16 @@ export class PlatformReactNative implements Platform { */ encode(text: string, encoding: string): Uint8Array { // See https://www.w3.org/TR/encoding/#utf-16le - if (encoding === 'utf-16') { - encoding = 'utf16le'; + if (encoding === "utf-16") { + encoding = "utf16le"; } return new Uint8Array(Buffer.from(text, encoding as BufferEncoding)); } /** Decode the provided bytes into a string using the provided encoding. */ decode(bytes: Uint8Array, encoding: string): string { // See https://www.w3.org/TR/encoding/#utf-16le - if (encoding === 'utf-16') { - encoding = 'utf16le'; + if (encoding === "utf-16") { + encoding = "utf16le"; } return Buffer.from(bytes).toString(encoding as BufferEncoding); } @@ -160,13 +170,18 @@ export class PlatformReactNative implements Platform { } setTimeoutCustom() { - throw new Error('react native does not support setTimeoutCustom'); + throw new Error("react native does not support setTimeoutCustom"); } - isTypedArray(a: unknown): a is Uint8Array | Float32Array | Int32Array - | Uint8ClampedArray { - return a instanceof Float32Array || a instanceof Int32Array || - a instanceof Uint8Array || a instanceof Uint8ClampedArray; + isTypedArray( + a: unknown + ): a is Uint8Array | Float32Array | Int32Array | Uint8ClampedArray { + return ( + a instanceof Float32Array || + a instanceof Int32Array || + a instanceof Uint8Array || + a instanceof Uint8ClampedArray + ); } } @@ -177,89 +192,96 @@ function setupGlobals() { function registerWebGLBackend() { try { const PRIORITY = 5; - tf.registerBackend('rn-webgl', async () => { - const glContext = await GLView.createContextAsync(); + tf.registerBackend( + "rn-webgl", + async () => { + const glContext = await GLView.createContextAsync(); - // ExpoGl getBufferSubData is not implemented yet (throws an exception). - tf.env().set('WEBGL_BUFFER_SUPPORTED', false); + // ExpoGl getBufferSubData is not implemented yet (throws an exception). + tf.env().set("WEBGL_BUFFER_SUPPORTED", false); - // - // Mock extension support for EXT_color_buffer_float and - // EXT_color_buffer_half_float on the expo-gl context object. - // In react native we do not have to get a handle to the extension - // in order to use the functionality of that extension on the device. - // - // This code block makes iOS and Android devices pass the extension checks - // used in core. After those are done core will actually test whether - // we can render/download float or half float textures. - // - // We can remove this block once we upstream checking for these - // extensions in expo. - // - // TODO look into adding support for checking these extensions in expo-gl - // - //@ts-ignore - const getExt = glContext.getExtension.bind(glContext); - const shimGetExt = (name: string) => { - if (name === 'EXT_color_buffer_float') { - if (RNPlatform.OS === 'ios') { - // iOS does not support EXT_color_buffer_float - return null; - } else { + // + // Mock extension support for EXT_color_buffer_float and + // EXT_color_buffer_half_float on the expo-gl context object. + // In react native we do not have to get a handle to the extension + // in order to use the functionality of that extension on the device. + // + // This code block makes iOS and Android devices pass the extension checks + // used in core. After those are done core will actually test whether + // we can render/download float or half float textures. + // + // We can remove this block once we upstream checking for these + // extensions in expo. + // + // TODO look into adding support for checking these extensions in expo-gl + // + //@ts-ignore + const getExt = glContext.getExtension.bind(glContext); + const shimGetExt = (name: string) => { + if (name === "EXT_color_buffer_float") { + if (RNPlatform.OS === "ios") { + // iOS does not support EXT_color_buffer_float + return null; + } else { + return {}; + } + } + + if (name === "EXT_color_buffer_half_float") { return {}; } - } + return getExt(name); + }; - if (name === 'EXT_color_buffer_half_float') { + // + // Manually make 'read' synchronous. glContext has a defined gl.fenceSync + // function that throws a "Not implemented yet" exception so core + // cannot properly detect that it is not supported. We mock + // implementations of gl.fenceSync and gl.clientWaitSync + // TODO remove once fenceSync and clientWaitSync is implemented upstream. + // + const shimFenceSync = () => { return {}; - } - return getExt(name); - }; + }; + const shimClientWaitSync = () => glContext.CONDITION_SATISFIED; - // - // Manually make 'read' synchronous. glContext has a defined gl.fenceSync - // function that throws a "Not implemented yet" exception so core - // cannot properly detect that it is not supported. We mock - // implementations of gl.fenceSync and gl.clientWaitSync - // TODO remove once fenceSync and clientWaitSync is implemented upstream. - // - const shimFenceSync = () => { - return {}; - }; - const shimClientWaitSync = () => glContext.CONDITION_SATISFIED; + // @ts-ignore + glContext.getExtension = shimGetExt.bind(glContext); + glContext.fenceSync = shimFenceSync.bind(glContext); + glContext.clientWaitSync = shimClientWaitSync.bind(glContext); - // @ts-ignore - glContext.getExtension = shimGetExt.bind(glContext); - glContext.fenceSync = shimFenceSync.bind(glContext); - glContext.clientWaitSync = shimClientWaitSync.bind(glContext); + // Set the WebGLContext before flag evaluation + setWebGLContext(2, glContext); + const context = new GPGPUContext(); + const backend = new MathBackendWebGL(context); - // Set the WebGLContext before flag evaluation - setWebGLContext(2, glContext); - const context = new GPGPUContext(); - const backend = new MathBackendWebGL(context); - - return backend; - }, PRIORITY); + return backend; + }, + PRIORITY + ); // Register all the webgl kernels on the rn-webgl backend // TODO: Use tf.copyRegisteredKernels once synced to tfjs-core 2.5.0. // tf.copyRegisteredKernels('webgl', 'rn-webgl'); - const kernels = tf.getKernelsForBackend('webgl'); - kernels.forEach(kernelConfig => { - const newKernelConfig = - Object.assign({}, kernelConfig, {backendName: 'rn-webgl'}); + const kernels = tf.getKernelsForBackend("webgl"); + kernels.forEach((kernelConfig) => { + const newKernelConfig = Object.assign({}, kernelConfig, { + backendName: "rn-webgl", + }); tf.registerKernel(newKernelConfig); }); } catch (e) { - throw (new Error(`Failed to register Webgl backend: ${e.message}`)); + throw new Error(`Failed to register Webgl backend: ${e.message}`); } } tf.env().registerFlag( - 'IS_REACT_NATIVE', () => navigator && navigator.product === 'ReactNative'); + "IS_REACT_NATIVE", + () => navigator && navigator.product === "ReactNative" +); -if (tf.env().getBool('IS_REACT_NATIVE')) { +if (tf.env().getBool("IS_REACT_NATIVE")) { setupGlobals(); registerWebGLBackend(); - tf.setPlatform('react-native', new PlatformReactNative()); + tf.setPlatform("react-native", new PlatformReactNative()); }