From c8617743866ac289029b2cc4fe47775e71de12ab Mon Sep 17 00:00:00 2001 From: kuloud Date: Fri, 27 Dec 2019 19:12:20 +0800 Subject: [PATCH] Add extensions-tensorflow-lite #24 --- Android/examples/demo/build.gradle | 5 +- .../aoe/features/mnist/MnistInterpreter.java | 4 +- .../pytorch/ClassifierPyTorchInterpreter.kt | 3 +- .../features/squeeze/SqueezeInterpreter.kt | 4 +- .../pytorch/src/main/AndroidManifest.xml | 2 +- .../aoe/{ => extensions}/pytorch/AoeExt.kt | 2 +- .../pytorch/PyTorchInterpreterWrapper.kt | 2 +- .../pytorch/PytorchConvertor.kt | 2 +- Android/extensions/tensorflow-lite/.gitignore | 1 + .../extensions/tensorflow-lite/build.gradle | 60 +++ .../tensorflow-lite/consumer-rules.pro | 0 .../tensorflow-lite/proguard-rules.pro | 21 + .../src/main/AndroidManifest.xml | 18 + .../aoe/extensions/tensorflow/lite/AoeExt.kt | 39 ++ Android/global_config.gradle | 3 +- Android/settings.gradle | 3 +- ...preter.java => TensorFlowInterpreter.java} | 80 ++-- ...FlowMultipleInputsOutputsInterpreter.java} | 413 +++++++++--------- 18 files changed, 406 insertions(+), 256 deletions(-) rename Android/extensions/pytorch/src/main/java/com/didi/aoe/{ => extensions}/pytorch/AoeExt.kt (96%) rename Android/extensions/pytorch/src/main/java/com/didi/aoe/{ => extensions}/pytorch/PyTorchInterpreterWrapper.kt (96%) rename Android/extensions/pytorch/src/main/java/com/didi/aoe/{ => extensions}/pytorch/PytorchConvertor.kt (95%) create mode 100644 Android/extensions/tensorflow-lite/.gitignore create mode 100644 Android/extensions/tensorflow-lite/build.gradle create mode 100644 Android/extensions/tensorflow-lite/consumer-rules.pro create mode 100644 Android/extensions/tensorflow-lite/proguard-rules.pro create mode 100644 Android/extensions/tensorflow-lite/src/main/AndroidManifest.xml create mode 100644 Android/extensions/tensorflow-lite/src/main/java/com/didi/aoe/extensions/tensorflow/lite/AoeExt.kt rename Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/{TensorFlowLiteInterpreter.java => TensorFlowInterpreter.java} (83%) rename Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/{TensorFlowLiteMultipleInputsOutputsInterpreter.java => TensorFlowMultipleInputsOutputsInterpreter.java} (91%) diff --git a/Android/examples/demo/build.gradle b/Android/examples/demo/build.gradle index 711093f..1612b68 100644 --- a/Android/examples/demo/build.gradle +++ b/Android/examples/demo/build.gradle @@ -72,15 +72,14 @@ dependencies { implementation deps.aoe.runtime.mnn implementation deps.aoe.runtime.ncnn + implementation deps.gson implementation deps.kotlin implementation 'com.didi.aoe:extensions-support:1.1.1.1' - implementation 'org.tensorflow:tensorflow-lite:2.0.0' - implementation 'org.tensorflow:tensorflow-lite-gpu:2.0.0' - + implementation deps.aoe.extensions.tensorflow implementation deps.aoe.extensions.pytorch } diff --git a/Android/examples/demo/src/main/java/com/didi/aoe/features/mnist/MnistInterpreter.java b/Android/examples/demo/src/main/java/com/didi/aoe/features/mnist/MnistInterpreter.java index 664261f..289a822 100644 --- a/Android/examples/demo/src/main/java/com/didi/aoe/features/mnist/MnistInterpreter.java +++ b/Android/examples/demo/src/main/java/com/didi/aoe/features/mnist/MnistInterpreter.java @@ -3,12 +3,12 @@ import androidx.annotation.NonNull; import androidx.annotation.Nullable; -import com.didi.aoe.runtime.tensorflow.lite.TensorFlowLiteInterpreter; +import com.didi.aoe.runtime.tensorflow.lite.TensorFlowInterpreter; /** * @author noctis */ -public class MnistInterpreter extends TensorFlowLiteInterpreter { +public class MnistInterpreter extends TensorFlowInterpreter { @Nullable @Override diff --git a/Android/examples/demo/src/main/java/com/didi/aoe/features/pytorch/ClassifierPyTorchInterpreter.kt b/Android/examples/demo/src/main/java/com/didi/aoe/features/pytorch/ClassifierPyTorchInterpreter.kt index 8aec1a6..f3cb3ad 100644 --- a/Android/examples/demo/src/main/java/com/didi/aoe/features/pytorch/ClassifierPyTorchInterpreter.kt +++ b/Android/examples/demo/src/main/java/com/didi/aoe/features/pytorch/ClassifierPyTorchInterpreter.kt @@ -17,8 +17,7 @@ package com.didi.aoe.features.pytorch import android.graphics.Bitmap -import com.didi.aoe.pytorch.PytorchConvertor -import com.didi.aoe.runtime.pytorch.PyTorchInterpreter +import com.didi.aoe.extensions.pytorch.PytorchConvertor import org.pytorch.Tensor import org.pytorch.torchvision.TensorImageUtils diff --git a/Android/examples/demo/src/main/java/com/didi/aoe/features/squeeze/SqueezeInterpreter.kt b/Android/examples/demo/src/main/java/com/didi/aoe/features/squeeze/SqueezeInterpreter.kt index 821b242..108e991 100644 --- a/Android/examples/demo/src/main/java/com/didi/aoe/features/squeeze/SqueezeInterpreter.kt +++ b/Android/examples/demo/src/main/java/com/didi/aoe/features/squeeze/SqueezeInterpreter.kt @@ -70,11 +70,11 @@ class SqueezeInterpreter : val bmpBuffer = ByteBuffer.allocate(size) input.copyPixelsToBuffer(bmpBuffer) val rgba = bmpBuffer.array() - squeeze!!.inputRgba(rgba, input.width, input.height, INPUT_WIDTH, + squeeze?.inputRgba(rgba, input.width, input.height, INPUT_WIDTH, INPUT_HEIGHT, meanVals, norVals, 0) val buffer = ByteBuffer.allocate(4096) - squeeze!!.run(null, buffer) + squeeze?.run(null, buffer) buffer.order(ByteOrder.nativeOrder()) buffer.flip() val shape = squeeze!!.getOutputTensor(0).shape() diff --git a/Android/extensions/pytorch/src/main/AndroidManifest.xml b/Android/extensions/pytorch/src/main/AndroidManifest.xml index 31fddb9..9d7e84e 100644 --- a/Android/extensions/pytorch/src/main/AndroidManifest.xml +++ b/Android/extensions/pytorch/src/main/AndroidManifest.xml @@ -15,4 +15,4 @@ --> + package="com.didi.aoe.extensions.pytorch" /> diff --git a/Android/extensions/pytorch/src/main/java/com/didi/aoe/pytorch/AoeExt.kt b/Android/extensions/pytorch/src/main/java/com/didi/aoe/extensions/pytorch/AoeExt.kt similarity index 96% rename from Android/extensions/pytorch/src/main/java/com/didi/aoe/pytorch/AoeExt.kt rename to Android/extensions/pytorch/src/main/java/com/didi/aoe/extensions/pytorch/AoeExt.kt index 4d6253d..a78ca4b 100644 --- a/Android/extensions/pytorch/src/main/java/com/didi/aoe/pytorch/AoeExt.kt +++ b/Android/extensions/pytorch/src/main/java/com/didi/aoe/extensions/pytorch/AoeExt.kt @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.didi.aoe.pytorch +package com.didi.aoe.extensions.pytorch import android.content.Context import com.didi.aoe.library.api.Aoe diff --git a/Android/extensions/pytorch/src/main/java/com/didi/aoe/pytorch/PyTorchInterpreterWrapper.kt b/Android/extensions/pytorch/src/main/java/com/didi/aoe/extensions/pytorch/PyTorchInterpreterWrapper.kt similarity index 96% rename from Android/extensions/pytorch/src/main/java/com/didi/aoe/pytorch/PyTorchInterpreterWrapper.kt rename to Android/extensions/pytorch/src/main/java/com/didi/aoe/extensions/pytorch/PyTorchInterpreterWrapper.kt index c5438cb..5641ed2 100644 --- a/Android/extensions/pytorch/src/main/java/com/didi/aoe/pytorch/PyTorchInterpreterWrapper.kt +++ b/Android/extensions/pytorch/src/main/java/com/didi/aoe/extensions/pytorch/PyTorchInterpreterWrapper.kt @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.didi.aoe.pytorch +package com.didi.aoe.extensions.pytorch import com.didi.aoe.runtime.pytorch.PyTorchInterpreter import org.pytorch.Tensor diff --git a/Android/extensions/pytorch/src/main/java/com/didi/aoe/pytorch/PytorchConvertor.kt b/Android/extensions/pytorch/src/main/java/com/didi/aoe/extensions/pytorch/PytorchConvertor.kt similarity index 95% rename from Android/extensions/pytorch/src/main/java/com/didi/aoe/pytorch/PytorchConvertor.kt rename to Android/extensions/pytorch/src/main/java/com/didi/aoe/extensions/pytorch/PytorchConvertor.kt index f38a513..f45c8f4 100644 --- a/Android/extensions/pytorch/src/main/java/com/didi/aoe/pytorch/PytorchConvertor.kt +++ b/Android/extensions/pytorch/src/main/java/com/didi/aoe/extensions/pytorch/PytorchConvertor.kt @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.didi.aoe.pytorch +package com.didi.aoe.extensions.pytorch import com.didi.aoe.library.api.convertor.Convertor import org.pytorch.Tensor diff --git a/Android/extensions/tensorflow-lite/.gitignore b/Android/extensions/tensorflow-lite/.gitignore new file mode 100644 index 0000000..796b96d --- /dev/null +++ b/Android/extensions/tensorflow-lite/.gitignore @@ -0,0 +1 @@ +/build diff --git a/Android/extensions/tensorflow-lite/build.gradle b/Android/extensions/tensorflow-lite/build.gradle new file mode 100644 index 0000000..5dac167 --- /dev/null +++ b/Android/extensions/tensorflow-lite/build.gradle @@ -0,0 +1,60 @@ +/* + * Copyright 2019 The AoE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +apply plugin: 'com.android.library' +apply plugin: 'kotlin-android' + + +ext { + releaseArtifact = 'extensions-tensorflow-lite' + releaseDescription = 'The AoE tensorflow lite extensions library' + releaseVersion = aoe_version_name +} +apply from: rootProject.file('gradle/release.gradle') + +android { + compileSdkVersion aoe_compile_sdk_version + defaultConfig { + minSdkVersion aoe_min_sdk_version + targetSdkVersion aoe_target_sdk_version + versionName releaseVersion + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + +} + +dependencies { + implementation fileTree(dir: 'libs', include: ['*.jar']) + + implementation deps.support.annotation + + implementation deps.aoe.library.core + + implementation deps.kotlin + + implementation 'org.tensorflow:tensorflow-lite:2.0.0' + implementation 'org.tensorflow:tensorflow-lite-gpu:2.0.0' + + implementation deps.aoe.runtime.tensorflow + +} + diff --git a/Android/extensions/tensorflow-lite/consumer-rules.pro b/Android/extensions/tensorflow-lite/consumer-rules.pro new file mode 100644 index 0000000..e69de29 diff --git a/Android/extensions/tensorflow-lite/proguard-rules.pro b/Android/extensions/tensorflow-lite/proguard-rules.pro new file mode 100644 index 0000000..f1b4245 --- /dev/null +++ b/Android/extensions/tensorflow-lite/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/Android/extensions/tensorflow-lite/src/main/AndroidManifest.xml b/Android/extensions/tensorflow-lite/src/main/AndroidManifest.xml new file mode 100644 index 0000000..1c822b0 --- /dev/null +++ b/Android/extensions/tensorflow-lite/src/main/AndroidManifest.xml @@ -0,0 +1,18 @@ + + + diff --git a/Android/extensions/tensorflow-lite/src/main/java/com/didi/aoe/extensions/tensorflow/lite/AoeExt.kt b/Android/extensions/tensorflow-lite/src/main/java/com/didi/aoe/extensions/tensorflow/lite/AoeExt.kt new file mode 100644 index 0000000..fe07e92 --- /dev/null +++ b/Android/extensions/tensorflow-lite/src/main/java/com/didi/aoe/extensions/tensorflow/lite/AoeExt.kt @@ -0,0 +1,39 @@ +/* + * Copyright 2019 The AoE Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.didi.aoe.extensions.tensorflow.lite + +import android.content.Context +import com.didi.aoe.library.api.Aoe +import com.didi.aoe.library.core.AoeClient +import com.didi.aoe.runtime.tensorflow.lite.TensorFlowMultipleInputsOutputsInterpreter +import org.tensorflow.lite.gpu.GpuDelegate + +/** + * + * + * @author noctis + * @since 1.1.0 + */ + +fun Aoe.Companion.createAoeClient(context: Context, modelPath: String, + convertor: TensorFlowMultipleInputsOutputsInterpreter<*, *, *, *>, useGpu: Boolean): AoeClient { + if (useGpu) { + convertor.addDelegate(GpuDelegate()) + } + val client = AoeClient(context, convertor, modelPath) + return client +} \ No newline at end of file diff --git a/Android/global_config.gradle b/Android/global_config.gradle index 4449afb..ce9114a 100644 --- a/Android/global_config.gradle +++ b/Android/global_config.gradle @@ -45,7 +45,8 @@ ext { mnn : isDebug() ? project(':runtime-mnn') : "com.didi.aoe:runtime-mnn:$aoe_version_name", ], extensions: [ - pytorch: isDebug() ? project(':extensions-pytorch') : "com.didi.aoe:extensions-pytorch:$aoe_version_name" + pytorch : isDebug() ? project(':extensions-pytorch') : "com.didi.aoe:extensions-pytorch:$aoe_version_name", + tensorflow: isDebug() ? project(':extensions-tensorflow-lite') : "com.didi.aoe:extensions-tensorflow-lite:$aoe_version_name", ] ], // aoe 计划后续实现全量使用 kotlin,注解依赖暂时使用较低的 26.x 以保证最小的适配成本 diff --git a/Android/settings.gradle b/Android/settings.gradle index 78fa87c..abf6fad 100644 --- a/Android/settings.gradle +++ b/Android/settings.gradle @@ -1,5 +1,3 @@ -//include ':aoe-pytorch' - module('examples', 'demo') module('library', 'core') @@ -14,6 +12,7 @@ module('runtime', 'ncnn') module('runtime', 'pytorch') module('extensions', 'pytorch') +module('extensions', 'tensorflow-lite') // ------------------------------------------------------------------------------------------------- diff --git a/Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowLiteInterpreter.java b/Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowInterpreter.java similarity index 83% rename from Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowLiteInterpreter.java rename to Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowInterpreter.java index b74e1fd..ced7ab1 100644 --- a/Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowLiteInterpreter.java +++ b/Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowInterpreter.java @@ -1,40 +1,40 @@ -package com.didi.aoe.runtime.tensorflow.lite; - -import android.support.annotation.NonNull; -import android.support.annotation.Nullable; -import com.didi.aoe.library.api.convertor.Convertor; - -import java.util.Map; - -/** - * 基于TensorFlow Lite的运行时Interpreter封装,用于单输入,单输出的常见场景。多路输入的场景不要继承这个类,继承 - * 它的父类TensorFlowLiteMultipleInputsOutputsInterpreter,实现preProcessMulti和postProcessMulti即可。 - * - * @param 范型,业务输入数据 - * @param 范型,业务输出数据 - * @param 范型,模型输入数据 - * @param 范型,模型输出数据 - * @author noctis - */ -public abstract class TensorFlowLiteInterpreter extends - TensorFlowLiteMultipleInputsOutputsInterpreter implements - Convertor { - - @Nullable - @Override - public final Object[] preProcessMulti(@NonNull TInput tInput) { - Object[] inputs = new Object[1]; - inputs[0] = preProcess(tInput); - return inputs; - } - - @Nullable - @Override - public final TOutput postProcessMulti(@Nullable Map modelOutput) { - if (modelOutput != null && !modelOutput.isEmpty()) { - return postProcess(modelOutput.get(0)); - } - return null; - } - -} +package com.didi.aoe.runtime.tensorflow.lite; + +import android.support.annotation.NonNull; +import android.support.annotation.Nullable; +import com.didi.aoe.library.api.convertor.Convertor; + +import java.util.Map; + +/** + * 基于TensorFlow Lite的运行时Interpreter封装,用于单输入,单输出的常见场景。多路输入的场景不要继承这个类,继承 + * 它的父类TensorFlowLiteMultipleInputsOutputsInterpreter,实现preProcessMulti和postProcessMulti即可。 + * + * @param 范型,业务输入数据 + * @param 范型,业务输出数据 + * @param 范型,模型输入数据 + * @param 范型,模型输出数据 + * @author noctis + */ +public abstract class TensorFlowInterpreter extends + TensorFlowMultipleInputsOutputsInterpreter implements + Convertor { + + @Nullable + @Override + public final Object[] preProcessMulti(@NonNull TInput tInput) { + Object[] inputs = new Object[1]; + inputs[0] = preProcess(tInput); + return inputs; + } + + @Nullable + @Override + public final TOutput postProcessMulti(@Nullable Map modelOutput) { + if (modelOutput != null && !modelOutput.isEmpty()) { + return postProcess(modelOutput.get(0)); + } + return null; + } + +} diff --git a/Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowLiteMultipleInputsOutputsInterpreter.java b/Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowMultipleInputsOutputsInterpreter.java similarity index 91% rename from Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowLiteMultipleInputsOutputsInterpreter.java rename to Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowMultipleInputsOutputsInterpreter.java index b1358d6..0c85228 100644 --- a/Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowLiteMultipleInputsOutputsInterpreter.java +++ b/Android/third_party/tensorflow-lite/src/main/java/com/didi/aoe/runtime/tensorflow/lite/TensorFlowMultipleInputsOutputsInterpreter.java @@ -1,200 +1,213 @@ -package com.didi.aoe.runtime.tensorflow.lite; - -import android.annotation.SuppressLint; -import android.content.Context; -import android.support.annotation.NonNull; -import android.support.annotation.Nullable; -import com.didi.aoe.library.api.AoeModelOption; -import com.didi.aoe.library.api.AoeProcessor; -import com.didi.aoe.library.api.StatusCode; -import com.didi.aoe.library.api.convertor.MultiConvertor; -import com.didi.aoe.library.api.domain.ModelSource; -import com.didi.aoe.library.api.interpreter.InterpreterInitResult; -import com.didi.aoe.library.api.interpreter.OnInterpreterInitListener; -import com.didi.aoe.library.api.interpreter.SingleInterpreterComponent; -import com.didi.aoe.library.common.util.FileUtils; -import com.didi.aoe.library.logging.Logger; -import com.didi.aoe.library.logging.LoggerFactory; -import org.tensorflow.lite.Interpreter; -import org.tensorflow.lite.Tensor; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.lang.reflect.Array; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.channels.FileChannel; -import java.util.HashMap; -import java.util.Map; - -/** - * 提供单模型的TensorFlowLite实现,用于多输入、多输出方式调用。 - * - * @param 范型,业务输入数据 - * @param 范型,业务输出数据 - * @param 范型,模型输入数据 - * @param 范型,模型输出数据 - * @author noctis - */ -public abstract class TensorFlowLiteMultipleInputsOutputsInterpreter - extends SingleInterpreterComponent implements - MultiConvertor { - private final Logger mLogger = LoggerFactory.getLogger("TFLite.Interpreter"); - private Interpreter mInterpreter; - private Map mOutputPlaceholder; - - @Override - public void init(@NonNull Context context, - @Nullable AoeProcessor.InterpreterComponent.Options interpreterOptions, - @NonNull AoeModelOption modelOptions, - @Nullable OnInterpreterInitListener listener) { - - @ModelSource - String modelSource = modelOptions.getModelSource(); - ByteBuffer bb = null; - if (ModelSource.CLOUD.equals(modelSource)) { - String modelFilePath = - modelOptions.getModelDir() + "_" + modelOptions.getVersion() + File.separator + modelOptions - .getModelName(); - File modelFile = new File(FileUtils.getFilesDir(context), modelFilePath); - if (modelFile.exists()) { - try { - bb = loadFromExternal(context, modelFilePath); - } catch (Exception e) { - mLogger.warn("IOException", e); - } - } else { - // 配置为云端模型,本地无文件,返回等待中状态 - if (listener != null) { - listener.onInitResult(InterpreterInitResult.create(StatusCode.STATUS_MODEL_DOWNLOAD_WAITING)); - } - return; - } - - - } else { - String modelFilePath = modelOptions.getModelDir() + File.separator + modelOptions.getModelName(); - // local default - bb = loadFromAssets(context, modelFilePath); - } - - if (bb != null) { - Interpreter.Options options = null; - if (interpreterOptions != null) { - options = new Interpreter.Options().setNumThreads(interpreterOptions.getNumThreads()); - } - mInterpreter = new Interpreter(bb, options); - - mOutputPlaceholder = generalOutputPlaceholder(mInterpreter); - if (listener != null) { - listener.onInitResult(InterpreterInitResult.create(StatusCode.STATUS_OK)); - } - return; - } else { - if (listener != null) { - listener.onInitResult(InterpreterInitResult.create(StatusCode.STATUS_INNER_ERROR)); - } - } - } - - private ByteBuffer loadFromExternal(Context context, String modelFilePath) throws IOException { - FileInputStream fis = new FileInputStream(FileUtils.getFilesDir(context) + File.separator + modelFilePath); - FileChannel fileChannel = fis.getChannel(); - long startOffset = fileChannel.position(); - long declaredLength = fileChannel.size(); - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); - } - - private Map generalOutputPlaceholder(@NonNull Interpreter interpreter) { - @SuppressLint("UseSparseArrays") - Map out = new HashMap<>(interpreter.getOutputTensorCount()); - for (int i = 0; i < interpreter.getOutputTensorCount(); i++) { - Tensor tensor = interpreter.getOutputTensor(i); - Object data = null; - switch (tensor.dataType()) { - case FLOAT32: - data = Array.newInstance(Float.TYPE, tensor.shape()); - break; - case INT32: - data = Array.newInstance(Integer.TYPE, tensor.shape()); - break; - case UINT8: - data = Array.newInstance(Byte.TYPE, tensor.shape()); - break; - case INT64: - data = Array.newInstance(Long.TYPE, tensor.shape()); - break; - case STRING: - data = Array.newInstance(String.class, tensor.shape()); - break; - default: - // ignore - break; - } - out.put(i, data); - } - - return out; - } - - @Override - @Nullable - public TOutput run(@NonNull TInput input) { - if (isReady()) { - Object[] modelInput = preProcessMulti(input); - - if (modelInput != null) { - - mInterpreter.runForMultipleInputsOutputs(modelInput, mOutputPlaceholder); - - //noinspection unchecked - return postProcessMulti((Map) mOutputPlaceholder); - } - - } - return null; - } - - @Override - public void release() { - if (mInterpreter != null) { - mInterpreter.close(); - } - } - - @Override - public boolean isReady() { - return mInterpreter != null && mOutputPlaceholder != null; - } - - private ByteBuffer loadFromAssets(Context context, String modelFilePath) { - InputStream is = null; - try { - is = context.getAssets().open(modelFilePath); - byte[] bytes = FileUtils.read(is); - if (bytes == null) { - return null; - } - ByteBuffer bf = ByteBuffer.allocateDirect(bytes.length); - bf.order(ByteOrder.nativeOrder()); - bf.put(bytes); - - return bf; - } catch (IOException e) { - mLogger.error("loadFromAssets error", e); - } finally { - if (is != null) { - try { - is.close(); - } catch (IOException e) { - // ignore - } - } - } - - return null; - - } -} +package com.didi.aoe.runtime.tensorflow.lite; + +import android.annotation.SuppressLint; +import android.content.Context; +import android.support.annotation.NonNull; +import android.support.annotation.Nullable; +import com.didi.aoe.library.api.AoeModelOption; +import com.didi.aoe.library.api.AoeProcessor; +import com.didi.aoe.library.api.StatusCode; +import com.didi.aoe.library.api.convertor.MultiConvertor; +import com.didi.aoe.library.api.domain.ModelSource; +import com.didi.aoe.library.api.interpreter.InterpreterInitResult; +import com.didi.aoe.library.api.interpreter.OnInterpreterInitListener; +import com.didi.aoe.library.api.interpreter.SingleInterpreterComponent; +import com.didi.aoe.library.common.util.FileUtils; +import com.didi.aoe.library.logging.Logger; +import com.didi.aoe.library.logging.LoggerFactory; +import org.tensorflow.lite.Delegate; +import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.Tensor; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.util.*; + +/** + * 提供单模型的TensorFlowLite实现,用于多输入、多输出方式调用。 + * + * @param 范型,业务输入数据 + * @param 范型,业务输出数据 + * @param 范型,模型输入数据 + * @param 范型,模型输出数据 + * @author noctis + */ +public abstract class TensorFlowMultipleInputsOutputsInterpreter + extends SingleInterpreterComponent implements + MultiConvertor { + private final Logger mLogger = LoggerFactory.getLogger("TFLite.Interpreter"); + private Interpreter mInterpreter; + private Map mOutputPlaceholder; + private List mDelegates = new ArrayList<>(); + + @Override + public void init(@NonNull Context context, + @Nullable AoeProcessor.InterpreterComponent.Options interpreterOptions, + @NonNull AoeModelOption modelOptions, + @Nullable OnInterpreterInitListener listener) { + + @ModelSource + String modelSource = modelOptions.getModelSource(); + ByteBuffer bb = null; + if (ModelSource.CLOUD.equals(modelSource)) { + String modelFilePath = + modelOptions.getModelDir() + "_" + modelOptions.getVersion() + File.separator + modelOptions + .getModelName(); + File modelFile = new File(FileUtils.getFilesDir(context), modelFilePath); + if (modelFile.exists()) { + try { + bb = loadFromExternal(context, modelFilePath); + } catch (Exception e) { + mLogger.warn("IOException", e); + } + } else { + // 配置为云端模型,本地无文件,返回等待中状态 + if (listener != null) { + listener.onInitResult(InterpreterInitResult.create(StatusCode.STATUS_MODEL_DOWNLOAD_WAITING)); + } + return; + } + + + } else { + String modelFilePath = modelOptions.getModelDir() + File.separator + modelOptions.getModelName(); + // local default + bb = loadFromAssets(context, modelFilePath); + } + + if (bb != null) { + Interpreter.Options options = null; + if (interpreterOptions != null) { + options = new Interpreter.Options().setNumThreads(interpreterOptions.getNumThreads()); + } + if (!mDelegates.isEmpty()) { + Iterator it = mDelegates.iterator(); + while (it.hasNext()) { + Delegate delegate = it.next(); + mLogger.debug("addDelegate: " + delegate); + options.addDelegate(delegate); + } + } + mInterpreter = new Interpreter(bb, options); + + mOutputPlaceholder = generalOutputPlaceholder(mInterpreter); + if (listener != null) { + listener.onInitResult(InterpreterInitResult.create(StatusCode.STATUS_OK)); + } + return; + } else { + if (listener != null) { + listener.onInitResult(InterpreterInitResult.create(StatusCode.STATUS_INNER_ERROR)); + } + } + } + + private ByteBuffer loadFromExternal(Context context, String modelFilePath) throws IOException { + FileInputStream fis = new FileInputStream(FileUtils.getFilesDir(context) + File.separator + modelFilePath); + FileChannel fileChannel = fis.getChannel(); + long startOffset = fileChannel.position(); + long declaredLength = fileChannel.size(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + + private Map generalOutputPlaceholder(@NonNull Interpreter interpreter) { + @SuppressLint("UseSparseArrays") + Map out = new HashMap<>(interpreter.getOutputTensorCount()); + for (int i = 0; i < interpreter.getOutputTensorCount(); i++) { + Tensor tensor = interpreter.getOutputTensor(i); + Object data = null; + switch (tensor.dataType()) { + case FLOAT32: + data = Array.newInstance(Float.TYPE, tensor.shape()); + break; + case INT32: + data = Array.newInstance(Integer.TYPE, tensor.shape()); + break; + case UINT8: + data = Array.newInstance(Byte.TYPE, tensor.shape()); + break; + case INT64: + data = Array.newInstance(Long.TYPE, tensor.shape()); + break; + case STRING: + data = Array.newInstance(String.class, tensor.shape()); + break; + default: + // ignore + break; + } + out.put(i, data); + } + + return out; + } + + @Override + @Nullable + public TOutput run(@NonNull TInput input) { + if (isReady()) { + Object[] modelInput = preProcessMulti(input); + + if (modelInput != null) { + + mInterpreter.runForMultipleInputsOutputs(modelInput, mOutputPlaceholder); + + //noinspection unchecked + return postProcessMulti((Map) mOutputPlaceholder); + } + + } + return null; + } + + @Override + public void release() { + if (mInterpreter != null) { + mInterpreter.close(); + } + } + + @Override + public boolean isReady() { + return mInterpreter != null && mOutputPlaceholder != null; + } + + private ByteBuffer loadFromAssets(Context context, String modelFilePath) { + InputStream is = null; + try { + is = context.getAssets().open(modelFilePath); + byte[] bytes = FileUtils.read(is); + if (bytes == null) { + return null; + } + ByteBuffer bf = ByteBuffer.allocateDirect(bytes.length); + bf.order(ByteOrder.nativeOrder()); + bf.put(bytes); + + return bf; + } catch (IOException e) { + mLogger.error("loadFromAssets error", e); + } finally { + if (is != null) { + try { + is.close(); + } catch (IOException e) { + // ignore + } + } + } + + return null; + + } + + public void addDelegate(@NonNull Delegate delegate) { + mDelegates.add(delegate); + } +}