Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tfjs-react-native/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@
"dependencies": {
"base64-js": "^1.3.0",
"buffer": "^5.2.1",
"jpeg-js": "^0.4.3"
"jpeg-js": "^0.4.3",
"react-native-mmkv": "^4.1.2",
"react-native-nitro-modules": "^0.33.9"
},
"peerDependencies": {
"@react-native-async-storage/async-storage": "^1.13.0",
Expand Down
42 changes: 22 additions & 20 deletions tfjs-react-native/src/async_storage_io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
* =============================================================================
*/

import {AsyncStorageStatic} from '@react-native-async-storage/async-storage';
import {createMMKV} from 'react-native-mmkv';
import {io} from '@tensorflow/tfjs-core';
import {fromByteArray, toByteArray} from 'base64-js';
import type {MMKV} from 'react-native-mmkv';

type StorageKeys = {
info: string,
Expand Down Expand Up @@ -62,7 +63,7 @@ function getModelArtifactsInfoForJSON(modelArtifacts: io.ModelArtifacts):

class AsyncStorageHandler implements io.IOHandler {
protected readonly keys: StorageKeys;
protected asyncStorage: AsyncStorageStatic;
protected mmkvStore: MMKV;

constructor(protected readonly modelPath: string) {
if (modelPath == null || !modelPath) {
Expand All @@ -74,9 +75,9 @@ class AsyncStorageHandler implements io.IOHandler {
// needs to be installed by the user if they use this handler. We don't
// want users who are not using AsyncStorage to have to install this
// library.
this.asyncStorage =
// tslint:disable-next-line:no-require-imports
require('@react-native-async-storage/async-storage').default;
this.mmkvStore = createMMKV({
id: 'tfjs.react-native.store'
});
}

/**
Expand All @@ -99,22 +100,23 @@ class AsyncStorageHandler implements io.IOHandler {
const {weightData, ...modelArtifactsWithoutWeights} = modelArtifacts;

try {
this.asyncStorage.setItem(
this.keys.info, JSON.stringify(modelArtifactsInfo));
this.asyncStorage.setItem(
this.keys.modelArtifactsWithoutWeights,
JSON.stringify(modelArtifactsWithoutWeights));
this.asyncStorage.setItem(
this.keys.weightData, fromByteArray(new Uint8Array(weightData)));
this.mmkvStore.set(
this.keys.info, JSON.stringify(modelArtifacts));
this.mmkvStore.set(
this.keys.modelArtifactsWithoutWeights,
JSON.stringify(modelArtifactsWithoutWeights));
this.mmkvStore.set(
//@ts-ignore
this.keys.weightData, fromByteArray(new Uint8Array(weightData)));
return {modelArtifactsInfo};
} catch (err) {
// If saving failed, clean up all items saved so far.
this.asyncStorage.removeItem(this.keys.info);
this.asyncStorage.removeItem(this.keys.weightData);
this.asyncStorage.removeItem(this.keys.modelArtifactsWithoutWeights);
this.mmkvStore.remove(this.keys.info);
this.mmkvStore.remove(this.keys.modelArtifactsWithoutWeights);
this.mmkvStore.remove(this.keys.weightData);

throw new Error(
`Failed to save model '${this.modelPath}' to AsyncStorage.
`Failed to save model '${this.modelPath}' to key-value storage.
Error info ${err}`);
}
}
Expand All @@ -129,7 +131,7 @@ class AsyncStorageHandler implements io.IOHandler {
* @returns The loaded model (if loading succeeds).
*/
async load(): Promise<io.ModelArtifacts> {
const info = JSON.parse(await this.asyncStorage.getItem(this.keys.info)) as
const info = JSON.parse(this.mmkvStore.getString(this.keys.info)) as
io.ModelArtifactsInfo;
if (info == null) {
throw new Error(
Expand All @@ -143,18 +145,18 @@ class AsyncStorageHandler implements io.IOHandler {
}

const modelArtifacts: io.ModelArtifacts =
JSON.parse(await this.asyncStorage.getItem(
JSON.parse(this.mmkvStore.getString(
this.keys.modelArtifactsWithoutWeights));

// Load weight data.
const weightDataBase64 =
await this.asyncStorage.getItem(this.keys.weightData);
this.mmkvStore.getString(this.keys.weightData);
if (weightDataBase64 == null) {
throw new Error(
`In local storage, the binary weight values of model ` +
`'${this.modelPath}' are missing.`);
}
modelArtifacts.weightData = toByteArray(weightDataBase64).buffer;
modelArtifacts.weightData = toByteArray(weightDataBase64).buffer as ArrayBuffer;

return modelArtifacts;
}
Expand Down