From 500ed627194eece692818e8dded9007021562ab4 Mon Sep 17 00:00:00 2001 From: bhack Date: Tue, 24 Sep 2024 23:02:31 +0200 Subject: [PATCH] [torchlib] Support mod and eq on SymInt (#1879) Add some missing operators from https://github.com/pytorch/pytorch/issues/136524 /cc @justinchuby --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 44c6c0a87..253026d80 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3269,7 +3269,7 @@ def aten_empty_strided( return op.Expand(zero, size) -@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar")) +@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq")) def aten_eq(self: TTensor, other: TTensor) -> BOOL: """eq.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -7085,7 +7085,7 @@ def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrB return op.Sub(self, op.Mul(rounded_quotient, other)) -@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar")) +@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod")) def aten_remainder_int(self: TInt, other: TInt) -> TInt: """remainder.Tensor(Tensor self, Tensor other) -> Tensor"""