Skip to content

Commit

Permalink
Remove activation quantizers from Keras wrapper (#25)
Browse files Browse the repository at this point in the history
* Remove activation quantizers from Keras wrapper

* Remove call to _set_activation_vars in Keras wrapper

---------

Co-authored-by: reuvenp <[email protected]>
  • Loading branch information
reuvenperetz and reuvenp committed Jun 12, 2023
1 parent 9c2413f commit 20c9c4d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 102 deletions.
79 changes: 5 additions & 74 deletions mct_quantizers/keras/quantize_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Dict, List, Any, Tuple

from mct_quantizers.common.base_inferable_quantizer import BaseInferableQuantizer
from mct_quantizers.common.constants import FOUND_TF, ACTIVATION_QUANTIZERS, WEIGHTS_QUANTIZERS, STEPS, LAYER, TRAINING
from mct_quantizers.common.constants import FOUND_TF, WEIGHTS_QUANTIZERS, STEPS, LAYER, TRAINING
from mct_quantizers.logger import Logger
from mct_quantizers.common.get_all_subclasses import get_all_subclasses

Expand Down Expand Up @@ -54,20 +54,17 @@ class KerasQuantizationWrapper(tf.keras.layers.Wrapper):
def __init__(self,
layer,
weights_quantizers: Dict[str, BaseInferableQuantizer] = None,
activation_quantizers: List[BaseInferableQuantizer] = None,
**kwargs):
"""
Keras Quantization Wrapper takes a keras layer and quantizers and infer a quantized layer.
Args:
layer: A keras layer.
weights_quantizers: A dictionary between a weight's name to its quantizer.
activation_quantizers: A list of activations quantization, one for each layer output.
"""
super(KerasQuantizationWrapper, self).__init__(layer, **kwargs)
self._track_trackable(layer, name='layer')
self.weights_quantizers = weights_quantizers if weights_quantizers is not None else dict()
self.activation_quantizers = activation_quantizers if activation_quantizers is not None else list()

def add_weights_quantizer(self, param_name: str, quantizer: BaseInferableQuantizer):
"""
Expand All @@ -82,15 +79,6 @@ def add_weights_quantizer(self, param_name: str, quantizer: BaseInferableQuantiz
"""
self.weights_quantizers.update({param_name: quantizer})

@property
def is_activation_quantization(self) -> bool:
"""
This function check activation quantizer exists in wrapper.
Returns: a boolean if activation quantizer exists
"""
return self.num_activation_quantizers > 0

@property
def is_weights_quantization(self) -> bool:
"""
Expand All @@ -108,22 +96,13 @@ def num_weights_quantizers(self) -> int:
"""
return len(self.weights_quantizers)

@property
def num_activation_quantizers(self) -> int:
"""
Returns: number of activations quantizers
"""
return len(self.activation_quantizers)

def get_config(self):
"""
Returns: Configuration of KerasQuantizationWrapper.
"""
base_config = super(KerasQuantizationWrapper, self).get_config()
config = {
ACTIVATION_QUANTIZERS: [keras.utils.serialize_keras_object(act) for act in self.activation_quantizers],
WEIGHTS_QUANTIZERS: {k: keras.utils.serialize_keras_object(v) for k, v in self.weights_quantizers.items()}}
config = {WEIGHTS_QUANTIZERS: {k: keras.utils.serialize_keras_object(v) for k, v in self.weights_quantizers.items()}}
return dict(list(base_config.items()) + list(config.items()))

def _set_weights_vars(self, is_training: bool = True):
Expand All @@ -143,17 +122,6 @@ def _set_weights_vars(self, is_training: bool = True):
self._weights_vars.append((name, weight, quantizer))
self._trainable_weights.append(weight) # Must when inherit from tf.keras.layers.Wrapper in tf2.10 and below

def _set_activations_vars(self):
"""
This function sets activations quantizers vars to the layer
Returns: None
"""
self._activation_vars = []
for i, quantizer in enumerate(self.activation_quantizers):
quantizer.initialize_quantization(None, self.layer.name + f'/out{i}', self)
self._activation_vars.append(quantizer)

@classmethod
def from_config(cls, config):
"""
Expand All @@ -167,14 +135,11 @@ def from_config(cls, config):
config = config.copy()
qi_inferable_custom_objects = {subclass.__name__: subclass for subclass in
get_all_subclasses(BaseKerasInferableQuantizer)}
activation_quantizers = [keras.utils.deserialize_keras_object(act,
module_objects=globals(),
custom_objects=None) for act in config.pop(ACTIVATION_QUANTIZERS)]
weights_quantizers = {k: keras.utils.deserialize_keras_object(v,
module_objects=globals(),
custom_objects=qi_inferable_custom_objects) for k, v in config.pop(WEIGHTS_QUANTIZERS).items()}
layer = tf.keras.layers.deserialize(config.pop(LAYER))
return cls(layer=layer, weights_quantizers=weights_quantizers, activation_quantizers=activation_quantizers, **config)
return cls(layer=layer, weights_quantizers=weights_quantizers, **config)

def build(self, input_shape):
"""
Expand All @@ -194,7 +159,6 @@ def build(self, input_shape):
trainable=False)

self._set_weights_vars()
self._set_activations_vars()

def set_quantize_weights(self, quantized_weights: dict):
"""
Expand Down Expand Up @@ -254,29 +218,6 @@ def call(self, inputs, training=None, **kwargs):
else:
outputs = self.layer.call(inputs, **kwargs)

# Quantize all activations if quantizers exist.
if self.is_activation_quantization:
num_outputs = len(outputs) if isinstance(outputs, (list, tuple)) else 1
if self.num_activation_quantizers != num_outputs:
Logger.error('Quantization wrapper output quantization error: '
f'number of outputs and quantizers mismatch ({num_outputs}!='
f'{self.num_activation_quantizers}')
if num_outputs == 1:
outputs = [outputs]

_outputs = []
for _output, act_quant in zip(outputs, self.activation_quantizers):
activation_quantizer_args_spec = tf_inspect.getfullargspec(act_quant.__call__).args
if TRAINING in activation_quantizer_args_spec:
_outputs.append(utils.smart_cond(
training,
_make_quantizer_fn(act_quant, _output, True),
_make_quantizer_fn(act_quant, _output, False)))
else:
# Keras activation inferable quantizer.
_outputs.append(act_quant(_output))
outputs = _outputs[0] if num_outputs == 1 else _outputs

return outputs

def convert_to_inferable_quantizers(self):
Expand All @@ -286,14 +227,6 @@ def convert_to_inferable_quantizers(self):
Returns:
None
"""
# Activations quantizers
inferable_activation_quantizers = []
if self.is_activation_quantization:
for quantizer in self.activation_quantizers:
if hasattr(quantizer, 'convert2inferable') and callable(quantizer.convert2inferable):
inferable_activation_quantizers.append(quantizer.convert2inferable())
self.activation_quantizers = inferable_activation_quantizers

# Weight quantizers
inferable_weight_quantizers = {}
if self.is_weights_quantization:
Expand All @@ -310,7 +243,7 @@ def convert_to_inferable_quantizers(self):
layer_weights_list.append(getattr(self.layer, weight_attr)) # quantized weights
layer_weights_list.extend(self.layer.get_weights()) # non quantized weights
inferable_quantizers_wrapper.layer.set_weights(layer_weights_list)
inferable_quantizers_wrapper._set_activations_vars()

# The wrapper inference is using the weights of the quantizers so it expectes to create them by running _set_weights_vars
inferable_quantizers_wrapper._set_weights_vars(False)
return inferable_quantizers_wrapper
Expand Down Expand Up @@ -342,15 +275,13 @@ def get_quantized_weights(self) -> Dict[str, tf.Tensor]:
class KerasQuantizationWrapper(object):
def __init__(self,
layer,
weights_quantizers: Dict[str, BaseInferableQuantizer] = None,
activation_quantizers: List[BaseInferableQuantizer] = None):
weights_quantizers: Dict[str, BaseInferableQuantizer] = None):
"""
Keras Quantization Wrapper takes a keras layer and quantizers and infer a quantized layer.
Args:
layer: A keras layer.
weights_quantizers: A dictionary between a weight's name to its quantizer.
activation_quantizers: A list of activations quantization, one for each layer output.
"""
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
'when using KerasQuantizationWrapper. '
Expand Down
28 changes: 0 additions & 28 deletions tests/keras_tests/test_keras_quantization_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,6 @@ def initialize_quantization(self, tensor_shape, name, layer):
return {}


class ZeroActivationsQuantizer:
"""
A dummy quantizer for test usage - "quantize" the layer's activation to 0
"""

def __call__(self,
inputs: tf.Tensor,
training: bool = True) -> tf.Tensor:
return inputs * 0

def initialize_quantization(self, tensor_shape, name, layer):
return {}


class TestKerasQuantizationWrapper(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -86,17 +72,3 @@ def test_weights_quantization_wrapper(self):
outputs = wrapper.call(call_inputs.astype('float32'))
self.assertTrue((outputs == conv_layer(call_inputs)).numpy().all())

def test_activation_quantization_wrapper(self):
conv_layer = self.model.layers[1]

wrapper = KerasQuantizationWrapper(conv_layer, activation_quantizers=[ZeroActivationsQuantizer()])

# build
wrapper.build(self.input_shapes)
(act_quantizer) = wrapper._activation_vars[0]
self.assertTrue(isinstance(act_quantizer, ZeroActivationsQuantizer))

# apply the wrapper on inputs
call_inputs = self.inputs[0]
outputs = wrapper.call(call_inputs.astype('float32'))
self.assertTrue((outputs == 0).numpy().all())

0 comments on commit 20c9c4d

Please sign in to comment.