From 47df8efdcc8c3110a74e6ce8292e7913933655d1 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 1 Mar 2021 09:08:36 -0500 Subject: [PATCH] update for tensorflow v2 compatibility --- mmdnn/conversion/tensorflow/saver.py | 2 +- mmdnn/conversion/tensorflow/tensorflow_emitter.py | 13 +++++++++---- mmdnn/conversion/tensorflow/tensorflow_parser.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mmdnn/conversion/tensorflow/saver.py b/mmdnn/conversion/tensorflow/saver.py index 1cea7eb6..51a27113 100644 --- a/mmdnn/conversion/tensorflow/saver.py +++ b/mmdnn/conversion/tensorflow/saver.py @@ -1,4 +1,4 @@ -import tensorflow as tf +import tensorflow.compat.v1 as tf def save_model(MainModel, network_filepath, weight_filepath, dump_filepath, dump_tag = 'SERVING'): diff --git a/mmdnn/conversion/tensorflow/tensorflow_emitter.py b/mmdnn/conversion/tensorflow/tensorflow_emitter.py index 09fc49c7..8c0ca8d4 100644 --- a/mmdnn/conversion/tensorflow/tensorflow_emitter.py +++ b/mmdnn/conversion/tensorflow/tensorflow_emitter.py @@ -29,7 +29,9 @@ class TensorflowEmitter(Emitter): @property def header_code(self): - return """import tensorflow as tf + return """import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() +tf.compat.v1.enable_resource_variables() _weights_dict = dict() @@ -183,7 +185,10 @@ def emit_Pool(self, IR_node): pooling_type = IR_node.get_attr('pooling_type') if pooling_type == 'MAX': op = 'max_pool' - padding_const = ", constant_values=float('-Inf')" + # depending on the TensorFlow backend, constant -Inf may not be resolvable + # this change makes resulting tf saved model further compatible with tfjs graph model which can be used with webgl backend + # -3.4028235e+38 is a value of -tf.float32.max + padding_const = ", constant_values=float('-3.4028235e+38')" elif pooling_type == 'AVG': op = 'avg_pool' padding_const = "" @@ -292,7 +297,7 @@ def emit_FullyConnected(self, IR_node): parent_shape = shape_to_list(parent.get_attr('_output_shapes')[0]) if len(parent_shape) > 2: # flatten is needed - self.add_body(1, "{:<15} = tf.contrib.layers.flatten({})".format( + self.add_body(1, "{:<15} = tf.compat.v1.layers.Flatten()({})".format( IR_node.variable_name + '_flatten', self.parent_variable_name(IR_node))) @@ -334,7 +339,7 @@ def emit_UpSampling2D(self, IR_node): def emit_Flatten(self, IR_node): #self._emit_unary_operation(IR_node, "contrib.layers.flatten") - code = "{:<15} = tf.contrib.layers.flatten({})".format( + code = "{:<15} = tf.compat.v1.layers.Flatten()({})".format( IR_node.variable_name, self.parent_variable_name(IR_node)) return code diff --git a/mmdnn/conversion/tensorflow/tensorflow_parser.py b/mmdnn/conversion/tensorflow/tensorflow_parser.py index b39e6167..3ffd0fc4 100644 --- a/mmdnn/conversion/tensorflow/tensorflow_parser.py +++ b/mmdnn/conversion/tensorflow/tensorflow_parser.py @@ -4,7 +4,7 @@ #---------------------------------------------------------------------------------------------- import numpy as np -import tensorflow +import tensorflow.compat.v1 as tensorflow from tensorflow.python.framework import tensor_util from tensorflow.core.framework import attr_value_pb2 from mmdnn.conversion.tensorflow.tensorflow_graph import TensorflowGraph