Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change #84

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion morph_net/framework/batch_norm_source_op_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def setUp(self):

# Declare OpSlice and OpGroup for ops that are created in the test network.
self.batch_norm_op = g.get_operation_by_name(
'conv1/BatchNorm/FusedBatchNorm')
'conv1/BatchNorm/FusedBatchNormV3')
self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5))
self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

Expand Down
4 changes: 2 additions & 2 deletions morph_net/framework/concat_op_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def setUp(self):

self.axis_op = g.get_operation_by_name('concat/axis')

self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNorm')
self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNormV3')
self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 18))
self.batch_norm_op_group = orm.OpGroup(
self.batch_norm_op_slice,
Expand Down Expand Up @@ -808,7 +808,7 @@ def setUp(self):
self.relu3_op_group = orm.OpGroup(
self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice])

self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNorm')
self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNormV3')
self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6))
self.batch_norm_op_group = orm.OpGroup(
self.batch_norm_op_slice,
Expand Down
2 changes: 1 addition & 1 deletion morph_net/framework/grouping_op_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setUp(self):

# Declare OpSlice and OpGroup for ops of interest.
self.batch_norm_op = g.get_operation_by_name(
'conv1/BatchNorm/FusedBatchNorm')
'conv1/BatchNorm/FusedBatchNormV3')
self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5))
self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

Expand Down
2 changes: 1 addition & 1 deletion morph_net/framework/leaf_op_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setUp(self):

# Declare OpSlice and OpGroup for ops of interest.
self.batch_norm_op = g.get_operation_by_name(
'conv1/BatchNorm/FusedBatchNorm')
'conv1/BatchNorm/FusedBatchNormV3')
self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 5))
self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

Expand Down
6 changes: 3 additions & 3 deletions morph_net/framework/op_handler_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def setUp(self):

# Declare OpSlice and OpGroup for ops in the first test network.
self.batch_norm_op = g.get_operation_by_name(
'conv1/BatchNorm/FusedBatchNorm')
'conv1/BatchNorm/FusedBatchNormV3')
self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, None)
self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

Expand Down Expand Up @@ -86,7 +86,7 @@ def setUp(self):
self.relu4_op_slice, omit_source_op_slices=[self.relu4_op_slice])

self.unfused_batch_norm_op = g.get_operation_by_name(
'BatchNorm/FusedBatchNorm')
'BatchNorm/FusedBatchNormV3')
self.unfused_batch_norm_op_slice = orm.OpSlice(
self.unfused_batch_norm_op, orm.Slice(0, 18))

Expand Down Expand Up @@ -676,7 +676,7 @@ def testOpAssumptions(self):

g = tf.get_default_graph()

# Verify that FusedBatchNorm has gamma as inputs[1].
# Verify that FusedBatchNormV3 has gamma as inputs[1].
self.assertEqual('conv1/BatchNorm/gamma/read:0',
self.batch_norm_op.inputs[1].name)

Expand Down
30 changes: 18 additions & 12 deletions morph_net/framework/op_regularizer_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ def setUp(self):
self._default_op_handler_dict = collections.defaultdict(
grouping_op_handler.GroupingOpHandler)
self._default_op_handler_dict.update({
'FusedBatchNorm': IndexBatchNormSourceOpHandler(),
'FusedBatchNormV3':
IndexBatchNormSourceOpHandler(),
'Conv2D':
output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
'ConcatV2':
concat_op_handler.ConcatOpHandler(),
concat_op_handler.ConcatOpHandler(),
'DepthwiseConv2dNative':
depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
})

def _batch_norm_scope(self):
Expand Down Expand Up @@ -86,7 +87,8 @@ def testSimpleOpGetRegularizer(self, use_batch_norm, use_partitioner, scope):

# Instantiate OpRegularizerManager.
op_handler_dict = self._default_op_handler_dict
op_handler_dict['FusedBatchNorm'] = StubBatchNormSourceOpHandler(model_stub)
op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler(
model_stub)
if not use_batch_norm:
op_handler_dict['Conv2D'] = StubConv2DSourceOpHandler(model_stub)
op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict)
Expand All @@ -112,7 +114,8 @@ def testConcatOpGetRegularizer(self, use_batch_norm, use_partitioner):

# Instantiate OpRegularizerManager.
op_handler_dict = self._default_op_handler_dict
op_handler_dict['FusedBatchNorm'] = StubBatchNormSourceOpHandler(model_stub)
op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler(
model_stub)
if not use_batch_norm:
op_handler_dict['Conv2D'] = StubConv2DSourceOpHandler(model_stub)
op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict)
Expand All @@ -139,7 +142,8 @@ def testGroupConcatOpGetRegularizerValues(self, op_name, short_name):

# Instantiate OpRegularizerManager.
op_handler_dict = self._default_op_handler_dict
op_handler_dict['FusedBatchNorm'] = StubBatchNormSourceOpHandler(model_stub)
op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler(
model_stub)

op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict)

Expand All @@ -158,7 +162,8 @@ def testGroupConcatOpGetRegularizerObjects(self):

# Instantiate OpRegularizerManager.
op_handler_dict = self._default_op_handler_dict
op_handler_dict['FusedBatchNorm'] = StubBatchNormSourceOpHandler(model_stub)
op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler(
model_stub)

op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict)
self.assertEqual(
Expand Down Expand Up @@ -1688,14 +1693,15 @@ def testDfsForSourceOps(self):

# Verify source ops were found.
expected_queue = collections.deque([
_get_op('conv3/BatchNorm/FusedBatchNorm'),
_get_op('conv2/BatchNorm/FusedBatchNorm'),
_get_op('conv1/BatchNorm/FusedBatchNorm')])
_get_op('conv3/BatchNorm/FusedBatchNormV3'),
_get_op('conv2/BatchNorm/FusedBatchNormV3'),
_get_op('conv1/BatchNorm/FusedBatchNormV3')
])
self.assertEqual(expected_queue, manager._op_deque)

# Verify extra branch was not included.
self.assertNotIn(
_get_op('conv4/BatchNorm/FusedBatchNorm'), manager._op_deque)
_get_op('conv4/BatchNorm/FusedBatchNormV3'), manager._op_deque)

def testOpGroup_NewSourceGroup(self):
inputs = tf.zeros([2, 4, 4, 3])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setUp(self):
self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])

self.batch_norm_op = g.get_operation_by_name(
'conv2/BatchNorm/FusedBatchNorm')
'conv2/BatchNorm/FusedBatchNormV3')
self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6))
self.batch_norm_op_group = orm.OpGroup(self.batch_norm_op_slice)

Expand Down
1 change: 1 addition & 0 deletions morph_net/network_regularizers/activation_regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
op_handler_dict.update({
'FusedBatchNorm': source_op_handler,
'FusedBatchNormV2': source_op_handler,
'FusedBatchNormV3': source_op_handler,
})

self._manager = orm.OpRegularizerManager(
Expand Down
5 changes: 3 additions & 2 deletions morph_net/network_regularizers/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
CONV3D_OPS = ('Conv3D',)
CONV_OPS = CONV2D_OPS + CONV3D_OPS
FLOP_OPS = CONV_OPS + ('MatMul',)
SUPPORTED_OPS = FLOP_OPS + (
'Add', 'AddN', 'ConcatV2', 'FusedBatchNorm', 'Mul', 'Relu', 'Relu6', 'Sum')
SUPPORTED_OPS = FLOP_OPS + ('Add', 'AddN', 'ConcatV2', 'FusedBatchNorm',
'FusedBatchNormV2', 'FusedBatchNormV3', 'Mul',
'Relu', 'Relu6', 'Sum')


class CostCalculator(object):
Expand Down
8 changes: 4 additions & 4 deletions morph_net/network_regularizers/cost_calculator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def testImageIsNotZerothOutputOfOp(self):
op_handler_dict = collections.defaultdict(
grouping_op_handler.GroupingOpHandler)
op_handler_dict.update({
'FusedBatchNorm':
batch_norm_source_op_handler.BatchNormSourceOpHandler(0.1),
'FusedBatchNormV3':
batch_norm_source_op_handler.BatchNormSourceOpHandler(0.1),
'Conv2D':
output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
'ConcatV2':
concat_op_handler.ConcatOpHandler(),
concat_op_handler.ConcatOpHandler(),
})

# Create OpRegularizerManager and NetworkRegularizer for test.
Expand Down
1 change: 1 addition & 0 deletions morph_net/network_regularizers/flop_regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
op_handler_dict.update({
'FusedBatchNorm': source_op_handler,
'FusedBatchNormV2': source_op_handler,
'FusedBatchNormV3': source_op_handler,
})

self._manager = orm.OpRegularizerManager(
Expand Down
14 changes: 7 additions & 7 deletions morph_net/network_regularizers/flop_regularizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,17 @@ def testInputBoundaryNone(self):
self.BuildWithBatchNorm(fused=True)
self.AddRegularizer(input_boundary=None)
self.assertCountEqual(self.GetSourceOps(), [
'conv1/BatchNorm/FusedBatchNorm', 'conv2/BatchNorm/FusedBatchNorm',
'conv3/BatchNorm/FusedBatchNorm', 'conv4/BatchNorm/FusedBatchNorm'
'conv1/BatchNorm/FusedBatchNormV3', 'conv2/BatchNorm/FusedBatchNormV3',
'conv3/BatchNorm/FusedBatchNormV3', 'conv4/BatchNorm/FusedBatchNormV3'
])

def testInputBoundaryConv3(self):
# Only block one path, can still reach all other convolutions.
self.BuildWithBatchNorm(fused=True)
self.AddRegularizer(input_boundary=[self.conv3.op])
self.assertCountEqual(self.GetSourceOps(), [
'conv1/BatchNorm/FusedBatchNorm', 'conv2/BatchNorm/FusedBatchNorm',
'conv4/BatchNorm/FusedBatchNorm'
'conv1/BatchNorm/FusedBatchNormV3', 'conv2/BatchNorm/FusedBatchNormV3',
'conv4/BatchNorm/FusedBatchNormV3'
])

def testInputBoundaryConv3And4(self):
Expand All @@ -152,9 +152,9 @@ def testInputBoundaryConcat(self):
# Block concat, can only see conv3 and conv4.
self.BuildWithBatchNorm(fused=True)
self.AddRegularizer(input_boundary=[self.concat.op])
self.assertCountEqual(
self.GetSourceOps(),
['conv3/BatchNorm/FusedBatchNorm', 'conv4/BatchNorm/FusedBatchNorm'])
self.assertCountEqual(self.GetSourceOps(), [
'conv3/BatchNorm/FusedBatchNormV3', 'conv4/BatchNorm/FusedBatchNormV3'
])

def testLossDecorated(self):
self.BuildWithBatchNorm(True)
Expand Down
1 change: 1 addition & 0 deletions morph_net/network_regularizers/latency_regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
op_handler_dict.update({
'FusedBatchNorm': source_op_handler,
'FusedBatchNormV2': source_op_handler,
'FusedBatchNormV3': source_op_handler,
})

self._manager = orm.OpRegularizerManager(
Expand Down
1 change: 1 addition & 0 deletions morph_net/network_regularizers/model_size_regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
op_handler_dict.update({
'FusedBatchNorm': source_op_handler,
'FusedBatchNormV2': source_op_handler,
'FusedBatchNormV3': source_op_handler,
})

self._manager = orm.OpRegularizerManager(
Expand Down