diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc index 881c1a1bcf3..f9a50002ccc 100644 --- a/onnx/defs/nn/defs.cc +++ b/onnx/defs/nn/defs.cc @@ -741,6 +741,8 @@ void convTransposeShapeInference(InferenceContext& ctx) { return; } + int64_t group = getAttribute(ctx, "group", 1); + auto input_shape = ctx.getInputType(0)->tensor_type().shape(); if (input_shape.dim_size() < 2) { return; // Input tensor should have at least two dimensions. @@ -751,7 +753,11 @@ void convTransposeShapeInference(InferenceContext& ctx) { std::vector dilations; if (getRepeatedAttribute(ctx, "dilations", dilations)) { - return; // we don't handle the dialations. + for (auto i : dilations) + { + if (i != 1) + return; // we don't handle dialations not 1. + } } std::vector pads; @@ -788,10 +794,13 @@ void convTransposeShapeInference(InferenceContext& ctx) { } std::vector output_shape; + bool output_shape_presented = true; if (getRepeatedAttribute(ctx, "output_shape", output_shape)) { if (output_shape.size() != n_input_dims) { return; } + } else { + output_shape_presented = false; } std::vector output_padding; @@ -809,37 +818,39 @@ void convTransposeShapeInference(InferenceContext& ctx) { *final_output_shape->add_dim() = input_shape.dim(0); *final_output_shape->add_dim() = ctx.getInputType(1)->tensor_type().shape().dim( - 1); // channels should be the second dim of second input. + 1) * group; // channels should be the second dim of second input multiply group. - int size_of_output = static_cast(output_shape.size()); - if (size_of_output > 0) { + int size_of_output; + if (output_shape_presented) { + size_of_output = static_cast(output_shape.size()); for (int i = 0; i < size_of_output; ++i) { - if (output_shape[i] < input_shape.dim(i + 2).dim_value()) { - // TODO: throw exception? - return; // output shape value cannot be smaller than the input shape - // value + if (input_shape.dim(i + 2).has_dim_value()) { + if (output_shape[i] < input_shape.dim(i + 2).dim_value()) { + // TODO: throw exception? + return; // output shape value cannot be smaller than the input shape + // value + } } - final_output_shape->add_dim()->set_dim_value(output_shape[i]); } - return; // assume no need to proceed further when the output shape is given. + return; } - - int kernel_shape_size = static_cast(kernel_shape.size()); - for (int i = 0; i < kernel_shape_size; ++i) { - auto newdim = final_output_shape->add_dim(); - if (!input_shape.dim(2 + i).has_dim_value()) { - continue; + else + { + size_of_output = input_shape.dim_size() - 2; + for (int i = 0; i < size_of_output; ++i) + { + if (input_shape.dim(i + 2).has_dim_value()) { + int64_t output_shape_dim = + strides[i] * (input_shape.dim(i + 2).dim_value() - 1) + + output_padding[i] + kernel_shape[i] - pads[i] - + pads[i + n_input_dims]; + final_output_shape->add_dim()->set_dim_value(output_shape_dim); + } else{ + final_output_shape->add_dim(); + } } - - int64_t newdim_value = - strides[i] * (input_shape.dim(2 + i).dim_value() - 1); - newdim_value += (output_padding[i] + kernel_shape[i]); - newdim_value -= pads[i]; - newdim_value -= pads[i + kernel_shape_size]; - - // add in the initial position - newdim->set_dim_value(newdim_value); + return; } } diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py index 473438e60c6..a3773d94bd8 100644 --- a/onnx/test/shape_inference_test.py +++ b/onnx/test/shape_inference_test.py @@ -993,6 +993,22 @@ def test_conv_transpose_with_kernel_shape(self): # type: () -> None []) self._assert_inferred(graph, [make_tensor_value_info('Y', TensorProto.FLOAT, (25, 32, 30, 30))]) + def test_conv_transpose_with_group(self): # type: () -> None + graph = self._make_graph( + [('X', TensorProto.FLOAT, (25, 48, 16, 16)), + ('W', TensorProto.FLOAT, (48, 32, 3, 3))], + [make_node('ConvTranspose', ['X', 'W'], 'Y', strides=[2, 2], pads=[1, 1, 2, 2], group=2)], + []) + self._assert_inferred(graph, [make_tensor_value_info('Y', TensorProto.FLOAT, (25, 64, 30, 30))]) + + def test_conv_transpose_with_group_and_output_shape(self): # type: () -> None + graph = self._make_graph( + [('X', TensorProto.FLOAT, (25, 48, 16, 16)), + ('W', TensorProto.FLOAT, (48, 32, 3, 3))], + [make_node('ConvTranspose', ['X', 'W'], 'Y', strides=[2, 2], pads=[1, 1, 2, 2], group=2, output_shape=[36, 36])], + []) + self._assert_inferred(graph, [make_tensor_value_info('Y', TensorProto.FLOAT, (25, 64, 36, 36))]) + def test_mvn_function_output_shape(self): # type: () -> None graph = self._make_graph( [('X', TensorProto.FLOAT, (25, 48, 16, 16))],