Skip to content

Commit 995378c

Browse files
yongtangrmlarsen
authored andcommitted
Switch over to max_pool_v2 in Python (tensorflow#14983)
* Switch over to max_pool_v2 in Python This fix is a follow up to 11875 so that MaxPool in Python use v2 version. As 11875 has been merged some time ago, this fix conforms to the deprecation policy. This fix is realted to 11875 and 4746. Signed-off-by: Yong Tang <[email protected]> * Update test cases in contrib/specs/python/specs_test due to MaxPool -> MaxPoolV2 Signed-off-by: Yong Tang <[email protected]> * Update tensorflow/contrib/receptive_field Update tensorflow/contrib/receptive_field due to max_pool's strides and ksize from attr -> input Signed-off-by: Yong Tang <[email protected]> * Remove const restriction for strides and ksize Signed-off-by: Yong Tang <[email protected]> * Register MaxPoolV2 with XLA Signed-off-by: Yong Tang <[email protected]> * Reformat with clang-format -i --style=Google Signed-off-by: Yong Tang <[email protected]>
1 parent 2c50519 commit 995378c

File tree

5 files changed

+115
-57
lines changed

5 files changed

+115
-57
lines changed

tensorflow/compiler/tf2xla/kernels/pooling_ops.cc

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,23 @@ class PoolingOp : public XlaOpKernel {
3737
public:
3838
PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims)
3939
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
40-
std::vector<int32> ksize_int;
41-
std::vector<int32> stride_int;
42-
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
43-
OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
44-
errors::InvalidArgument("Sliding window ksize field must "
45-
"specify ",
46-
num_dims(), " dimensions"));
47-
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
48-
OP_REQUIRES(ctx, stride_int.size() == num_dims(),
49-
errors::InvalidArgument("Sliding window stride field must "
50-
"specify ",
51-
num_dims(), " dimensions"));
52-
for (int i = 0; i < num_dims(); ++i) {
53-
ksize_.push_back(ksize_int[i]);
54-
stride_.push_back(stride_int[i]);
40+
if (ctx->num_inputs() == 1) {
41+
std::vector<int32> ksize_int;
42+
std::vector<int32> stride_int;
43+
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
44+
OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
45+
errors::InvalidArgument("Sliding window ksize field must "
46+
"specify ",
47+
num_dims(), " dimensions"));
48+
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
49+
OP_REQUIRES(ctx, stride_int.size() == num_dims(),
50+
errors::InvalidArgument("Sliding window stride field must "
51+
"specify ",
52+
num_dims(), " dimensions"));
53+
for (int i = 0; i < num_dims(); ++i) {
54+
ksize_.push_back(ksize_int[i]);
55+
stride_.push_back(stride_int[i]);
56+
}
5557
}
5658
Padding padding;
5759
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
@@ -77,15 +79,42 @@ class PoolingOp : public XlaOpKernel {
7779
xla::ComputationDataHandle input = ctx->Input(0);
7880
const TensorShape input_shape = ctx->InputShape(0);
7981

82+
std::vector<int64> ksize = ksize_;
83+
std::vector<int64> stride = stride_;
84+
if (ctx->num_inputs() != 1) {
85+
const TensorShape ksize_shape = ctx->InputShape(1);
86+
// Validate input sizes.
87+
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
88+
errors::InvalidArgument("ksize must be a vector, not shape ",
89+
ksize_shape.DebugString()));
90+
OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
91+
errors::InvalidArgument("Sliding window ksize field must "
92+
"specify ",
93+
num_dims(), " dimensions"));
94+
ksize.clear();
95+
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));
96+
97+
const TensorShape stride_shape = ctx->InputShape(2);
98+
// Validate input sizes.
99+
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
100+
errors::InvalidArgument("stride must be a vector, not shape ",
101+
stride_shape.DebugString()));
102+
OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
103+
errors::InvalidArgument("Sliding window stride field must "
104+
"specify ",
105+
num_dims(), " dimensions"));
106+
stride.clear();
107+
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
108+
}
80109
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
81110
errors::InvalidArgument("Input to ", type_string(),
82111
" operator must have ", num_dims(),
83112
" dimensions"));
84113

85114
const DataType type = input_type(0);
86115
xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow(
87-
input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize_,
88-
stride_, padding_);
116+
input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize,
117+
stride, padding_);
89118
ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape));
90119
}
91120

@@ -130,6 +159,10 @@ class MaxPool2DOp : public MaxPoolOp {
130159
}
131160
};
132161
REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
162+
REGISTER_XLA_OP(Name("MaxPoolV2")
163+
.CompileTimeConstInput("ksize")
164+
.CompileTimeConstInput("strides"),
165+
MaxPool2DOp);
133166

134167
class MaxPool3DOp : public MaxPoolOp {
135168
public:

tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
_SAME_PADDING = ["SAME", b"SAME"]
3636

3737

38-
def _stride_size(node):
38+
def _stride_size(node, name_to_node):
3939
"""Computes stride size given a TF node.
4040
4141
Args:
@@ -45,10 +45,20 @@ def _stride_size(node):
4545
stride_x: Stride size for horizontal direction (integer).
4646
stride_y: Stride size for vertical direction (integer).
4747
"""
48-
strides_attr = node.attr["strides"]
49-
logging.vlog(4, "strides_attr = %s", strides_attr)
50-
stride_y = strides_attr.list.i[1]
51-
stride_x = strides_attr.list.i[2]
48+
if node.op == "MaxPoolV2":
49+
strides_input_name = node.input[2]
50+
if not strides_input_name.endswith("/strides"):
51+
raise ValueError("Strides name does not end with '/strides'")
52+
strides_node = name_to_node[strides_input_name]
53+
value = strides_node.attr["value"]
54+
t = make_ndarray(value.tensor)
55+
stride_y = t[1]
56+
stride_x = t[2]
57+
else:
58+
strides_attr = node.attr["strides"]
59+
logging.vlog(4, "strides_attr = %s", strides_attr)
60+
stride_y = strides_attr.list.i[1]
61+
stride_x = strides_attr.list.i[2]
5262
return stride_x, stride_y
5363

5464

@@ -144,7 +154,7 @@ def _padding_size_conv_pool(node, kernel_size, stride, input_resolution=None):
144154
return total_padding, padding
145155

146156

147-
def _pool_kernel_size(node):
157+
def _pool_kernel_size(node, name_to_node):
148158
"""Computes kernel size given a TF pooling node.
149159
150160
Args:
@@ -157,13 +167,27 @@ def _pool_kernel_size(node):
157167
Raises:
158168
ValueError: If pooling is invalid.
159169
"""
160-
ksize = node.attr["ksize"]
161-
kernel_size_y = ksize.list.i[1]
162-
kernel_size_x = ksize.list.i[2]
163-
if ksize.list.i[0] != 1:
164-
raise ValueError("pool ksize for first dim is not 1")
165-
if ksize.list.i[3] != 1:
166-
raise ValueError("pool ksize for last dim is not 1")
170+
if node.op == "MaxPoolV2":
171+
ksize_input_name = node.input[1]
172+
if not ksize_input_name.endswith("/ksize"):
173+
raise ValueError("Kernel size name does not end with '/ksize'")
174+
ksize_node = name_to_node[ksize_input_name]
175+
value = ksize_node.attr["value"]
176+
t = make_ndarray(value.tensor)
177+
kernel_size_y = t[1]
178+
kernel_size_x = t[2]
179+
if t[0] != 1:
180+
raise ValueError("pool ksize for first dim is not 1")
181+
if t[3] != 1:
182+
raise ValueError("pool ksize for last dim is not 1")
183+
else:
184+
ksize = node.attr["ksize"]
185+
kernel_size_y = ksize.list.i[1]
186+
kernel_size_x = ksize.list.i[2]
187+
if ksize.list.i[0] != 1:
188+
raise ValueError("pool ksize for first dim is not 1")
189+
if ksize.list.i[3] != 1:
190+
raise ValueError("pool ksize for last dim is not 1")
167191
return kernel_size_x, kernel_size_y
168192

169193

@@ -243,7 +267,7 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False):
243267
logging.vlog(3, "node.op = %s", node.op)
244268
logging.vlog(4, "node = %s", node)
245269
if node.op == "Conv2D" or node.op == "DepthwiseConv2dNative":
246-
stride_x, stride_y = _stride_size(node)
270+
stride_x, stride_y = _stride_size(node, name_to_node)
247271
kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_node)
248272
# Compute the padding for this node separately for each direction.
249273
total_padding_x, padding_x = _padding_size_conv_pool(
@@ -260,9 +284,9 @@ def get_layer_params(node, name_to_node, input_resolution=None, force=False):
260284
stride_y = 1
261285
total_padding_x, padding_x, total_padding_y, padding_y = (
262286
_padding_size_pad_layer(node, name_to_node))
263-
elif node.op == "MaxPool" or node.op == "AvgPool":
264-
stride_x, stride_y = _stride_size(node)
265-
kernel_size_x, kernel_size_y = _pool_kernel_size(node)
287+
elif node.op == "MaxPool" or node.op == "MaxPoolV2" or node.op == "AvgPool":
288+
stride_x, stride_y = _stride_size(node, name_to_node)
289+
kernel_size_x, kernel_size_y = _pool_kernel_size(node, name_to_node)
266290
# Compute the padding for this node separately for each direction.
267291
total_padding_x, padding_x = _padding_size_conv_pool(
268292
node, kernel_size_x, stride_x, input_resolution[1]

tensorflow/contrib/specs/python/specs_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def testMpPower(self):
8787
self.assertEqual(tuple(result.shape), (1, 8, 8, 5))
8888
self.assertEqual(
8989
summaries.tf_spec_structure(spec, inputs),
90-
"_ maxpool maxpool maxpool")
90+
"_ _ _ maxpoolv2 _ _ maxpoolv2 _ _ maxpoolv2")
9191

9292
def testAbbrevPower(self):
9393
with self.test_session():
@@ -100,10 +100,10 @@ def testAbbrevPower(self):
100100
self.assertEqual(tuple(result.shape), (1, 8, 8, 5))
101101
self.assertEqual(
102102
summaries.tf_spec_structure(spec, inputs),
103-
"_ variablev2 conv variablev2 biasadd relu maxpool"
103+
"_ variablev2 conv variablev2 biasadd relu _ _ maxpoolv2"
104104
" variablev2 conv variablev2"
105-
" biasadd relu maxpool variablev2 conv variablev2"
106-
" biasadd relu maxpool")
105+
" biasadd relu _ _ maxpoolv2 variablev2 conv variablev2"
106+
" biasadd relu _ _ maxpoolv2")
107107

108108
def testAbbrevPower2(self):
109109
with self.test_session():
@@ -117,10 +117,10 @@ def testAbbrevPower2(self):
117117
self.assertEqual(tuple(result.shape), (1, 8, 8, 5))
118118
self.assertEqual(
119119
summaries.tf_spec_structure(spec, inputs),
120-
"_ variablev2 conv variablev2 biasadd relu maxpool"
120+
"_ variablev2 conv variablev2 biasadd relu _ _ maxpoolv2"
121121
" variablev2 conv variablev2 biasadd relu"
122-
" maxpool variablev2 conv variablev2 biasadd relu"
123-
" maxpool")
122+
" _ _ maxpoolv2 variablev2 conv variablev2 biasadd relu"
123+
" _ _ maxpoolv2")
124124

125125
def testConc(self):
126126
with self.test_session():

tensorflow/python/kernel_tests/pooling_ops_test.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,16 +1811,17 @@ def testOpEdgeCases(self):
18111811
if test.is_gpu_available():
18121812
pool_funcs.append(nn_ops.max_pool_with_argmax)
18131813
for pool_func in pool_funcs:
1814-
# Illegal strides.
1815-
with self.assertRaisesRegexp(
1816-
errors_impl.UnimplementedError,
1817-
"Pooling is not yet supported on the batch"):
1818-
sess.run(
1819-
pool_func(
1820-
array_ops.placeholder(dtypes.float32),
1821-
ksize=[1, 1, 1, 1],
1822-
strides=[2, 1, 1, 1],
1823-
padding="SAME"))
1814+
if pool_func != nn_ops.max_pool:
1815+
# Illegal strides.
1816+
with self.assertRaisesRegexp(
1817+
errors_impl.UnimplementedError,
1818+
"Pooling is not yet supported on the batch"):
1819+
sess.run(
1820+
pool_func(
1821+
array_ops.placeholder(dtypes.float32),
1822+
ksize=[1, 1, 1, 1],
1823+
strides=[2, 1, 1, 1],
1824+
padding="SAME"))
18241825

18251826
# Filter too large.
18261827
with self.assertRaisesRegexp(ValueError, "Negative dimension size"):

tensorflow/python/ops/nn_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,12 +2070,12 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
20702070
"""
20712071
with ops.name_scope(name, "MaxPool", [value]) as name:
20722072
value = ops.convert_to_tensor(value, name="input")
2073-
return gen_nn_ops._max_pool(value,
2074-
ksize=ksize,
2075-
strides=strides,
2076-
padding=padding,
2077-
data_format=data_format,
2078-
name=name)
2073+
return gen_nn_ops._max_pool_v2(value,
2074+
ksize=ksize,
2075+
strides=strides,
2076+
padding=padding,
2077+
data_format=data_format,
2078+
name=name)
20792079

20802080

20812081
@ops.RegisterStatistics("Conv2D", "flops")

0 commit comments

Comments
 (0)