Skip to content

Commit 9966e9f

Browse files
committed
Add binary maxpool 2D op to LCE Micro (#53)
1 parent 48b9875 commit 9966e9f

11 files changed

+1196
-1030
lines changed

larq_compute_engine/micro/build_make/build_lcem.sh

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ if [[ "$projects" == "1" || "$native" == "1" || "$arduino" == "1" || "$stm" == "
100100
bgemm_functor.h \
101101
cortexm/bconv2d_impl.h \
102102
cortexm/bgemv.h \
103+
bmaxpool.h \
103104
packbits.h \
104105
packbits_utils.h \
105106
types.h"

larq_compute_engine/micro/kernels/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@ cc_library(
1212
name = "lce_op_kernels",
1313
srcs = [
1414
"bconv2d.cc",
15+
"bmaxpool.cc",
1516
],
1617
hdrs = [
1718
"micro_ops.h",
1819
],
1920
copts = micro_copts(),
2021
deps = [
2122
"//larq_compute_engine/core:bconv2d_impl_ref",
23+
"//larq_compute_engine/core:bmaxpool",
2224
"//larq_compute_engine/core:packbits_utils",
2325
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
2426
"@org_tensorflow//tensorflow/lite/kernels:padding",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
2+
#include "larq_compute_engine/core/bmaxpool.h"
3+
4+
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
5+
#include "larq_compute_engine/core/packbits_utils.h"
6+
#include "tensorflow/lite/c/builtin_op_data.h"
7+
#include "tensorflow/lite/c/common.h"
8+
#include "tensorflow/lite/kernels/internal/tensor.h"
9+
#include "tensorflow/lite/kernels/kernel_util.h"
10+
#include "tensorflow/lite/kernels/op_macros.h"
11+
12+
using namespace tflite;
13+
14+
namespace compute_engine {
15+
namespace tflite {
16+
namespace maxpool {
17+
18+
using namespace compute_engine::ref;
19+
using namespace compute_engine::core;
20+
21+
using TBitpacked = std::uint32_t;
22+
23+
struct MicroBMaxPoolParams : public BMaxPoolParams {
24+
int packed_input_id;
25+
};
26+
27+
bool StringEquals(const flexbuffers::String& a, const char* b) {
28+
// We use `strcmp` instead of `std::string` to avoid dynamic memory allocation
29+
return strcmp(a.c_str(), b) == 0;
30+
}
31+
32+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
33+
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
34+
MicroBMaxPoolParams* poolparams = nullptr;
35+
if (context->AllocatePersistentBuffer(context, sizeof(MicroBMaxPoolParams),
36+
(void**)&poolparams) != kTfLiteOk) {
37+
context->ReportError(context, "Could not allocate persistent buffer.");
38+
return nullptr;
39+
}
40+
41+
const std::uint8_t* buffer_t = reinterpret_cast<const std::uint8_t*>(buffer);
42+
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
43+
44+
poolparams->filter_height = m["filter_height"].AsInt32();
45+
poolparams->filter_width = m["filter_width"].AsInt32();
46+
poolparams->stride_height = m["stride_height"].AsInt32();
47+
poolparams->stride_width = m["stride_width"].AsInt32();
48+
49+
auto padding_str = m["padding"].AsString();
50+
if (StringEquals(padding_str, "VALID") ||
51+
StringEquals(padding_str, "valid")) {
52+
poolparams->padding_type = kTfLitePaddingValid;
53+
} else if (StringEquals(padding_str, "SAME") ||
54+
StringEquals(padding_str, "same")) {
55+
poolparams->padding_type = kTfLitePaddingSame;
56+
} else {
57+
context->ReportError(context, "Bmaxpool2d: invalid padding attribute.");
58+
}
59+
return poolparams;
60+
}
61+
62+
// The only thing done in Prepare is asserts
63+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
64+
MicroBMaxPoolParams* poolparams =
65+
reinterpret_cast<MicroBMaxPoolParams*>(node->user_data);
66+
67+
TF_LITE_ENSURE(context, poolparams != nullptr);
68+
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
69+
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
70+
TfLiteTensor* output = GetOutput(context, node, 0);
71+
const TfLiteTensor* input = GetInput(context, node, 0);
72+
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
73+
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt32);
74+
75+
int batches = input->dims->data[0];
76+
int height = input->dims->data[1];
77+
int width = input->dims->data[2];
78+
79+
int out_width, out_height;
80+
poolparams->padding = ComputePaddingHeightWidth(
81+
poolparams->stride_height, poolparams->stride_width, 1, 1, height, width,
82+
poolparams->filter_height, poolparams->filter_width,
83+
poolparams->padding_type, &out_height, &out_width);
84+
85+
int channels_out = 0;
86+
if (input->type == kTfLiteFloat32 || input->type == kTfLiteInt8) {
87+
channels_out = GetPackedSize<TBitpacked>(input->dims->data[3]);
88+
} else {
89+
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32);
90+
channels_out = input->dims->data[3];
91+
}
92+
93+
// Use temoprary tensor for bitpacked inputs
94+
if (input->type == kTfLiteFloat32 || input->type == kTfLiteInt8) {
95+
int flat_size =
96+
batches * height * width * channels_out * sizeof(TBitpacked);
97+
98+
TF_LITE_ENSURE_OK(context,
99+
context->RequestScratchBufferInArena(
100+
context, flat_size, &poolparams->packed_input_id));
101+
}
102+
103+
return kTfLiteOk;
104+
}
105+
106+
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
107+
MicroBMaxPoolParams* poolparams =
108+
reinterpret_cast<MicroBMaxPoolParams*>(node->user_data);
109+
110+
TF_LITE_ENSURE(context, poolparams != nullptr);
111+
112+
TfLiteTensor* output = GetOutput(context, node, 0);
113+
const TfLiteTensor* input = GetInput(context, node, 0);
114+
115+
const TBitpacked* packed_input_data;
116+
RuntimeShape packed_input_shape;
117+
118+
if (input->type == kTfLiteFloat32) {
119+
TBitpacked* packed_input = reinterpret_cast<TBitpacked*>(
120+
context->GetScratchBuffer(context, poolparams->packed_input_id));
121+
ce::core::packbits_tensor<ce::core::BitpackOrder::Canonical>(
122+
GetTensorShape(input), GetTensorData<float>(input), 0,
123+
packed_input_shape, packed_input);
124+
packed_input_data = packed_input;
125+
} else if (input->type == kTfLiteInt8) {
126+
TBitpacked* packed_input = reinterpret_cast<TBitpacked*>(
127+
context->GetScratchBuffer(context, poolparams->packed_input_id));
128+
ce::core::packbits_tensor<ce::core::BitpackOrder::Canonical>(
129+
GetTensorShape(input), GetTensorData<std::int8_t>(input),
130+
input->params.zero_point, packed_input_shape, packed_input);
131+
packed_input_data = packed_input;
132+
} else {
133+
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32);
134+
packed_input_shape.ReplaceWith(4, GetTensorShape(input).DimsData());
135+
packed_input_data = GetTensorData<TBitpacked>(input);
136+
}
137+
138+
BMaxPoolParams* pp = poolparams;
139+
BMaxPool(*pp, packed_input_shape, packed_input_data, GetTensorShape(output),
140+
GetTensorData<TBitpacked>(output));
141+
142+
return kTfLiteOk;
143+
}
144+
145+
} // namespace maxpool
146+
147+
TfLiteRegistration* Register_BMAXPOOL_2D() {
148+
static TfLiteRegistration r = {maxpool::Init, nullptr, maxpool::Prepare,
149+
maxpool::Eval};
150+
return &r;
151+
}
152+
153+
} // namespace tflite
154+
} // namespace compute_engine

larq_compute_engine/micro/kernels/micro_ops.h

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace tflite {
88

99
TfLiteRegistration* Register_BCONV_2D();
1010
TfLiteRegistration* Register_BCONV_2D_NoFloat();
11+
TfLiteRegistration* Register_BMAXPOOL_2D();
1112

1213
} // namespace tflite
1314
} // namespace compute_engine

larq_compute_engine/micro/tests/end2end_test.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,31 @@ def quant(x):
3333
def toy_model_int8(**kwargs):
3434
img = tf.keras.layers.Input(shape=(16, 16, 3))
3535
x = quant(img)
36+
x = tf.keras.layers.MaxPooling2D((2, 2))(x) # Binary maxpool
3637
x = lq.layers.QuantConv2D(
37-
32, 3, input_quantizer="ste_sign", kernel_quantizer="ste_sign", activation=quant
38+
32,
39+
3,
40+
input_quantizer="ste_sign",
41+
kernel_quantizer="ste_sign",
42+
padding="same",
43+
pad_values=1.0,
3844
)(x)
45+
x = tf.keras.layers.MaxPooling2D((2, 2))(x) # Binary maxpool
3946
x = lq.layers.QuantConv2D(
40-
64, 3, input_quantizer="ste_sign", kernel_quantizer="ste_sign", activation=quant
47+
64,
48+
3,
49+
input_quantizer="ste_sign",
50+
kernel_quantizer="ste_sign",
51+
padding="same",
52+
pad_values=1.0,
4153
)(x)
4254
x = lq.layers.QuantConv2D(
43-
32, 3, input_quantizer="ste_sign", kernel_quantizer="ste_sign", activation=quant
55+
32,
56+
3,
57+
input_quantizer="ste_sign",
58+
kernel_quantizer="ste_sign",
59+
padding="same",
60+
pad_values=1.0,
4461
)(x)
4562
x = global_pool(x)
4663
x = lq.layers.QuantDense(

larq_compute_engine/micro/tests/lce_test/lce_test.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
4848
model->version(), TFLITE_SCHEMA_VERSION);
4949
}
5050

51-
tflite::MicroOpResolver<9> resolver;
51+
tflite::MicroOpResolver<10> resolver;
5252
resolver.AddBuiltin(tflite::BuiltinOperator_CONV_2D,
5353
tflite::ops::micro::Register_CONV_2D(), 3, 3);
5454
resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
@@ -67,6 +67,8 @@ TF_LITE_MICRO_TEST(LoadModelAndPerformInference) {
6767
tflite::ops::micro::Register_DEQUANTIZE(), 2, 2);
6868
resolver.AddCustom("LceBconv2d",
6969
compute_engine::tflite::Register_BCONV_2D_NoFloat());
70+
resolver.AddCustom("LceBMaxPool2d",
71+
compute_engine::tflite::Register_BMAXPOOL_2D());
7072

7173
// Create an area of memory to use for input, output, and intermediate arrays.
7274
// Finding the minimum value for your model may require some trial and error.

0 commit comments

Comments
 (0)