Skip to content

Commit e1cf526

Browse files
BowenBaojignparmgramalingam
authored
Update Reshape op: Add 'allowzero' attribute - Cont (onnx#3113)
* Add 'allowzero' flag to reshape operator Signed-off-by: BowenBao <[email protected]> * Add shape inference test Signed-off-by: BowenBao <[email protected]> * fix flake errors Signed-off-by: BowenBao <[email protected]> * fix mypy complaint Signed-off-by: BowenBao <[email protected]> * Update TestCoverage Signed-off-by: BowenBao <[email protected]> * Fix tabs Signed-off-by: BowenBao <[email protected]> * Fix tabs Signed-off-by: BowenBao <[email protected]> * fix typo Signed-off-by: BowenBao <[email protected]> * restart build Signed-off-by: BowenBao <[email protected]> * restart build Signed-off-by: BowenBao <[email protected]> * restart build Signed-off-by: BowenBao <[email protected]> * kick off build Signed-off-by: BowenBao <[email protected]> * kick off build Signed-off-by: BowenBao <[email protected]> * PR feedback Signed-off-by: BowenBao <[email protected]> * doc changes after merge Signed-off-by: BowenBao <[email protected]> * PR feedback Signed-off-by: BowenBao <[email protected]> * update documentation Signed-off-by: BowenBao <[email protected]> * Upgrade to opset 14 Signed-off-by: BowenBao <[email protected]> * update test case Signed-off-by: BowenBao <[email protected]> * revert unnecessary version bump Signed-off-by: BowenBao <[email protected]> * fix test case Signed-off-by: BowenBao <[email protected]> * rebase with master Signed-off-by: BowenBao <[email protected]> Co-authored-by: Jignesh Parmar <[email protected]> Co-authored-by: G. Ramalingam <[email protected]>
1 parent bf2f97d commit e1cf526

File tree

12 files changed

+361
-34
lines changed

12 files changed

+361
-34
lines changed

docs/Changelog.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18679,6 +18679,50 @@ This version of the operator has been available since version 14 of the default
1867918679
<dd>Constrain input and output types to signed numeric tensors.</dd>
1868018680
</dl>
1868118681

18682+
### <a name="Reshape-14"></a>**Reshape-14**</a>
18683+
18684+
Reshape the input tensor similar to numpy.reshape.
18685+
First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor.
18686+
At most one dimension of the new shape can be -1. In this case, the value is
18687+
inferred from the size of the tensor and the remaining dimensions. A dimension
18688+
could also be 0, in which case the actual dimension value is unchanged (i.e. taken
18689+
from the input tensor). If 'allowzero' is set, and the new shape includes 0, the
18690+
dimension will be set explicitly to zero (i.e. not taken from input tensor)
18691+
18692+
#### Version
18693+
18694+
This version of the operator has been available since version 14 of the default ONNX operator set.
18695+
18696+
#### Attributes
18697+
18698+
<dl>
18699+
<dt><tt>allowzero</tt> : int (default is 0)</dt>
18700+
<dd>(Optional) By default, when any value in the 'shape' input is equal to zero the corresponding dimension value is copied from the input tensor dynamically. allowzero=1 indicates that if any value in the 'shape' input is set to zero, the zero value is honored, similar to NumPy.</dd>
18701+
</dl>
18702+
18703+
#### Inputs
18704+
18705+
<dl>
18706+
<dt><tt>data</tt> (differentiable) : T</dt>
18707+
<dd>An input tensor.</dd>
18708+
<dt><tt>shape</tt> (non-differentiable) : tensor(int64)</dt>
18709+
<dd>Specified shape for output.</dd>
18710+
</dl>
18711+
18712+
#### Outputs
18713+
18714+
<dl>
18715+
<dt><tt>reshaped</tt> (differentiable) : T</dt>
18716+
<dd>Reshaped data.</dd>
18717+
</dl>
18718+
18719+
#### Type Constraints
18720+
18721+
<dl>
18722+
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(bfloat16), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
18723+
<dd>Constrain input and output types to all tensor types.</dd>
18724+
</dl>
18725+
1868218726
# ai.onnx.preview.training
1868318727
## Version 1 of the 'ai.onnx.preview.training' operator set
1868418728
### <a name="ai.onnx.preview.training.Adagrad-1"></a>**ai.onnx.preview.training.Adagrad-1**</a>

docs/Operators.md

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ For an operator input/output's differentiability, it can be differentiable,
117117
|<a href="#ReduceSum">ReduceSum</a>|<a href="Changelog.md#ReduceSum-13">13</a>, <a href="Changelog.md#ReduceSum-11">11</a>, <a href="Changelog.md#ReduceSum-1">1</a>|
118118
|<a href="#ReduceSumSquare">ReduceSumSquare</a>|<a href="Changelog.md#ReduceSumSquare-13">13</a>, <a href="Changelog.md#ReduceSumSquare-11">11</a>, <a href="Changelog.md#ReduceSumSquare-1">1</a>|
119119
|<a href="#Relu">Relu</a>|<a href="Changelog.md#Relu-14">14</a>, <a href="Changelog.md#Relu-13">13</a>, <a href="Changelog.md#Relu-6">6</a>, <a href="Changelog.md#Relu-1">1</a>|
120-
|<a href="#Reshape">Reshape</a>|<a href="Changelog.md#Reshape-13">13</a>, <a href="Changelog.md#Reshape-5">5</a>, <a href="Changelog.md#Reshape-1">1</a>|
120+
|<a href="#Reshape">Reshape</a>|<a href="Changelog.md#Reshape-14">14</a>, <a href="Changelog.md#Reshape-13">13</a>, <a href="Changelog.md#Reshape-5">5</a>, <a href="Changelog.md#Reshape-1">1</a>|
121121
|<a href="#Resize">Resize</a>|<a href="Changelog.md#Resize-13">13</a>, <a href="Changelog.md#Resize-11">11</a>, <a href="Changelog.md#Resize-10">10</a>|
122122
|<a href="#ReverseSequence">ReverseSequence</a>|<a href="Changelog.md#ReverseSequence-10">10</a>|
123123
|<a href="#RoiAlign">RoiAlign</a>|<a href="Changelog.md#RoiAlign-10">10</a>|
@@ -15673,13 +15673,21 @@ expect(node, inputs=[x], outputs=[y],
1567315673
At most one dimension of the new shape can be -1. In this case, the value is
1567415674
inferred from the size of the tensor and the remaining dimensions. A dimension
1567515675
could also be 0, in which case the actual dimension value is unchanged (i.e. taken
15676-
from the input tensor).
15676+
from the input tensor). If 'allowzero' is set, and the new shape includes 0, the
15677+
dimension will be set explicitly to zero (i.e. not taken from input tensor)
1567715678

1567815679
#### Version
1567915680

15680-
This version of the operator has been available since version 13 of the default ONNX operator set.
15681+
This version of the operator has been available since version 14 of the default ONNX operator set.
15682+
15683+
Other versions of this operator: <a href="Changelog.md#Reshape-1">1</a>, <a href="Changelog.md#Reshape-5">5</a>, <a href="Changelog.md#Reshape-13">13</a>
15684+
15685+
#### Attributes
1568115686

15682-
Other versions of this operator: <a href="Changelog.md#Reshape-1">1</a>, <a href="Changelog.md#Reshape-5">5</a>
15687+
<dl>
15688+
<dt><tt>allowzero</tt> : int (default is 0)</dt>
15689+
<dd>(Optional) By default, when any value in the 'shape' input is equal to zero the corresponding dimension value is copied from the input tensor dynamically. allowzero=1 indicates that if any value in the 'shape' input is set to zero, the zero value is honored, similar to NumPy.</dd>
15690+
</dl>
1568315691

1568415692
#### Inputs
1568515693

@@ -15707,6 +15715,34 @@ Other versions of this operator: <a href="Changelog.md#Reshape-1">1</a>, <a href
1570715715

1570815716
#### Examples
1570915717

15718+
<details>
15719+
<summary>allowzero</summary>
15720+
15721+
```python
15722+
original_shape = [0, 3, 4]
15723+
test_cases = {
15724+
'allowzero_reordered': np.array([3, 4, 0], dtype=np.int64),
15725+
}
15726+
data = np.random.random_sample(original_shape).astype(np.float32)
15727+
15728+
for test_name, shape in test_cases.items():
15729+
node = onnx.helper.make_node(
15730+
'Reshape',
15731+
inputs=['data', 'shape'],
15732+
outputs=['reshaped'],
15733+
allowzero=1, # if allowzero=1, final shape = (3, 4, 0)
15734+
# if allowzero=0, final shape = (3, 4, 4)
15735+
)
15736+
15737+
reshaped = reshape_reference_implementation(data, shape, allowzero=1)
15738+
15739+
expect(node, inputs=[data, shape], outputs=[reshaped],
15740+
name='test_reshape_' + test_name)
15741+
```
15742+
15743+
</details>
15744+
15745+
1571015746
<details>
1571115747
<summary>reshape</summary>
1571215748

docs/TestCoverage.md

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9631,7 +9631,33 @@ expect(node, inputs=[x], outputs=[y],
96319631

96329632

96339633
### Reshape
9634-
There are 1 test cases, listed as following:
9634+
There are 2 test cases, listed as following:
9635+
<details>
9636+
<summary>allowzero</summary>
9637+
9638+
```python
9639+
original_shape = [0, 3, 4]
9640+
test_cases = {
9641+
'allowzero_reordered': np.array([3, 4, 0], dtype=np.int64),
9642+
}
9643+
data = np.random.random_sample(original_shape).astype(np.float32)
9644+
9645+
for test_name, shape in test_cases.items():
9646+
node = onnx.helper.make_node(
9647+
'Reshape',
9648+
inputs=['data', 'shape'],
9649+
outputs=['reshaped'],
9650+
allowzero=1, # if allowzero=1, final shape = (3, 4, 0)
9651+
# if allowzero=0, final shape = (3, 4, 4)
9652+
)
9653+
9654+
reshaped = reshape_reference_implementation(data, shape, allowzero=1)
9655+
9656+
expect(node, inputs=[data, shape], outputs=[reshaped],
9657+
name='test_reshape_' + test_name)
9658+
```
9659+
9660+
</details>
96359661
<details>
96369662
<summary>reshape</summary>
96379663

onnx/backend/test/case/node/reshape.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,21 @@
1212
from . import expect
1313

1414

15-
def reshape_reference_implementation(data, shape): # type: (np.ndarray, np.ndarray) -> np.ndarray
15+
def reshape_reference_implementation(data, shape, allowzero=0): # type: (np.ndarray, np.ndarray, int) -> np.ndarray
1616
# replace zeros with corresponding dim size
17-
# we need to do this because np.reshape doesn't support 0
17+
# we need to do this because np.reshape doesn't support 0 by default unless 'allowzero' is set
1818
new_shape = np.copy(shape)
19-
zeros_index = np.where(shape == 0)
20-
new_shape[zeros_index] = np.array(data.shape)[zeros_index]
19+
if allowzero == 0:
20+
zeros_index = np.where(shape == 0)
21+
new_shape[zeros_index] = np.array(data.shape)[zeros_index]
2122
reshaped = np.reshape(data, new_shape)
2223
return reshaped
2324

2425

2526
class Reshape(Base):
2627

2728
@staticmethod
28-
def export(): # type: () -> None
29+
def export_reshape(): # type: () -> None
2930
original_shape = [2, 3, 4]
3031
test_cases = {
3132
'reordered_all_dims': np.array([4, 2, 3], dtype=np.int64),
@@ -51,3 +52,25 @@ def export(): # type: () -> None
5152

5253
expect(node, inputs=[data, shape], outputs=[reshaped],
5354
name='test_reshape_' + test_name)
55+
56+
@staticmethod
57+
def export_allowzero(): # type: () -> None
58+
original_shape = [0, 3, 4]
59+
test_cases = {
60+
'allowzero_reordered': np.array([3, 4, 0], dtype=np.int64),
61+
}
62+
data = np.random.random_sample(original_shape).astype(np.float32)
63+
64+
for test_name, shape in test_cases.items():
65+
node = onnx.helper.make_node(
66+
'Reshape',
67+
inputs=['data', 'shape'],
68+
outputs=['reshaped'],
69+
allowzero=1, # if allowzero=1, final shape = (3, 4, 0)
70+
# if allowzero=0, final shape = (3, 4, 4)
71+
)
72+
73+
reshaped = reshape_reference_implementation(data, shape, allowzero=1)
74+
75+
expect(node, inputs=[data, shape], outputs=[reshaped],
76+
name='test_reshape_' + test_name)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

onnx/defs/operator_sets.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,6 @@ class OpSet_Onnx_ver13 {
902902
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 13, Shape)>());
903903
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 13, Size)>());
904904
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 13, Concat)>());
905-
906905
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 13, Split)>());
907906
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 13, Slice)>());
908907
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 13, Transpose)>());
@@ -931,13 +930,15 @@ class OpSet_Onnx_ver13 {
931930
// Forward declarations for ai.onnx version 14
932931
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 14, CumSum);
933932
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 14, Relu);
933+
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 14, Reshape);
934934

935935
// Iterate over schema from ai.onnx version 14
936936
class OpSet_Onnx_ver14 {
937937
public:
938938
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
939939
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 14, CumSum)>());
940940
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 14, Relu)>());
941+
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 14, Reshape)>());
941942
}
942943
};
943944

onnx/defs/tensor/defs.cc

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,28 @@ ONNX_OPERATOR_SET_SCHEMA(
103103
}
104104
}));
105105

106-
static const char* Reshape_ver13_doc = R"DOC(
106+
static const char* Reshape_ver14_doc = R"DOC(
107107
Reshape the input tensor similar to numpy.reshape.
108108
First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor.
109109
At most one dimension of the new shape can be -1. In this case, the value is
110110
inferred from the size of the tensor and the remaining dimensions. A dimension
111111
could also be 0, in which case the actual dimension value is unchanged (i.e. taken
112-
from the input tensor).)DOC";
112+
from the input tensor). If 'allowzero' is set, and the new shape includes 0, the
113+
dimension will be set explicitly to zero (i.e. not taken from input tensor))DOC";
113114

114115
ONNX_OPERATOR_SET_SCHEMA(
115116
Reshape,
116-
13,
117+
14,
117118
OpSchema()
118-
.SetDoc(Reshape_ver13_doc)
119+
.SetDoc(Reshape_ver14_doc)
120+
.Attr(
121+
"allowzero",
122+
"(Optional) By default, when any value in the 'shape' input is equal to zero "
123+
"the corresponding dimension value is copied from the input tensor dynamically. "
124+
"allowzero=1 indicates that if any value in the 'shape' input is set to zero, "
125+
"the zero value is honored, similar to NumPy.",
126+
AttributeProto::INT,
127+
static_cast<int64_t>(0))
119128
.Input(0,
120129
"data",
121130
"An input tensor.",
@@ -152,6 +161,7 @@ ONNX_OPERATOR_SET_SCHEMA(
152161
if (!targetShapeInitializer) {
153162
return;
154163
}
164+
int allowzero = static_cast<int>(getAttribute(ctx, "allowzero", 0));
155165
// Make targetShape (0 -> same as originalShape, -1 -> inferred).
156166
// The targetShape vector represents the specified shape for output.
157167
std::vector<int64_t> targetShape;
@@ -167,9 +177,9 @@ ONNX_OPERATOR_SET_SCHEMA(
167177
}
168178

169179
// Iterate through targetShape, adding dimensions in the outputShape
170-
// TensorProto. If the targertShape dimension is -1, we do not set the
180+
// TensorProto. If the targetShape dimension is -1, we do not set the
171181
// dimension value in this iteration, but we record the Dimension. If
172-
// targertShape dimension is 0, we attempt to propagate the dimension
182+
// targetShape dimension is 0, we attempt to propagate the dimension
173183
// value/param. If the value cannot be inferred, we set the flag in
174184
// the unresolveZeros vector. If targetShape dimension is positive, we
175185
// set the dimension value in the outputShape. We track the product of
@@ -190,30 +200,39 @@ ONNX_OPERATOR_SET_SCHEMA(
190200
// this dimension to potentially be filled in later.
191201
if (negativeOneDim) {
192202
fail_shape_inference(
193-
"Target shape may not have multiple -1 dimensions");
203+
"Target shape may not have multiple -1 dimensions.");
194204
}
195205
negativeOneDim = new_dim;
196206
} else if (targetShape[i] == 0) {
197207
// Check if data input has a shape and if the index i is within
198208
// its bounds. If these conditions are satisfied, any dimension
199209
// value/param should be propogated. If dimension value cannot be
200210
// inferred, set the corresponding unresolvedZeros flag to true.
201-
unresolvedZeros[i] = true;
202-
if (dataInputTensorType.has_shape()) {
203-
if (i >= dataInputTensorType.shape().dim_size()) {
204-
fail_shape_inference("Invalid position of 0");
205-
}
206-
if (dataInputTensorType.shape().dim(i).has_dim_value()) {
207-
const auto& dim_value =
208-
dataInputTensorType.shape().dim(i).dim_value();
209-
new_dim->set_dim_value(dim_value);
210-
outputProduct *= dim_value;
211-
unresolvedZeros[i] = false;
212-
} else if (dataInputTensorType.shape().dim(i).has_dim_param()) {
213-
const auto& dim_param =
214-
dataInputTensorType.shape().dim(i).dim_param();
215-
new_dim->set_dim_param(dim_param);
211+
// If allowzero is set however, do not propagate values, since output
212+
// dimension is explicitly zero.
213+
if (allowzero == 0) {
214+
unresolvedZeros[i] = true;
215+
if (dataInputTensorType.has_shape()) {
216+
if (i >= dataInputTensorType.shape().dim_size()) {
217+
fail_shape_inference("Invalid position of 0.");
218+
}
219+
if (dataInputTensorType.shape().dim(i).has_dim_value()) {
220+
const auto& dim_value =
221+
dataInputTensorType.shape().dim(i).dim_value();
222+
new_dim->set_dim_value(dim_value);
223+
outputProduct *= dim_value;
224+
unresolvedZeros[i] = false;
225+
} else if (dataInputTensorType.shape()
226+
.dim(i)
227+
.has_dim_param()) {
228+
const auto& dim_param =
229+
dataInputTensorType.shape().dim(i).dim_param();
230+
new_dim->set_dim_param(dim_param);
231+
}
216232
}
233+
} else {
234+
new_dim->set_dim_value(targetShape[i]);
235+
outputProduct *= targetShape[i];
217236
}
218237
} else if (targetShape[i] > 0) {
219238
// Set the dimension value to targetShape[i]
@@ -234,7 +253,7 @@ ONNX_OPERATOR_SET_SCHEMA(
234253
// that are not marked by unresolvedZeros. If not possible, set the
235254
// inputProductValid flag to false.
236255
if (!outputProduct) {
237-
fail_shape_inference("Invalid Target shape product of 0");
256+
fail_shape_inference("Invalid Target shape product of 0. Product cannot be 0 in combination with -1");
238257
}
239258
int64_t inputProduct = 1;
240259
bool inputProductValid = true;

0 commit comments

Comments
 (0)