Skip to content

Commit 741d0f4

Browse files
suofacebook-github-bot
authored andcommitted
[package] split tests (pytorch#53749)
Summary: Pull Request resolved: pytorch#53749 Split up tests into cases that cover specific functionality. Goals: 1. Avoid the omnibus test file mess (see: test_jit.py) by imposing early structure and deliberately avoiding a generic TestPackage test case. 2. Encourage testing of individual APIs and components by example. 3. Hide the fake modules we created for these tests in their own folder. You can either run the test files individually, or still use test/test_package.py like before. Also this isort + black formats all the tests. Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D26958535 Pulled By: suo fbshipit-source-id: 8a63048b95ca71f4f1aa94e53c48442686076034
1 parent 4351d09 commit 741d0f4

16 files changed

+1224
-928
lines changed

test/module_a.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

test/package/common.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
import sys
3+
from tempfile import NamedTemporaryFile
4+
5+
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
6+
7+
8+
class PackageTestCase(TestCase):
9+
def __init__(self, *args, **kwargs):
10+
super().__init__(*args, **kwargs)
11+
self._temporary_files = []
12+
13+
def temp(self):
14+
t = NamedTemporaryFile()
15+
name = t.name
16+
if IS_WINDOWS:
17+
t.close() # can't read an open file in windows
18+
else:
19+
self._temporary_files.append(t)
20+
return name
21+
22+
def setUp(self):
23+
"""Add test/package/ to module search path. This ensures that
24+
importing our fake packages via, e.g. `import package_a` will always
25+
work regardless of how we invoke the test.
26+
"""
27+
super().setUp()
28+
self.package_test_dir = os.path.dirname(os.path.realpath(__file__))
29+
self.orig_sys_path = sys.path.copy()
30+
sys.path.append(self.package_test_dir)
31+
32+
def tearDown(self):
33+
super().tearDown()
34+
sys.path = self.orig_sys_path
35+
36+
# remove any temporary files
37+
for t in self._temporary_files:
38+
t.close()
39+
self._temporary_files = []

test/package/module_a.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
result = "module_a"
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
result = 'package_a'
1+
result = "package_a"
2+
23

34
class PackageAObject:
4-
__slots__ = ['obj']
5+
__slots__ = ["obj"]
56

67
def __init__(self, obj):
78
self.obj = obj
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
result = 'package_a.subpackage'
1+
result = "package_a.subpackage"
2+
3+
24
class PackageASubpackageObject:
35
pass
46

7+
58
def leaf_function(a, b):
69
return a + b
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import torch
22
from torch.fx import wrap
33

4-
wrap('a_non_torch_leaf')
4+
wrap("a_non_torch_leaf")
5+
56

67
class SimpleTest(torch.nn.Module):
78
def forward(self, x):
89
x = a_non_torch_leaf(x, x)
910
return torch.relu(x + 3.0)
1011

12+
1113
def a_non_torch_leaf(a, b):
1214
return a + b
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
from sys import version_info
2+
from textwrap import dedent
3+
from unittest import skipIf
4+
5+
from torch.package import (
6+
DeniedModuleError,
7+
EmptyMatchError,
8+
PackageExporter,
9+
PackageImporter,
10+
)
11+
from torch.testing._internal.common_utils import run_tests
12+
13+
try:
14+
from .common import PackageTestCase
15+
except ImportError:
16+
# Support the case where we run this file directly.
17+
from common import PackageTestCase # type: ignore
18+
19+
20+
class TestDependencyAPI(PackageTestCase):
21+
"""Dependency management API tests.
22+
- mock()
23+
- extern()
24+
- deny()
25+
"""
26+
27+
def test_extern(self):
28+
filename = self.temp()
29+
with PackageExporter(filename, verbose=False) as he:
30+
he.extern(["package_a.subpackage", "module_a"])
31+
he.require_module("package_a.subpackage")
32+
he.require_module("module_a")
33+
he.save_module("package_a")
34+
hi = PackageImporter(filename)
35+
import module_a
36+
import package_a.subpackage
37+
38+
module_a_im = hi.import_module("module_a")
39+
hi.import_module("package_a.subpackage")
40+
package_a_im = hi.import_module("package_a")
41+
42+
self.assertIs(module_a, module_a_im)
43+
self.assertIsNot(package_a, package_a_im)
44+
self.assertIs(package_a.subpackage, package_a_im.subpackage)
45+
46+
def test_extern_glob(self):
47+
filename = self.temp()
48+
with PackageExporter(filename, verbose=False) as he:
49+
he.extern(["package_a.*", "module_*"])
50+
he.save_module("package_a")
51+
he.save_source_string(
52+
"test_module",
53+
dedent(
54+
"""\
55+
import package_a.subpackage
56+
import module_a
57+
"""
58+
),
59+
)
60+
hi = PackageImporter(filename)
61+
import module_a
62+
import package_a.subpackage
63+
64+
module_a_im = hi.import_module("module_a")
65+
hi.import_module("package_a.subpackage")
66+
package_a_im = hi.import_module("package_a")
67+
68+
self.assertIs(module_a, module_a_im)
69+
self.assertIsNot(package_a, package_a_im)
70+
self.assertIs(package_a.subpackage, package_a_im.subpackage)
71+
72+
def test_extern_glob_allow_empty(self):
73+
"""
74+
Test that an error is thrown when a extern glob is specified with allow_empty=True
75+
and no matching module is required during packaging.
76+
"""
77+
filename = self.temp()
78+
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
79+
with PackageExporter(filename, verbose=False) as exporter:
80+
exporter.extern(include=["package_a.*"], allow_empty=False)
81+
exporter.save_module("package_b.subpackage")
82+
83+
def test_deny(self):
84+
"""
85+
Test marking packages as "deny" during export.
86+
"""
87+
filename = self.temp()
88+
89+
with self.assertRaisesRegex(
90+
DeniedModuleError,
91+
"required during packaging but has been explicitly blocklisted",
92+
):
93+
with PackageExporter(filename, verbose=False) as exporter:
94+
exporter.deny(["package_a.subpackage", "module_a"])
95+
exporter.require_module("package_a.subpackage")
96+
97+
def test_deny_glob(self):
98+
"""
99+
Test marking packages as "deny" using globs instead of package names.
100+
"""
101+
filename = self.temp()
102+
with self.assertRaisesRegex(
103+
DeniedModuleError,
104+
"required during packaging but has been explicitly blocklisted",
105+
):
106+
with PackageExporter(filename, verbose=False) as exporter:
107+
exporter.deny(["package_a.*", "module_*"])
108+
exporter.save_source_string(
109+
"test_module",
110+
dedent(
111+
"""\
112+
import package_a.subpackage
113+
import module_a
114+
"""
115+
),
116+
)
117+
118+
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
119+
def test_mock(self):
120+
filename = self.temp()
121+
with PackageExporter(filename, verbose=False) as he:
122+
he.mock(["package_a.subpackage", "module_a"])
123+
he.save_module("package_a")
124+
he.require_module("package_a.subpackage")
125+
he.require_module("module_a")
126+
hi = PackageImporter(filename)
127+
import package_a.subpackage
128+
129+
_ = package_a.subpackage
130+
import module_a
131+
132+
_ = module_a
133+
134+
m = hi.import_module("package_a.subpackage")
135+
r = m.result
136+
with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
137+
r()
138+
139+
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
140+
def test_mock_glob(self):
141+
filename = self.temp()
142+
with PackageExporter(filename, verbose=False) as he:
143+
he.mock(["package_a.*", "module*"])
144+
he.save_module("package_a")
145+
he.save_source_string(
146+
"test_module",
147+
dedent(
148+
"""\
149+
import package_a.subpackage
150+
import module_a
151+
"""
152+
),
153+
)
154+
hi = PackageImporter(filename)
155+
import package_a.subpackage
156+
157+
_ = package_a.subpackage
158+
import module_a
159+
160+
_ = module_a
161+
162+
m = hi.import_module("package_a.subpackage")
163+
r = m.result
164+
with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
165+
r()
166+
167+
def test_mock_glob_allow_empty(self):
168+
"""
169+
Test that an error is thrown when a mock glob is specified with allow_empty=True
170+
and no matching module is required during packaging.
171+
"""
172+
filename = self.temp()
173+
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
174+
with PackageExporter(filename, verbose=False) as exporter:
175+
exporter.mock(include=["package_a.*"], allow_empty=False)
176+
exporter.save_module("package_b.subpackage")
177+
178+
def test_module_glob(self):
179+
from torch.package.package_exporter import _GlobGroup
180+
181+
def check(include, exclude, should_match, should_not_match):
182+
x = _GlobGroup(include, exclude)
183+
for e in should_match:
184+
self.assertTrue(x.matches(e))
185+
for e in should_not_match:
186+
self.assertFalse(x.matches(e))
187+
188+
check(
189+
"torch.*",
190+
[],
191+
["torch.foo", "torch.bar"],
192+
["tor.foo", "torch.foo.bar", "torch"],
193+
)
194+
check(
195+
"torch.**",
196+
[],
197+
["torch.foo", "torch.bar", "torch.foo.bar", "torch"],
198+
["what.torch", "torchvision"],
199+
)
200+
check("torch.*.foo", [], ["torch.w.foo"], ["torch.hi.bar.baz"])
201+
check(
202+
"torch.**.foo", [], ["torch.w.foo", "torch.hi.bar.foo"], ["torch.f.foo.z"]
203+
)
204+
check("torch*", [], ["torch", "torchvision"], ["torch.f"])
205+
check(
206+
"torch.**",
207+
["torch.**.foo"],
208+
["torch", "torch.bar", "torch.barfoo"],
209+
["torch.foo", "torch.some.foo"],
210+
)
211+
check("**.torch", [], ["torch", "bar.torch"], ["visiontorch"])
212+
213+
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
214+
def test_pickle_mocked(self):
215+
import package_a.subpackage
216+
217+
obj = package_a.subpackage.PackageASubpackageObject()
218+
obj2 = package_a.PackageAObject(obj)
219+
220+
filename = self.temp()
221+
with PackageExporter(filename, verbose=False) as he:
222+
he.mock(include="package_a.subpackage")
223+
he.save_pickle("obj", "obj.pkl", obj2)
224+
225+
hi = PackageImporter(filename)
226+
with self.assertRaises(NotImplementedError):
227+
hi.load_pickle("obj", "obj.pkl")
228+
229+
230+
if __name__ == "__main__":
231+
run_tests()

0 commit comments

Comments
 (0)