Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -11299,26 +11299,7 @@
"Matcher": "ChangePrefixMatcher"
},
"torch.randint": {
"Matcher": "RandintMatcher",
"paddle_api": "paddle.randint",
"min_input_args": 2,
"args_list": [
"low",
"high",
"size",
"*",
"generator",
"out",
"dtype",
"layout",
"device",
"pin_memory",
"requires_grad"
],
"kwargs_change": {
"size": "shape",
"dtype": "dtype"
}
"Matcher": "ChangePrefixMatcher"
},
"torch.randint_like": {
"Matcher": "RandintLikeMatcher",
Expand Down
2 changes: 1 addition & 1 deletion paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def get_paddle_nodes(self, args, kwargs):
kwargs = self.parse_kwargs(kwargs, allow_none=True)

# temporary delete these unsupport args, which paddle does not support now
for k in ["layout", "generator", "memory_format", "sparse_grad"]:
for k in ["layout", "generator", "memory_format", "sparse_grad", "requires_grad", "pin_memory", "device"]:
if k in kwargs:
kwargs.pop(k)
code = f"{self.get_paddle_api()}({self.args_and_kwargs_to_str(args, kwargs)})"
Expand Down
215 changes: 210 additions & 5 deletions tests/test_randint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import textwrap

import paddle
import pytest
from apibase import APIBase

Expand Down Expand Up @@ -141,11 +140,11 @@ def test_case_11():
obj.run(pytorch_code, ["result"], check_value=False)


@pytest.mark.skipif(
condition=not paddle.device.is_compiled_with_cuda(),
reason="can only run on paddle with CUDA",
)
def test_case_12():
import torch

if not torch.cuda.is_available():
pytest.skip("pin_memory=True requires CUDA")
pytorch_code = textwrap.dedent(
"""
import torch
Expand All @@ -155,3 +154,209 @@ def test_case_12():
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


# Additional test cases for comprehensive coverage


def test_case_13():
"""Test with size keyword argument explicitly"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(0, 10, size=(3, 3))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_14():
"""Test with only high and size as keyword arguments"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(high=5, size=(2, 3))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_15():
"""Test with mixed positional and keyword: low positional, high and size as keyword"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(1, high=10, size=(2, 2))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_16():
"""Test with dtype=torch.int32 and size keyword"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(0, 100, size=(4, 4), dtype=torch.int32)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_17():
"""Test 1D tensor"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(low=0, high=10, size=(5,))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_18():
"""Test 3D tensor"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(0, 5, (2, 3, 4))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_19():
"""Test with out parameter and size keyword"""
pytorch_code = textwrap.dedent(
"""
import torch
out = torch.empty(3, 3, dtype=torch.int64)
result = torch.randint(0, 10, size=(3, 3), out=out)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_20():
"""Test with expression as high parameter"""
pytorch_code = textwrap.dedent(
"""
import torch
base = 5
result = torch.randint(0, base * 2, (2, 2))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_21():
"""Test with expression as size parameter"""
pytorch_code = textwrap.dedent(
"""
import torch
dim = 2
result = torch.randint(0, 10, (dim + 1, dim + 1))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_22():
"""Test with all keyword arguments in different order"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(size=(2, 2), high=10, low=0, dtype=torch.int64)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_23():
"""Test with negative low value"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(-10, 10, (3, 3))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_24():
"""Test with large range"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(0, 1000000, (2, 2), dtype=torch.int64)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_25():
"""Test with single element tensor"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(0, 10, (1,))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_26():
"""Test with 4D tensor"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(low=0, high=5, size=(2, 2, 2, 2))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_27():
"""Test with variable unpacking for size"""
pytorch_code = textwrap.dedent(
"""
import torch
shape = (3, 4)
result = torch.randint(0, 10, shape)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_28():
"""Test with only positional arguments: low, high, size"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(5, 15, (2, 3))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_29():
"""Test with out parameter as keyword, other as positional"""
pytorch_code = textwrap.dedent(
"""
import torch
out = torch.empty(2, 2, dtype=torch.int64)
result = torch.randint(0, 10, (2, 2), out=out)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_30():
"""Test with dtype as keyword only"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(10, (3, 3), dtype=torch.int32)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)