|
| 1 | +from subprocess import check_output, STDOUT, CalledProcessError |
| 2 | +import sys |
| 3 | +import pytest |
| 4 | +import glob |
| 5 | + |
| 6 | + |
| 7 | +PYTHON_CODE_DIR = "python_code" |
| 8 | +ALL_FILES = glob.glob(PYTHON_CODE_DIR + "/*.py") |
| 9 | + |
| 10 | + |
| 11 | +@pytest.mark.parametrize('file_path', ALL_FILES) |
| 12 | +def test_run_file(file_path): |
| 13 | + if 'nvidia' in file_path: |
| 14 | + # FIXME: NVIDIA models checkoints are on cuda |
| 15 | + pytest.skip("temporarily disabled") |
| 16 | + if 'pytorch_fairseq_translation' in file_path: |
| 17 | + pytest.skip("temporarily disabled") |
| 18 | + if 'ultralytics_yolov5' in file_path: |
| 19 | + # FIXME torch.nn.modules.module.ModuleAttributeError: 'autoShape' object has no attribute 'fuse |
| 20 | + pytest.skip("temporarily disabled") |
| 21 | + if 'huggingface_pytorch-transformers' in file_path: |
| 22 | + # FIXME torch.nn.modules.module.ModuleAttributeError: 'autoShape' object has no attribute 'fuse |
| 23 | + pytest.skip("temporarily disabled") |
| 24 | + if 'pytorch_fairseq_roberta' in file_path: |
| 25 | + pytest.skip("temporarily disabled") |
| 26 | + |
| 27 | + # We just run the python files in a separate sub-process. We really want a |
| 28 | + # subprocess here because otherwise we might run into package versions |
| 29 | + # issues: imagine script A that needs torchvivion 0.9 and script B that |
| 30 | + # needs torchvision 0.10. If script A is run prior to script B in the same |
| 31 | + # process, script B will still be run with torchvision 0.9 because the only |
| 32 | + # "import torchvision" statement that counts is the first one, and even |
| 33 | + # torchub sys.path shenanigans can do nothing about this. By creating |
| 34 | + # subprocesses we're sure that all file executions are fully independent. |
| 35 | + try: |
| 36 | + # This is inspired (and heavily simplified) from |
| 37 | + # https://github.com/cloudpipe/cloudpickle/blob/343da119685f622da2d1658ef7b3e2516a01817f/tests/testutils.py#L177 |
| 38 | + out = check_output([sys.executable, file_path], stderr=STDOUT) |
| 39 | + print(out.decode()) |
| 40 | + except CalledProcessError as e: |
| 41 | + raise RuntimeError(f"Script {file_path} errored with output:\n{e.output.decode()}") |
0 commit comments