Skip to content

Commit

Permalink
QuantConv2D binarized activations with tf.int32 bitpacked output (#611)
Browse files Browse the repository at this point in the history
* added function strip_lcedequantize_ops:
- strips the output LceDequantize operators of a model such that the output is a bitpacked tf.int32 tensor
- usually the lce_converter dequantizes the bitpacked output back to tf.float32 resulting in an identity tensor
- use cases: larq.layers.QuantConv2D followed by a sign operation (ie. larq.math.sign or larq.quantizers.SteSign())
- import using `from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops`

* reformatted using black code style

* added pytest module for verifying lce_dequantize_ops

* fixed larq import errors and renamed unit test function

* fix PyFlakes error due to typo when defining toy_model

* using Interpreter from larq_compute_engine.tflite.python.interpreter instead of tf.lite

* reformatted strip_lcedequantize_test.py using black code style

* added function strip_lcedequantize_ops:
- strips the output LceDequantize operators of a model such that the output is a bitpacked tf.int32 tensor
- usually the lce_converter dequantizes the bitpacked output back to tf.float32 resulting in an identity tensor
- use cases: larq.layers.QuantConv2D followed by a sign operation (ie. larq.math.sign or larq.quantizers.SteSign())
- import using `from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops`

* reformatted using black code style

* added pytest module for verifying lce_dequantize_ops

* fixed larq import errors and renamed unit test function

* fix PyFlakes error due to typo when defining toy_model

* using Interpreter from larq_compute_engine.tflite.python.interpreter instead of tf.lite

* reformatted strip_lcedequantize_test.py using black code style

* Remove dependency of compute engine interpreter

* Add bazel target for dequantize test

* Update strip_lcedequantize_test.py

fixed test_strip_lcedequantize_ops function test as only models with tf.float32 output will result in tf.int32 tensor outputs when used with strip_lcedequantize_ops

* Update strip_lcedequantize_test.py

refactored if-else statement

* Update strip_lcedequantize_test.py

deactivate setting default int8 ranges for `tf.float32` models as the strip_lcedequantize_ops function will not remove `LceDequantize` ops

* fix: accidentally added merge indicators

* Update strip_lcedequantize_test.py

Testing strip_lcedequantize_ops for tf.float32 output:
- fix double allocation of Interpreter, using tf.lite.Interpreter instead
- fix typo when converting model to TFLite model

* Update strip_lcedequantize_test.py

removed import of Larq interpreter due to Lint tests failing

* Adapt unit test for output type checking

- only validate output after LceDequantize ops have been stripped, input type tests already validated in end2end_test.py

* Update strip_lcedequantize_test.py

fix: setting inference_input_type statically to tf.float32 as we're only validating the output

* set tf.float32 as parametrized input type

* Updated strip_lcedequantize_ops() to support more models:
- updated signature defs for TF2.5 compatibility
- support int8-quantized models when stripping LceDequantize op for int8 output
- support int8-quantized models when using dequantized tf.float32 output, strips Dequantize operator first then LceDequantize

* Unit tests for tf.int8 input/output models

* Correction in toy_model_int8_sign

- fake quantize before QuantConv2D

* Extended Unit tests for test_strip_lcedequantize_ops() to parametrize experimental_enable_bitpacked_activations

* Clean up using black code style

Co-authored-by: Lukas Geiger <[email protected]>
  • Loading branch information
simonmaurer and lgeiger authored Sep 8, 2021
1 parent 5b8284c commit cd041b5
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ jobs:
run: bazelisk test larq_compute_engine/mlir/tests:all --test_output=all
- name: Run End2End tests
run: bazelisk test larq_compute_engine/tests:end2end_test --test_output=all
- name: Run Strip dequantize op tests
run: bazelisk test larq_compute_engine/tests:strip_lcedequantize_test --test_output=all

ConverterPython:
runs-on: ubuntu-latest
Expand Down
143 changes: 143 additions & 0 deletions larq_compute_engine/mlir/python/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,146 @@ def modify_integer_quantized_model_io_type(

# Convert the model to a bytearray
return _convert_model_from_object_to_bytearray(model)


def strip_lcedequantize_ops(model):
"""Strip the LceDequantize ops to directly output bitpacked tf.int32 tensors."""
# Convert the model to an object
model = _convert_model_from_bytearray_to_object(model)

if len(model.subgraphs) > 1:
raise ValueError(
"Model must only have one subgraph. Instead, it has "
"{} subgraphs.".format(len(model.subgraphs))
)

# Ensure model has at least one LceDequantize and/or Dequantize operator
lce_dequant_opcode_idx, dequant_opcode_idx = None, None
for idx, opcode in enumerate(model.operatorCodes):
if opcode.customCode == b"LceDequantize":
lce_dequant_opcode_idx = idx
elif opcode.builtinCode == tflite_schema.BuiltinOperator.DEQUANTIZE:
dequant_opcode_idx = idx
if lce_dequant_opcode_idx is not None and dequant_opcode_idx is not None:
break
if lce_dequant_opcode_idx is None and dequant_opcode_idx is None:
raise ValueError(
"Model does not contain any LceDequantize or Dequantize operators."
)

# Ensure model outputs are dequantized and remove Dequantize ops first if any
if dequant_opcode_idx is not None:
subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
remove_tensors_idxs = set()

output_dequant_ops = []
for op in operators:
# Find output Dequantize operator
if (
op.opcodeIndex == dequant_opcode_idx
and op.outputs[0] in subgraph.outputs
):
pos, float_tensor, int_tensor = (
"output",
tensors[op.outputs[0]],
tensors[op.inputs[0]],
)
output_dequant_ops.append(op)
# Otherwise, ignore
else:
continue
# If found, validate the input/output tensor type
if float_tensor.type != tflite_schema.TensorType.FLOAT32:
raise ValueError(
"Model {} type must be tf.float32. Expected type for tensor with "
"name '{}' is tf.float32, instead type is tf.{}".format(
pos,
float_tensor.name,
_convert_tflite_enum_type_to_tf_type(float_tensor.type).name,
)
)
if int_tensor.type != tflite_schema.TensorType.INT8:
raise ValueError(
"Model is not integer quantized. Expected type for tensor with "
"name '{}' is tf.int8, instead type is tf.{}".format(
int_tensor.name,
_convert_tflite_enum_type_to_tf_type(int_tensor.type).name,
)
)

# Remove the Dequantize operators
for op in output_dequant_ops:
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
if model.signatureDefs:
signature_def = model.signatureDefs[0]
for i in range(len(signature_def.outputs)):
if signature_def.outputs[i].tensorIndex == op.outputs[0]:
signature_def.outputs[i].tensorIndex = op.inputs[0]
remove_tensors_idxs.add(op.outputs[0])
operators.remove(op)

# Remove tensors marked for deletion.
_remove_tensors_from_model(model, remove_tensors_idxs)

subgraph = model.subgraphs[0]
tensors = subgraph.tensors
operators = subgraph.operators
remove_tensors_idxs = set()

# Ensure model outputs are Lce dequantized and remove LceDequantize ops
lce_output_dequant_ops = []
for op in operators:
# Find output LceDequantize operator
if (
op.opcodeIndex == lce_dequant_opcode_idx
and op.outputs[0] in subgraph.outputs
):
pos, output_tensor, input_tensor = (
"output",
tensors[op.outputs[0]],
tensors[op.inputs[0]],
)
lce_output_dequant_ops.append(op)
# Otherwise, ignore
else:
continue
# If found, validate the input/output tensor type
if (
output_tensor.type != tflite_schema.TensorType.FLOAT32
and output_tensor.type != tflite_schema.TensorType.INT8
):
raise ValueError(
"Model {} type must be tf.float32/tf.int8. Expected type for tensor with "
"name '{}' is tf.float32/tf.int8, instead type is tf.{}".format(
pos,
output_tensor.name,
_convert_tflite_enum_type_to_tf_type(output_tensor.type).name,
)
)
if input_tensor.type != tflite_schema.TensorType.INT32:
raise ValueError(
"Expected type for tensor with "
"name '{}' is tf.int32, instead type is tf.{}".format(
input_tensor.name,
_convert_tflite_enum_type_to_tf_type(input_tensor.type).name,
)
)

# Remove the LceDequantize operators
for op in lce_output_dequant_ops:
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
if model.signatureDefs:
signature_def = model.signatureDefs[0]
for i in range(len(signature_def.outputs)):
if signature_def.outputs[i].tensorIndex == op.outputs[0]:
signature_def.outputs[i].tensorIndex = op.inputs[0]
remove_tensors_idxs.add(op.outputs[0])
operators.remove(op)

# Remove tensors marked for deletion.
_remove_tensors_from_model(model, remove_tensors_idxs)

# Convert the model to a bytearray
return _convert_model_from_object_to_bytearray(model)
8 changes: 8 additions & 0 deletions larq_compute_engine/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ py_test(
],
)

py_test(
name = "strip_lcedequantize_test",
srcs = ["strip_lcedequantize_test.py"],
deps = [
"//larq_compute_engine/mlir:converter",
],
)

py_test(
name = "convert_model",
srcs = ["convert_model.py"],
Expand Down
73 changes: 73 additions & 0 deletions larq_compute_engine/tests/strip_lcedequantize_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import sys

import larq as lq
import pytest
import tensorflow as tf

from larq_compute_engine.mlir.python.converter import convert_keras_model
from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops


def toy_model_sign(**kwargs):
img = tf.keras.layers.Input(shape=(224, 224, 3))
x = lq.layers.QuantConv2D(
256,
kernel_size=3,
strides=1,
padding="same",
pad_values=1,
input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
kernel_constraint="weight_clip",
)(img)
x = lq.quantizers.SteSign()(x)
return tf.keras.Model(inputs=img, outputs=x)


def quant(x):
return tf.quantization.fake_quant_with_min_max_vars(x, -3.0, 3.0)


def toy_model_int8_sign(**kwargs):
img = tf.keras.layers.Input(shape=(224, 224, 3))
x = quant(img)
x = lq.layers.QuantConv2D(
256,
kernel_size=3,
strides=1,
padding="same",
pad_values=1,
input_quantizer="ste_sign",
kernel_quantizer="ste_sign",
kernel_constraint="weight_clip",
)(x)
x = lq.quantizers.SteSign()(x)
x = quant(x)
return tf.keras.Model(inputs=img, outputs=x)


@pytest.mark.parametrize("model_cls", [toy_model_sign, toy_model_int8_sign])
@pytest.mark.parametrize("inference_input_type", [tf.float32, tf.int8])
@pytest.mark.parametrize("inference_output_type", [tf.float32, tf.int8])
@pytest.mark.parametrize("experimental_enable_bitpacked_activations", [True, False])
def test_strip_lcedequantize_ops(
model_cls,
inference_input_type,
inference_output_type,
experimental_enable_bitpacked_activations,
):
model_lce = convert_keras_model(
model_cls(),
inference_input_type=inference_input_type,
inference_output_type=inference_output_type,
experimental_enable_bitpacked_activations=experimental_enable_bitpacked_activations,
)
model_lce = strip_lcedequantize_ops(model_lce)
interpreter = tf.lite.Interpreter(model_content=model_lce)
output_details = interpreter.get_output_details()
assert len(output_details) == 1
assert output_details[0]["dtype"] == tf.int32.as_numpy_dtype


if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-s"]))

0 comments on commit cd041b5

Please sign in to comment.