Skip to content

Commit

Permalink
revert the whole family
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Oct 23, 2024
1 parent 8295a77 commit a833dc6
Showing 1 changed file with 35 additions and 37 deletions.
72 changes: 35 additions & 37 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
TInt,
TReal,
TRealOrUInt8,
TRealUnlessFloat16OrInt8,
TRealUnlessInt16OrInt8,
TTensor,
TTensor2,
Expand Down Expand Up @@ -540,7 +541,7 @@ def _integral_to_be_adjusted(dtype: int) -> bool:

@torch_op("aten::arange", trace_only=True)
def aten_arange(
end: float,
end: TRealUnlessFloat16OrInt8,
dtype: int = -1,
layout: str = "",
device: str = "",
Expand All @@ -549,10 +550,9 @@ def aten_arange(
"""arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

if dtype == -1 or dtype is None:
if isinstance(end, int):
result = op.Range(0, end, 1)
else:
result = op.Range(0.0, end, 1.0)
zero = op.CastLike(0.0, end)
one = op.CastLike(1.0, end)
result = op.Range(zero, end, one)
elif _range_supported(dtype):
end = op.Cast(end, to=dtype)
zero = op.Cast(0, to=dtype)
Expand All @@ -563,7 +563,7 @@ def aten_arange(
# because the input dtype may be e.g. bfloat16 / int8 etc.
# which Range does not support. The output type is ensured because the output
# is casted to the specified dtype.
end = op.Constant(value_float=float(end))
end = op.Cast(end, to=FLOAT.dtype)
zero = op.Constant(value_float=0.0)
one = op.Constant(value_float=1.0)
result = op.Cast(op.Range(zero, end, one), to=dtype)
Expand All @@ -573,8 +573,8 @@ def aten_arange(

@torch_op("aten::arange.start", trace_only=True)
def aten_arange_start(
start: TReal,
end: TReal,
start: TRealUnlessFloat16OrInt8,
end: TRealUnlessFloat16OrInt8,
dtype: int = -1,
layout: str = "",
device: str = "",
Expand Down Expand Up @@ -604,57 +604,56 @@ def aten_arange_start(


def _adjust_args_for_arange_int_dtype(
start: float,
end: float,
step: float,
) -> Tuple[float, float, float]:
if start < 0:
start = math.ceil(start)
if step < 0:
start = math.floor(start)
start: TRealUnlessFloat16OrInt8,
end: TRealUnlessFloat16OrInt8,
step: TRealUnlessFloat16OrInt8,
) -> Tuple[FLOAT, FLOAT, FLOAT]:
zero = op.Cast(0.0, to=FLOAT.dtype)
start = op.Cast(start, to=FLOAT.dtype)
end = op.Cast(end, to=FLOAT.dtype)
step = op.Cast(step, to=FLOAT.dtype)

return float(start), float(end), float(step)
start = op.Where(op.Less(start, zero), op.Ceil(start), start)
start = op.Where(op.Less(step, zero), op.Floor(start), start)

return (start, end, step)


@torch_op("aten::arange.start_step", trace_only=True)
def aten_arange_start_step(
start: float,
end: float,
step: float = 1.0,
start: TRealUnlessFloat16OrInt8,
end: TRealUnlessFloat16OrInt8,
step: TRealUnlessFloat16OrInt8 = 1.0,
dtype: int = -1,
layout: str = "",
device: str = "",
pin_memory: bool = False,
) -> TensorType:
"""arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""

if dtype == -1 or dtype is None:
if isinstance(start, int) and isinstance(end, int):
result = op.Range(start, end, int(step))
else:
start = float(start)
end = float(end)
step = float(step)
result = op.Range(start, end, step)
if dtype == -1:
start = op.Cast(start, to=FLOAT.dtype)
end = op.Cast(end, to=FLOAT.dtype)
result = op.Range(start, end, step)
elif _integral_to_be_adjusted(dtype):
# PyTorch arange op handles these integral types differently from INT64,
# so we have to adjust these arguments accordingly.
# https://github.com/pytorch/pytorch/blob/121cfb60c0817816fcbe2190303b7f6d05c77cf3/torch/_refs/__init__.py#L4794
start, end, step = _adjust_args_for_arange_int_dtype(start, end, step)
result = op.Cast(op.Range(start, end, step), to=dtype)
elif dtype == INT64.dtype:
end = int(end)
start = int(start)
step = int(step)
end = op.Cast(end, to=dtype)
start = op.Cast(start, to=dtype)
step = op.Cast(step, to=dtype)
result = op.Range(start, end, step)
else:
# Cast input to float if dtype is not supported by Range,
# because the input dtype may be e.g. bfloat16,
# which Range does not support. The output type is ensured because the output
# is casted to the specified dtype.
end = float(end)
start = float(start)
step = float(step)
end = op.Cast(end, to=FLOAT.dtype)
start = op.Cast(start, to=FLOAT.dtype)
step = op.Cast(step, to=FLOAT.dtype)
result = op.Cast(op.Range(start, end, step), to=dtype)

return result
Expand Down Expand Up @@ -4731,8 +4730,8 @@ def aten_linear_backward(

@torch_op("aten::linspace", trace_only=True)
def aten_linspace(
start: float,
end: float,
start: TFloat,
end: TFloat,
steps: int,
dtype: int = FLOAT.dtype,
layout: str = "",
Expand All @@ -4750,7 +4749,6 @@ def aten_linspace(
if steps == 1:
return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype)

# TODO(justinchuby): Simplify the logic knowing start and end are floats
rg = aten_arange_start(0, steps, dtype=dtype)
start = op.Cast(start, to=dtype)
end = op.Cast(end, to=dtype)
Expand Down

0 comments on commit a833dc6

Please sign in to comment.