Skip to content

Commit bf2f97d

Browse files
Add sequence type for identity op (onnx#3170)
* add sequence type for identity op Signed-off-by: BowenBao <[email protected]> * update recursive call Signed-off-by: BowenBao <[email protected]> Co-authored-by: G. Ramalingam <[email protected]>
1 parent 8561a9a commit bf2f97d

File tree

10 files changed

+130
-26
lines changed

10 files changed

+130
-26
lines changed

docs/Changelog.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16069,22 +16069,22 @@ This version of the operator has been available since version 13 of the default
1606916069
#### Inputs
1607016070

1607116071
<dl>
16072-
<dt><tt>input</tt> (differentiable) : T</dt>
16072+
<dt><tt>input</tt> (differentiable) : V</dt>
1607316073
<dd>Input tensor</dd>
1607416074
</dl>
1607516075

1607616076
#### Outputs
1607716077

1607816078
<dl>
16079-
<dt><tt>output</tt> (differentiable) : T</dt>
16079+
<dt><tt>output</tt> (differentiable) : V</dt>
1608016080
<dd>Tensor to copy input into.</dd>
1608116081
</dl>
1608216082

1608316083
#### Type Constraints
1608416084

1608516085
<dl>
16086-
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(bfloat16), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
16087-
<dd>Constrain input and output types to all tensor types.</dd>
16086+
<dt><tt>V</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(bfloat16), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
16087+
<dd>Constrain input and output types to all tensor and sequence types.</dd>
1608816088
</dl>
1608916089

1609016090
### <a name="If-13"></a>**If-13**</a>

docs/Operators.md

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7473,22 +7473,22 @@ Other versions of this operator: <a href="Changelog.md#Identity-1">1</a>
74737473
#### Inputs
74747474

74757475
<dl>
7476-
<dt><tt>input</tt> (differentiable) : T</dt>
7476+
<dt><tt>input</tt> (differentiable) : V</dt>
74777477
<dd>Input tensor</dd>
74787478
</dl>
74797479

74807480
#### Outputs
74817481

74827482
<dl>
7483-
<dt><tt>output</tt> (differentiable) : T</dt>
7483+
<dt><tt>output</tt> (differentiable) : V</dt>
74847484
<dd>Tensor to copy input into.</dd>
74857485
</dl>
74867486

74877487
#### Type Constraints
74887488

74897489
<dl>
7490-
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(bfloat16), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
7491-
<dd>Constrain input and output types to all tensor types.</dd>
7490+
<dt><tt>V</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(bfloat16), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
7491+
<dd>Constrain input and output types to all tensor and sequence types.</dd>
74927492
</dl>
74937493

74947494

@@ -7516,6 +7516,33 @@ expect(node, inputs=[data], outputs=[data],
75167516
</details>
75177517

75187518

7519+
<details>
7520+
<summary>sequence</summary>
7521+
7522+
```python
7523+
node = onnx.helper.make_node(
7524+
'Identity',
7525+
inputs=['x'],
7526+
outputs=['y'],
7527+
)
7528+
7529+
data = [
7530+
np.array([[[
7531+
[1, 2],
7532+
[3, 4],
7533+
]]], dtype=np.float32),
7534+
np.array([[[
7535+
[2, 3],
7536+
[1, 5],
7537+
]]], dtype=np.float32)]
7538+
7539+
expect(node, inputs=[data], outputs=[data], name='test_identity_sequence',
7540+
opset_imports=[onnx.helper.make_opsetid("", 13)])
7541+
```
7542+
7543+
</details>
7544+
7545+
75197546
### <a name="If"></a><a name="if">**If**</a>
75207547

75217548
If conditional

docs/TestCoverage.md

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4595,7 +4595,7 @@ expect(node, inputs=[x], outputs=[y],
45954595

45964596

45974597
### Identity
4598-
There are 1 test cases, listed as following:
4598+
There are 2 test cases, listed as following:
45994599
<details>
46004600
<summary>identity</summary>
46014601

@@ -4615,6 +4615,31 @@ expect(node, inputs=[data], outputs=[data],
46154615
name='test_identity')
46164616
```
46174617

4618+
</details>
4619+
<details>
4620+
<summary>sequence</summary>
4621+
4622+
```python
4623+
node = onnx.helper.make_node(
4624+
'Identity',
4625+
inputs=['x'],
4626+
outputs=['y'],
4627+
)
4628+
4629+
data = [
4630+
np.array([[[
4631+
[1, 2],
4632+
[3, 4],
4633+
]]], dtype=np.float32),
4634+
np.array([[[
4635+
[2, 3],
4636+
[1, 5],
4637+
]]], dtype=np.float32)]
4638+
4639+
expect(node, inputs=[data], outputs=[data], name='test_identity_sequence',
4640+
opset_imports=[onnx.helper.make_opsetid("", 13)])
4641+
```
4642+
46184643
</details>
46194644

46204645

onnx/backend/test/case/node/identity.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,24 @@ def export(): # type: () -> None
2929

3030
expect(node, inputs=[data], outputs=[data],
3131
name='test_identity')
32+
33+
@staticmethod
34+
def export_sequence(): # type: () -> None
35+
node = onnx.helper.make_node(
36+
'Identity',
37+
inputs=['x'],
38+
outputs=['y'],
39+
)
40+
41+
data = [
42+
np.array([[[
43+
[1, 2],
44+
[3, 4],
45+
]]], dtype=np.float32),
46+
np.array([[[
47+
[2, 3],
48+
[1, 5],
49+
]]], dtype=np.float32)]
50+
51+
expect(node, inputs=[data], outputs=[data], name='test_identity_sequence',
52+
opset_imports=[onnx.helper.make_opsetid("", 13)])
96 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.

onnx/defs/shape_inference.h

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -397,25 +397,35 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType(
397397
*dim = input_type->tensor_type().shape().dim(static_cast<int>(fromDimIndex));
398398
}
399399

400+
inline void propagateShape(const TypeProto* from_type, TypeProto* to_type) {
401+
if (TypeProto::kTensorType == from_type->value_case() &&
402+
TypeProto::kTensorType == to_type->value_case()) {
403+
// If input shape is "uknown", the corresponding should be "unknown" too.
404+
// The way to make output shape unknown is not to assign it any value.
405+
if (hasShape(*from_type)) {
406+
*to_type->mutable_tensor_type()->mutable_shape() =
407+
from_type->tensor_type().shape();
408+
}
409+
} else if (TypeProto::kSequenceType == from_type->value_case() &&
410+
TypeProto::kSequenceType == to_type->value_case()) {
411+
propagateShape(&from_type->sequence_type().elem_type(), to_type->mutable_sequence_type()->mutable_elem_type());
412+
} else {
413+
fail_shape_inference(
414+
"Mismatch between source and target type. Source=",
415+
from_type->value_case(),
416+
" Target=",
417+
to_type->value_case());
418+
}
419+
}
420+
400421
inline void propagateShapeFromInputToOutput(
401422
InferenceContext& ctx,
402423
size_t inputIndex,
403424
size_t outputIndex) {
404425
auto output_type = ctx.getOutputType(outputIndex);
405426
auto input_type = ctx.getInputType(inputIndex);
406427

407-
if (TypeProto::kTensorType != input_type->value_case() ||
408-
TypeProto::kTensorType != output_type->value_case()) {
409-
fail_shape_inference(ONNX_NAMESPACE::to_string(
410-
ctx.getInputType(inputIndex)->tensor_type().shape().dim_size()));
411-
}
412-
413-
// If input shape is "uknown", the corresponding should be "unknown" too.
414-
// The way to make output shape unknown is not to assign it any value.
415-
if (hasShape(*input_type)) {
416-
*output_type->mutable_tensor_type()->mutable_shape() =
417-
input_type->tensor_type().shape();
418-
}
428+
propagateShape(input_type, output_type);
419429
}
420430

421431
inline void propagateShapeAndTypeFromFirstInput(InferenceContext& ctx) {

onnx/defs/tensor/defs.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2195,7 +2195,7 @@ ONNX_OPERATOR_SET_SCHEMA(
21952195
0,
21962196
"input",
21972197
"Input tensor",
2198-
"T",
2198+
"V",
21992199
OpSchema::Single,
22002200
true,
22012201
1,
@@ -2204,15 +2204,20 @@ ONNX_OPERATOR_SET_SCHEMA(
22042204
0,
22052205
"output",
22062206
"Tensor to copy input into.",
2207-
"T",
2207+
"V",
22082208
OpSchema::Single,
22092209
true,
22102210
1,
22112211
OpSchema::Differentiable)
22122212
.TypeConstraint(
2213-
"T",
2214-
OpSchema::all_tensor_types_with_bfloat(),
2215-
"Constrain input and output types to all tensor types.")
2213+
"V",
2214+
[](){
2215+
auto t = OpSchema::all_tensor_types_with_bfloat();
2216+
auto s = OpSchema::all_tensor_sequence_types();
2217+
t.insert(t.end(), s.begin(), s.end());
2218+
return t;
2219+
}(),
2220+
"Constrain input and output types to all tensor and sequence types.")
22162221
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
22172222

22182223
static const char* Compress_ver11_doc = R"DOC(

onnx/test/shape_inference_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,22 @@ def test_average_pool_auto_pads(self): # type: () -> None
782782
def test_relu(self): # type: () -> None
783783
self._identity_prop('Relu')
784784

785+
def test_identity(self): # type: () -> None
786+
self._identity_prop('Identity')
787+
788+
def test_identity_sequence(self): # type: () -> None
789+
graph = self._make_graph(
790+
[('input1', TensorProto.FLOAT, (2, 3, 4)),
791+
('input2', TensorProto.FLOAT, (2, 3, 4)),
792+
('input3', TensorProto.FLOAT, (2, 5, 4))],
793+
[make_node('SequenceConstruct', ['input1', 'input2', 'input3'], ['in_sequence']),
794+
make_node('Identity', ['in_sequence'], ['output_sequence'])],
795+
[])
796+
self._assert_inferred(
797+
graph,
798+
[make_sequence_value_info('in_sequence', TensorProto.FLOAT, (2, None, 4)), # type: ignore
799+
make_sequence_value_info('output_sequence', TensorProto.FLOAT, (2, None, 4))]) # type: ignore
800+
785801
def test_add(self): # type: () -> None
786802
graph = self._make_graph(
787803
[('x', TensorProto.FLOAT, (30, 4, 5)),

0 commit comments

Comments
 (0)