-
Notifications
You must be signed in to change notification settings - Fork 98
fix onnxscript export for irfft #2770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
simonbyrne
wants to merge
2
commits into
microsoft:main
Choose a base branch
from
simonbyrne:sbyrne/irfft
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |
| # -------------------------------------------------------------------------- | ||
| # pylint: disable=W0221,W0222,R0901,W0237 | ||
| # mypy: disable-error-code=override | ||
| # ruff: noqa: N801,E741,RUF036,D214,D402,D405,D411,D412,D416,D417 | ||
Check warningCode scanning / lintrunner RUFF/RUF100 Warning
Unused noqa directive (unused: N801, E741, RUF036, D214, D402, D405, D411, D412, D416, D417).
See https://docs.astral.sh/ruff/rules/unused-noqa |
||
| # -------------------------------------------------------------------------- | ||
|
|
||
| from __future__ import annotations | ||
|
|
@@ -39,11 +40,16 @@ | |
| from onnxscript.onnx_opset._impl.opset22 import Opset22 | ||
| from onnxscript.onnx_opset._impl.opset23 import Opset23 | ||
| from onnxscript.onnx_opset._impl.opset24 import Opset24 | ||
| from onnxscript.onnx_opset._impl.opset25 import Opset25 | ||
| from onnxscript.onnx_opset._impl.opset26 import Opset26 | ||
| from onnxscript.onnx_opset._impl.opset_ai_onnx_ml1 import Opset_ai_onnx_ml1 | ||
| from onnxscript.onnx_opset._impl.opset_ai_onnx_ml2 import Opset_ai_onnx_ml2 | ||
| from onnxscript.onnx_opset._impl.opset_ai_onnx_ml3 import Opset_ai_onnx_ml3 | ||
| from onnxscript.onnx_opset._impl.opset_ai_onnx_ml4 import Opset_ai_onnx_ml4 | ||
| from onnxscript.onnx_opset._impl.opset_ai_onnx_ml5 import Opset_ai_onnx_ml5 | ||
| from onnxscript.onnx_opset._impl.opset_ai_onnx_preview_training1 import ( | ||
| Opset_ai_onnx_preview_training1, | ||
| ) | ||
| from onnxscript.values import Opset | ||
|
|
||
| __all__ = [ | ||
|
|
@@ -72,11 +78,14 @@ | |
| "opset22", | ||
| "opset23", | ||
| "opset24", | ||
| "opset25", | ||
| "opset26", | ||
| "opset_ai_onnx_ml1", | ||
| "opset_ai_onnx_ml2", | ||
| "opset_ai_onnx_ml3", | ||
| "opset_ai_onnx_ml4", | ||
| "opset_ai_onnx_ml5", | ||
| "opset_ai_onnx_preview_training1", | ||
| ] | ||
|
|
||
|
|
||
|
|
@@ -110,11 +119,14 @@ | |
| opset22 = Opset22() | ||
| opset23 = Opset23() | ||
| opset24 = Opset24() | ||
| opset25 = Opset25() | ||
| opset26 = Opset26() | ||
| opset_ai_onnx_ml1 = Opset_ai_onnx_ml1() | ||
| opset_ai_onnx_ml2 = Opset_ai_onnx_ml2() | ||
| opset_ai_onnx_ml3 = Opset_ai_onnx_ml3() | ||
| opset_ai_onnx_ml4 = Opset_ai_onnx_ml4() | ||
| opset_ai_onnx_ml5 = Opset_ai_onnx_ml5() | ||
| opset_ai_onnx_preview_training1 = Opset_ai_onnx_preview_training1() | ||
| all_opsets: Mapping[Tuple[str, int], Opset] = { | ||
| ( | ||
| "", | ||
|
|
@@ -212,6 +224,14 @@ | |
| "", | ||
| 24, | ||
| ): opset24, | ||
| ( | ||
| "", | ||
| 25, | ||
| ): opset25, | ||
| ( | ||
| "", | ||
| 26, | ||
| ): opset26, | ||
| ( | ||
| "ai.onnx.ml", | ||
| 1, | ||
|
|
@@ -232,4 +252,8 @@ | |
| "ai.onnx.ml", | ||
| 5, | ||
| ): opset_ai_onnx_ml5, | ||
| ( | ||
| "ai.onnx.preview.training", | ||
| 1, | ||
| ): opset_ai_onnx_preview_training1, | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
| # -------------------------------------------------------------------------- | ||
| # pylint: disable=W0221,W0222,R0901,W0237 | ||
| # mypy: disable-error-code=override | ||
| # ruff: noqa: D214, D402, D405, D411, D416, D417 | ||
| # ruff: noqa: N801,E741,RUF036,D214,D402,D405,D411,D412,D416,D417 | ||
Check warningCode scanning / lintrunner RUFF/RUF100 Warning
Unused noqa directive (unused: N801, E741, D412; non-enabled: RUF036).
See https://docs.astral.sh/ruff/rules/unused-noqa |
||
| # -------------------------------------------------------------------------- | ||
|
|
||
| from __future__ import annotations | ||
|
|
@@ -397,18 +397,7 @@ | |
| ) | ||
|
|
||
| T2_Cast: TypeAlias = Union[ | ||
| BOOL, | ||
| DOUBLE, | ||
| FLOAT, | ||
| FLOAT16, | ||
| INT16, | ||
| INT32, | ||
| INT64, | ||
| INT8, | ||
| UINT16, | ||
| UINT32, | ||
| UINT64, | ||
| UINT8, | ||
| BOOL, DOUBLE, FLOAT, FLOAT16, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8 | ||
| ] | ||
|
|
||
| def Cast(self, input: T1_Cast, *, to: str) -> T2_Cast: | ||
|
|
@@ -847,11 +836,7 @@ | |
| T_Elu = TypeVar("T_Elu", DOUBLE, FLOAT, FLOAT16) | ||
|
|
||
| def Elu( | ||
| self, | ||
| X: T_Elu, | ||
| *, | ||
| alpha: float = 1.0, | ||
| consumed_inputs: Optional[Sequence[int]] = None, | ||
| self, X: T_Elu, *, alpha: float = 1.0, consumed_inputs: Optional[Sequence[int]] = None | ||
| ) -> T_Elu: | ||
| r"""[🌐 Elu(1)](https://onnx.ai/onnx/operators/onnx__Elu.html#elu-1 "Online Documentation") | ||
|
|
||
|
|
@@ -873,9 +858,7 @@ | |
| schema = get_schema("Elu", 1, "") | ||
| op = Op(self, "Elu", schema) | ||
| return op( | ||
| *self._prepare_inputs(schema, X), | ||
| alpha=alpha, | ||
| consumed_inputs=consumed_inputs, | ||
| *self._prepare_inputs(schema, X), alpha=alpha, consumed_inputs=consumed_inputs | ||
| ) | ||
|
|
||
| T_Equal = TypeVar("T_Equal", BOOL, INT32, INT64) | ||
|
|
@@ -1354,12 +1337,7 @@ | |
| T1_Greater: TypeAlias = BOOL | ||
|
|
||
| def Greater( | ||
| self, | ||
| A: T_Greater, | ||
| B: T_Greater, | ||
| *, | ||
| axis: Optional[int] = None, | ||
| broadcast: int = 0, | ||
| self, A: T_Greater, B: T_Greater, *, axis: Optional[int] = None, broadcast: int = 0 | ||
| ) -> T1_Greater: | ||
| r"""[🌐 Greater(1)](https://onnx.ai/onnx/operators/onnx__Greater.html#greater-1 "Online Documentation") | ||
|
|
||
|
|
@@ -1624,11 +1602,7 @@ | |
| schema = get_schema("LRN", 1, "") | ||
| op = Op(self, "LRN", schema) | ||
| return op( | ||
| *self._prepare_inputs(schema, X), | ||
| alpha=alpha, | ||
| beta=beta, | ||
| bias=bias, | ||
| size=size, | ||
| *self._prepare_inputs(schema, X), alpha=alpha, beta=beta, bias=bias, size=size | ||
| ) | ||
|
|
||
| T_LSTM = TypeVar("T_LSTM", DOUBLE, FLOAT, FLOAT16) | ||
|
|
@@ -1847,9 +1821,7 @@ | |
| schema = get_schema("LeakyRelu", 1, "") | ||
| op = Op(self, "LeakyRelu", schema) | ||
| return op( | ||
| *self._prepare_inputs(schema, X), | ||
| alpha=alpha, | ||
| consumed_inputs=consumed_inputs, | ||
| *self._prepare_inputs(schema, X), alpha=alpha, consumed_inputs=consumed_inputs | ||
| ) | ||
|
|
||
| T_Less = TypeVar("T_Less", DOUBLE, FLOAT, FLOAT16) | ||
|
|
@@ -1962,11 +1934,7 @@ | |
| ) | ||
|
|
||
| def Loop( | ||
| self, | ||
| M: Optional[I_Loop], | ||
| cond: Optional[B_Loop], | ||
| *v_initial: V_Loop, | ||
| body: GraphProto, | ||
| self, M: Optional[I_Loop], cond: Optional[B_Loop], *v_initial: V_Loop, body: GraphProto | ||
| ) -> V_Loop: | ||
| r"""[🌐 Loop(1)](https://onnx.ai/onnx/operators/onnx__Loop.html#loop-1 "Online Documentation") | ||
|
|
||
|
|
@@ -2524,11 +2492,7 @@ | |
| T_PRelu = TypeVar("T_PRelu", DOUBLE, FLOAT, FLOAT16) | ||
|
|
||
| def PRelu( | ||
| self, | ||
| X: T_PRelu, | ||
| slope: T_PRelu, | ||
| *, | ||
| consumed_inputs: Optional[Sequence[int]] = None, | ||
| self, X: T_PRelu, slope: T_PRelu, *, consumed_inputs: Optional[Sequence[int]] = None | ||
| ) -> T_PRelu: | ||
| r"""[🌐 PRelu(1)](https://onnx.ai/onnx/operators/onnx__PRelu.html#prelu-1 "Online Documentation") | ||
|
|
||
|
|
@@ -2602,10 +2566,7 @@ | |
| schema = get_schema("Pad", 1, "") | ||
| op = Op(self, "Pad", schema) | ||
| return op( | ||
| *self._prepare_inputs(schema, data), | ||
| mode=mode, | ||
| paddings=paddings, | ||
| value=value, | ||
| *self._prepare_inputs(schema, data), mode=mode, paddings=paddings, value=value | ||
| ) | ||
|
|
||
| T_Pow = TypeVar("T_Pow", DOUBLE, FLOAT, FLOAT16) | ||
|
|
@@ -3013,11 +2974,7 @@ | |
| schema = get_schema("RandomUniformLike", 1, "") | ||
| op = Op(self, "RandomUniformLike", schema) | ||
| return op( | ||
| *self._prepare_inputs(schema, input), | ||
| dtype=dtype, | ||
| high=high, | ||
| low=low, | ||
| seed=seed, | ||
| *self._prepare_inputs(schema, input), dtype=dtype, high=high, low=low, seed=seed | ||
| ) | ||
|
|
||
| T_Reciprocal = TypeVar("T_Reciprocal", DOUBLE, FLOAT, FLOAT16) | ||
|
|
@@ -3046,11 +3003,7 @@ | |
| T_ReduceL1 = TypeVar("T_ReduceL1", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) | ||
|
|
||
| def ReduceL1( | ||
| self, | ||
| data: T_ReduceL1, | ||
| *, | ||
| axes: Optional[Sequence[int]] = None, | ||
| keepdims: int = 1, | ||
| self, data: T_ReduceL1, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 | ||
| ) -> T_ReduceL1: | ||
| r"""[🌐 ReduceL1(1)](https://onnx.ai/onnx/operators/onnx__ReduceL1.html#reducel1-1 "Online Documentation") | ||
|
|
||
|
|
@@ -3080,11 +3033,7 @@ | |
| T_ReduceL2 = TypeVar("T_ReduceL2", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) | ||
|
|
||
| def ReduceL2( | ||
| self, | ||
| data: T_ReduceL2, | ||
| *, | ||
| axes: Optional[Sequence[int]] = None, | ||
| keepdims: int = 1, | ||
| self, data: T_ReduceL2, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 | ||
| ) -> T_ReduceL2: | ||
| r"""[🌐 ReduceL2(1)](https://onnx.ai/onnx/operators/onnx__ReduceL2.html#reducel2-1 "Online Documentation") | ||
|
|
||
|
|
@@ -3116,11 +3065,7 @@ | |
| ) | ||
|
|
||
| def ReduceLogSum( | ||
| self, | ||
| data: T_ReduceLogSum, | ||
| *, | ||
| axes: Optional[Sequence[int]] = None, | ||
| keepdims: int = 1, | ||
| self, data: T_ReduceLogSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 | ||
| ) -> T_ReduceLogSum: | ||
| r"""[🌐 ReduceLogSum(1)](https://onnx.ai/onnx/operators/onnx__ReduceLogSum.html#reducelogsum-1 "Online Documentation") | ||
|
|
||
|
|
@@ -3186,11 +3131,7 @@ | |
| T_ReduceMax = TypeVar("T_ReduceMax", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) | ||
|
|
||
| def ReduceMax( | ||
| self, | ||
| data: T_ReduceMax, | ||
| *, | ||
| axes: Optional[Sequence[int]] = None, | ||
| keepdims: int = 1, | ||
| self, data: T_ReduceMax, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 | ||
| ) -> T_ReduceMax: | ||
| r"""[🌐 ReduceMax(1)](https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax-1 "Online Documentation") | ||
|
|
||
|
|
@@ -3222,11 +3163,7 @@ | |
| ) | ||
|
|
||
| def ReduceMean( | ||
| self, | ||
| data: T_ReduceMean, | ||
| *, | ||
| axes: Optional[Sequence[int]] = None, | ||
| keepdims: int = 1, | ||
| self, data: T_ReduceMean, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 | ||
| ) -> T_ReduceMean: | ||
| r"""[🌐 ReduceMean(1)](https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-1 "Online Documentation") | ||
|
|
||
|
|
@@ -3256,11 +3193,7 @@ | |
| T_ReduceMin = TypeVar("T_ReduceMin", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) | ||
|
|
||
| def ReduceMin( | ||
| self, | ||
| data: T_ReduceMin, | ||
| *, | ||
| axes: Optional[Sequence[int]] = None, | ||
| keepdims: int = 1, | ||
| self, data: T_ReduceMin, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 | ||
| ) -> T_ReduceMin: | ||
| r"""[🌐 ReduceMin(1)](https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin-1 "Online Documentation") | ||
|
|
||
|
|
@@ -3292,11 +3225,7 @@ | |
| ) | ||
|
|
||
| def ReduceProd( | ||
| self, | ||
| data: T_ReduceProd, | ||
| *, | ||
| axes: Optional[Sequence[int]] = None, | ||
| keepdims: int = 1, | ||
| self, data: T_ReduceProd, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 | ||
| ) -> T_ReduceProd: | ||
| r"""[🌐 ReduceProd(1)](https://onnx.ai/onnx/operators/onnx__ReduceProd.html#reduceprod-1 "Online Documentation") | ||
|
|
||
|
|
@@ -3326,11 +3255,7 @@ | |
| T_ReduceSum = TypeVar("T_ReduceSum", DOUBLE, FLOAT, FLOAT16, INT32, INT64, UINT32, UINT64) | ||
|
|
||
| def ReduceSum( | ||
| self, | ||
| data: T_ReduceSum, | ||
| *, | ||
| axes: Optional[Sequence[int]] = None, | ||
| keepdims: int = 1, | ||
| self, data: T_ReduceSum, *, axes: Optional[Sequence[int]] = None, keepdims: int = 1 | ||
| ) -> T_ReduceSum: | ||
| r"""[🌐 ReduceSum(1)](https://onnx.ai/onnx/operators/onnx__ReduceSum.html#reducesum-1 "Online Documentation") | ||
|
|
||
|
|
@@ -3445,9 +3370,7 @@ | |
| schema = get_schema("Reshape", 1, "") | ||
| op = Op(self, "Reshape", schema) | ||
| return op( | ||
| *self._prepare_inputs(schema, data), | ||
| consumed_inputs=consumed_inputs, | ||
| shape=shape, | ||
| *self._prepare_inputs(schema, data), consumed_inputs=consumed_inputs, shape=shape | ||
| ) | ||
|
|
||
| T_Selu = TypeVar("T_Selu", DOUBLE, FLOAT, FLOAT16) | ||
|
|
@@ -4036,9 +3959,16 @@ | |
| r"""[🌐 Transpose(1)](https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-1 "Online Documentation") | ||
|
|
||
|
|
||
| Transpose the input tensor similar to numpy.transpose. For example, when | ||
| perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape | ||
| will be (2, 1, 3). | ||
| Returns a transpose of the input tensor. (Similar to `numpy.transpose`). | ||
| The optional attribute `perm` must be a permutation of the dimensions of | ||
| the input tensor. Axis `i` of the output tensor corresponds to the axis | ||
| `perm[i]` of the input tensor. | ||
| For example, when perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), | ||
| the output shape will be (2, 1, 3). | ||
| When perm=(1, 2, 0), given an input tensor of shape (1, 2, 3), | ||
| the output shape will be (2, 3, 1). | ||
| If the attribute `perm` is omitted, its default value is `(n-1, ..., 0)`, | ||
| where `n` is the rank of the input tensor. | ||
|
|
||
|
|
||
| Args: | ||
|
|
@@ -4095,12 +4025,7 @@ | |
| T_Upsample = TypeVar("T_Upsample", BOOL, DOUBLE, FLOAT, FLOAT16, INT32, INT64) | ||
|
|
||
| def Upsample( | ||
| self, | ||
| X: T_Upsample, | ||
| *, | ||
| height_scale: float, | ||
| mode: str = "nearest", | ||
| width_scale: float, | ||
| self, X: T_Upsample, *, height_scale: float, mode: str = "nearest", width_scale: float | ||
| ) -> T_Upsample: | ||
| r"""[🌐 Upsample(1)](https://onnx.ai/onnx/operators/onnx__Upsample.html#upsample-1 "Online Documentation") | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,7 @@ | |
| # -------------------------------------------------------------------------- | ||
| # pylint: disable=W0221,W0222,R0901,W0237 | ||
| # mypy: disable-error-code=override | ||
| # ruff: noqa: D402 | ||
| # ruff: noqa: N801,E741,RUF036,D214,D402,D405,D411,D412,D416,D417 | ||
Check warningCode scanning / lintrunner RUFF/RUF100 Warning
Unused noqa directive (unused: N801, E741, D214, D405, D411, D412, D416, D417; non-enabled: RUF036).
See https://docs.astral.sh/ruff/rules/unused-noqa |
||
| # -------------------------------------------------------------------------- | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to implement this with how the dft op is defined currently? I think the current implementation has issues with normalization, but I couldn't figure out what went wrong. You expertise is appreciated!
Given the onnx update is not merged and implemented yet, we are not able to incorporate the change to torchlib at the moment as we treat it as UB (unless the current behavior is actually correct in runtimes?).
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current implementation may be UB itself: from what I can tell, it treats the
n//2+1length matrix as if it was actually lengthn. Most of the time this means that you will be accessing unset memory and get small values (e.g.1e-300), so it is equivalent to roughly scaling most of the factors by 1/2 (except for the 0th frequency, so the offset remains the same).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I take that back: since it passes the fft length, it won't be UB, but it still will zero pad (and hence give the wrong answer).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be possible to implement the correct behavior given what we have currently (you would need to extend the array and reverse/conjugate the values), but it would be suboptimal. If the patch to ONNX is accepted, what would be the timeline to getting it updated here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the patch is accepted, depending on consensus it may or may not need to be in a new opset. If it is in the new opset, we need to update pytorch code. It will be accepted here if it is just a clarification of the old opset.