diff --git a/tfjs-react-native/package.json b/tfjs-react-native/package.json index 846ac38596..5ab5eeb07b 100644 --- a/tfjs-react-native/package.json +++ b/tfjs-react-native/package.json @@ -65,7 +65,9 @@ "dependencies": { "base64-js": "^1.3.0", "buffer": "^5.2.1", - "jpeg-js": "^0.4.3" + "jpeg-js": "^0.4.3", + "react-native-mmkv": "^4.1.2", + "react-native-nitro-modules": "^0.33.9" }, "peerDependencies": { "@react-native-async-storage/async-storage": "^1.13.0", diff --git a/tfjs-react-native/src/async_storage_io.ts b/tfjs-react-native/src/async_storage_io.ts index 9a88305837..807b015f5e 100644 --- a/tfjs-react-native/src/async_storage_io.ts +++ b/tfjs-react-native/src/async_storage_io.ts @@ -15,9 +15,10 @@ * ============================================================================= */ -import {AsyncStorageStatic} from '@react-native-async-storage/async-storage'; +import {createMMKV} from 'react-native-mmkv'; import {io} from '@tensorflow/tfjs-core'; import {fromByteArray, toByteArray} from 'base64-js'; +import type {MMKV} from 'react-native-mmkv'; type StorageKeys = { info: string, @@ -62,7 +63,7 @@ function getModelArtifactsInfoForJSON(modelArtifacts: io.ModelArtifacts): class AsyncStorageHandler implements io.IOHandler { protected readonly keys: StorageKeys; - protected asyncStorage: AsyncStorageStatic; + protected mmkvStore: MMKV; constructor(protected readonly modelPath: string) { if (modelPath == null || !modelPath) { @@ -74,9 +75,9 @@ class AsyncStorageHandler implements io.IOHandler { // needs to be installed by the user if they use this handler. We don't // want users who are not using AsyncStorage to have to install this // library. - this.asyncStorage = - // tslint:disable-next-line:no-require-imports - require('@react-native-async-storage/async-storage').default; + this.mmkvStore = createMMKV({ + id: 'tfjs.react-native.store' + }); } /** @@ -99,22 +100,23 @@ class AsyncStorageHandler implements io.IOHandler { const {weightData, ...modelArtifactsWithoutWeights} = modelArtifacts; try { - this.asyncStorage.setItem( - this.keys.info, JSON.stringify(modelArtifactsInfo)); - this.asyncStorage.setItem( - this.keys.modelArtifactsWithoutWeights, - JSON.stringify(modelArtifactsWithoutWeights)); - this.asyncStorage.setItem( - this.keys.weightData, fromByteArray(new Uint8Array(weightData))); + this.mmkvStore.set( + this.keys.info, JSON.stringify(modelArtifacts)); + this.mmkvStore.set( + this.keys.modelArtifactsWithoutWeights, + JSON.stringify(modelArtifactsWithoutWeights)); + this.mmkvStore.set( + //@ts-ignore + this.keys.weightData, fromByteArray(new Uint8Array(weightData))); return {modelArtifactsInfo}; } catch (err) { // If saving failed, clean up all items saved so far. - this.asyncStorage.removeItem(this.keys.info); - this.asyncStorage.removeItem(this.keys.weightData); - this.asyncStorage.removeItem(this.keys.modelArtifactsWithoutWeights); + this.mmkvStore.remove(this.keys.info); + this.mmkvStore.remove(this.keys.modelArtifactsWithoutWeights); + this.mmkvStore.remove(this.keys.weightData); throw new Error( - `Failed to save model '${this.modelPath}' to AsyncStorage. + `Failed to save model '${this.modelPath}' to key-value storage. Error info ${err}`); } } @@ -129,7 +131,7 @@ class AsyncStorageHandler implements io.IOHandler { * @returns The loaded model (if loading succeeds). */ async load(): Promise { - const info = JSON.parse(await this.asyncStorage.getItem(this.keys.info)) as + const info = JSON.parse(this.mmkvStore.getString(this.keys.info)) as io.ModelArtifactsInfo; if (info == null) { throw new Error( @@ -143,18 +145,18 @@ class AsyncStorageHandler implements io.IOHandler { } const modelArtifacts: io.ModelArtifacts = - JSON.parse(await this.asyncStorage.getItem( + JSON.parse(this.mmkvStore.getString( this.keys.modelArtifactsWithoutWeights)); // Load weight data. const weightDataBase64 = - await this.asyncStorage.getItem(this.keys.weightData); + this.mmkvStore.getString(this.keys.weightData); if (weightDataBase64 == null) { throw new Error( `In local storage, the binary weight values of model ` + `'${this.modelPath}' are missing.`); } - modelArtifacts.weightData = toByteArray(weightDataBase64).buffer; + modelArtifacts.weightData = toByteArray(weightDataBase64).buffer as ArrayBuffer; return modelArtifacts; }