|
| 1 | + |
| 2 | +#include "larq_compute_engine/core/bmaxpool.h" |
| 3 | + |
| 4 | +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers |
| 5 | +#include "larq_compute_engine/core/packbits_utils.h" |
| 6 | +#include "tensorflow/lite/c/builtin_op_data.h" |
| 7 | +#include "tensorflow/lite/c/common.h" |
| 8 | +#include "tensorflow/lite/kernels/internal/tensor.h" |
| 9 | +#include "tensorflow/lite/kernels/kernel_util.h" |
| 10 | +#include "tensorflow/lite/kernels/op_macros.h" |
| 11 | + |
| 12 | +using namespace tflite; |
| 13 | + |
| 14 | +namespace compute_engine { |
| 15 | +namespace tflite { |
| 16 | +namespace maxpool { |
| 17 | + |
| 18 | +using namespace compute_engine::ref; |
| 19 | +using namespace compute_engine::core; |
| 20 | + |
| 21 | +using TBitpacked = std::uint32_t; |
| 22 | + |
| 23 | +struct MicroBMaxPoolParams : public BMaxPoolParams { |
| 24 | + int packed_input_id; |
| 25 | +}; |
| 26 | + |
| 27 | +bool StringEquals(const flexbuffers::String& a, const char* b) { |
| 28 | + // We use `strcmp` instead of `std::string` to avoid dynamic memory allocation |
| 29 | + return strcmp(a.c_str(), b) == 0; |
| 30 | +} |
| 31 | + |
| 32 | +void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
| 33 | + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); |
| 34 | + MicroBMaxPoolParams* poolparams = nullptr; |
| 35 | + if (context->AllocatePersistentBuffer(context, sizeof(MicroBMaxPoolParams), |
| 36 | + (void**)&poolparams) != kTfLiteOk) { |
| 37 | + context->ReportError(context, "Could not allocate persistent buffer."); |
| 38 | + return nullptr; |
| 39 | + } |
| 40 | + |
| 41 | + const std::uint8_t* buffer_t = reinterpret_cast<const std::uint8_t*>(buffer); |
| 42 | + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); |
| 43 | + |
| 44 | + poolparams->filter_height = m["filter_height"].AsInt32(); |
| 45 | + poolparams->filter_width = m["filter_width"].AsInt32(); |
| 46 | + poolparams->stride_height = m["stride_height"].AsInt32(); |
| 47 | + poolparams->stride_width = m["stride_width"].AsInt32(); |
| 48 | + |
| 49 | + auto padding_str = m["padding"].AsString(); |
| 50 | + if (StringEquals(padding_str, "VALID") || |
| 51 | + StringEquals(padding_str, "valid")) { |
| 52 | + poolparams->padding_type = kTfLitePaddingValid; |
| 53 | + } else if (StringEquals(padding_str, "SAME") || |
| 54 | + StringEquals(padding_str, "same")) { |
| 55 | + poolparams->padding_type = kTfLitePaddingSame; |
| 56 | + } else { |
| 57 | + context->ReportError(context, "Bmaxpool2d: invalid padding attribute."); |
| 58 | + } |
| 59 | + return poolparams; |
| 60 | +} |
| 61 | + |
| 62 | +// The only thing done in Prepare is asserts |
| 63 | +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
| 64 | + MicroBMaxPoolParams* poolparams = |
| 65 | + reinterpret_cast<MicroBMaxPoolParams*>(node->user_data); |
| 66 | + |
| 67 | + TF_LITE_ENSURE(context, poolparams != nullptr); |
| 68 | + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
| 69 | + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
| 70 | + TfLiteTensor* output = GetOutput(context, node, 0); |
| 71 | + const TfLiteTensor* input = GetInput(context, node, 0); |
| 72 | + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); |
| 73 | + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt32); |
| 74 | + |
| 75 | + int batches = input->dims->data[0]; |
| 76 | + int height = input->dims->data[1]; |
| 77 | + int width = input->dims->data[2]; |
| 78 | + |
| 79 | + int out_width, out_height; |
| 80 | + poolparams->padding = ComputePaddingHeightWidth( |
| 81 | + poolparams->stride_height, poolparams->stride_width, 1, 1, height, width, |
| 82 | + poolparams->filter_height, poolparams->filter_width, |
| 83 | + poolparams->padding_type, &out_height, &out_width); |
| 84 | + |
| 85 | + int channels_out = 0; |
| 86 | + if (input->type == kTfLiteFloat32 || input->type == kTfLiteInt8) { |
| 87 | + channels_out = GetPackedSize<TBitpacked>(input->dims->data[3]); |
| 88 | + } else { |
| 89 | + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32); |
| 90 | + channels_out = input->dims->data[3]; |
| 91 | + } |
| 92 | + |
| 93 | + // Use temoprary tensor for bitpacked inputs |
| 94 | + if (input->type == kTfLiteFloat32 || input->type == kTfLiteInt8) { |
| 95 | + int flat_size = |
| 96 | + batches * height * width * channels_out * sizeof(TBitpacked); |
| 97 | + |
| 98 | + TF_LITE_ENSURE_OK(context, |
| 99 | + context->RequestScratchBufferInArena( |
| 100 | + context, flat_size, &poolparams->packed_input_id)); |
| 101 | + } |
| 102 | + |
| 103 | + return kTfLiteOk; |
| 104 | +} |
| 105 | + |
| 106 | +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
| 107 | + MicroBMaxPoolParams* poolparams = |
| 108 | + reinterpret_cast<MicroBMaxPoolParams*>(node->user_data); |
| 109 | + |
| 110 | + TF_LITE_ENSURE(context, poolparams != nullptr); |
| 111 | + |
| 112 | + TfLiteTensor* output = GetOutput(context, node, 0); |
| 113 | + const TfLiteTensor* input = GetInput(context, node, 0); |
| 114 | + |
| 115 | + const TBitpacked* packed_input_data; |
| 116 | + RuntimeShape packed_input_shape; |
| 117 | + |
| 118 | + if (input->type == kTfLiteFloat32) { |
| 119 | + TBitpacked* packed_input = reinterpret_cast<TBitpacked*>( |
| 120 | + context->GetScratchBuffer(context, poolparams->packed_input_id)); |
| 121 | + ce::core::packbits_tensor<ce::core::BitpackOrder::Canonical>( |
| 122 | + GetTensorShape(input), GetTensorData<float>(input), 0, |
| 123 | + packed_input_shape, packed_input); |
| 124 | + packed_input_data = packed_input; |
| 125 | + } else if (input->type == kTfLiteInt8) { |
| 126 | + TBitpacked* packed_input = reinterpret_cast<TBitpacked*>( |
| 127 | + context->GetScratchBuffer(context, poolparams->packed_input_id)); |
| 128 | + ce::core::packbits_tensor<ce::core::BitpackOrder::Canonical>( |
| 129 | + GetTensorShape(input), GetTensorData<std::int8_t>(input), |
| 130 | + input->params.zero_point, packed_input_shape, packed_input); |
| 131 | + packed_input_data = packed_input; |
| 132 | + } else { |
| 133 | + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32); |
| 134 | + packed_input_shape.ReplaceWith(4, GetTensorShape(input).DimsData()); |
| 135 | + packed_input_data = GetTensorData<TBitpacked>(input); |
| 136 | + } |
| 137 | + |
| 138 | + BMaxPoolParams* pp = poolparams; |
| 139 | + BMaxPool(*pp, packed_input_shape, packed_input_data, GetTensorShape(output), |
| 140 | + GetTensorData<TBitpacked>(output)); |
| 141 | + |
| 142 | + return kTfLiteOk; |
| 143 | +} |
| 144 | + |
| 145 | +} // namespace maxpool |
| 146 | + |
| 147 | +TfLiteRegistration* Register_BMAXPOOL_2D() { |
| 148 | + static TfLiteRegistration r = {maxpool::Init, nullptr, maxpool::Prepare, |
| 149 | + maxpool::Eval}; |
| 150 | + return &r; |
| 151 | +} |
| 152 | + |
| 153 | +} // namespace tflite |
| 154 | +} // namespace compute_engine |
0 commit comments