Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update for tensorflow v2 compatibility #917

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmdnn/conversion/tensorflow/saver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import tensorflow as tf
import tensorflow.compat.v1 as tf


def save_model(MainModel, network_filepath, weight_filepath, dump_filepath, dump_tag = 'SERVING'):
Expand Down
13 changes: 9 additions & 4 deletions mmdnn/conversion/tensorflow/tensorflow_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class TensorflowEmitter(Emitter):

@property
def header_code(self):
return """import tensorflow as tf
return """import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.compat.v1.enable_resource_variables()

_weights_dict = dict()

Expand Down Expand Up @@ -183,7 +185,10 @@ def emit_Pool(self, IR_node):
pooling_type = IR_node.get_attr('pooling_type')
if pooling_type == 'MAX':
op = 'max_pool'
padding_const = ", constant_values=float('-Inf')"
# depending on the TensorFlow backend, constant -Inf may not be resolvable
# this change makes resulting tf saved model further compatible with tfjs graph model which can be used with webgl backend
# -3.4028235e+38 is a value of -tf.float32.max
padding_const = ", constant_values=float('-3.4028235e+38')"
elif pooling_type == 'AVG':
op = 'avg_pool'
padding_const = ""
Expand Down Expand Up @@ -292,7 +297,7 @@ def emit_FullyConnected(self, IR_node):
parent_shape = shape_to_list(parent.get_attr('_output_shapes')[0])
if len(parent_shape) > 2:
# flatten is needed
self.add_body(1, "{:<15} = tf.contrib.layers.flatten({})".format(
self.add_body(1, "{:<15} = tf.compat.v1.layers.Flatten()({})".format(
IR_node.variable_name + '_flatten',
self.parent_variable_name(IR_node)))

Expand Down Expand Up @@ -334,7 +339,7 @@ def emit_UpSampling2D(self, IR_node):

def emit_Flatten(self, IR_node):
#self._emit_unary_operation(IR_node, "contrib.layers.flatten")
code = "{:<15} = tf.contrib.layers.flatten({})".format(
code = "{:<15} = tf.compat.v1.layers.Flatten()({})".format(
IR_node.variable_name,
self.parent_variable_name(IR_node))
return code
Expand Down
2 changes: 1 addition & 1 deletion mmdnn/conversion/tensorflow/tensorflow_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#----------------------------------------------------------------------------------------------

import numpy as np
import tensorflow
import tensorflow.compat.v1 as tensorflow
from tensorflow.python.framework import tensor_util
from tensorflow.core.framework import attr_value_pb2
from mmdnn.conversion.tensorflow.tensorflow_graph import TensorflowGraph
Expand Down