Skip to content

Commit 049c08e

Browse files
Revert "[dynamo] [guard] Add caching for inside torch.compile.disable function to avoid unnecessary recompilation. (pytorch#160934)"
This reverts commit 8f31aa9. Reverted pytorch#160934 on behalf of https://github.com/anijain2305 due to causes memory leak leading to OOMs ([comment](pytorch#160934 (comment)))
1 parent affd071 commit 049c08e

File tree

5 files changed

+7
-91
lines changed

5 files changed

+7
-91
lines changed

test/dynamo/test_misc.py

Lines changed: 5 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8647,64 +8647,15 @@ def global_context_capture_fn(frame_summary):
86478647
self.assertEqual(seen_frames[1].name, "uwu_inline_me")
86488648
self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)")
86498649

8650-
def test_recompile_on_disable_1(self):
8651-
# fix https://github.com/pytorch/pytorch/issues/157399
8650+
def test_error_on_recompile(self):
86528651
@torch.compile(backend="eager")
8653-
def fn(x):
8654-
@torch._dynamo.disable
8655-
def inner(x):
8656-
return x + 10
8657-
8658-
return inner(x) + 1
8659-
8660-
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
8661-
try:
8662-
for i in range(5):
8663-
fn(torch.rand(2, 3))
8664-
except torch._dynamo.exc.RecompileError as e:
8665-
self.fail("RecompileError raised unexpectedly: " + str(e))
8666-
8667-
def test_recompile_on_disable_2(self):
8668-
def outer(x, cond):
8669-
@torch._dynamo.disable()
8670-
def fn0(y):
8671-
return y + 1
8672-
8673-
@torch._dynamo.disable()
8674-
def fn1(y):
8675-
return y + 2
8676-
8677-
if cond:
8678-
f = fn0
8679-
else:
8680-
f = fn1
8681-
8682-
torch._dynamo.graph_break()
8683-
# there will be a resume function here
8684-
return f(x)
8685-
8686-
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
8687-
with self.assertRaises(torch._dynamo.exc.RecompileError):
8688-
x = torch.rand(2, 3)
8689-
self.assertEqual(outer(x, True), torch.compile(outer)(x, True))
8690-
self.assertEqual(outer(x, False), torch.compile(outer)(x, False))
8691-
8692-
def test_create_nested_fn_cache_clear(self):
8693-
def outer(x):
8694-
@torch._dynamo.disable()
8695-
def f(y):
8696-
return y + 2
8697-
8698-
return f(x) + 1
8652+
def fn(a, b):
8653+
return a + b
86998654

8700-
outer = torch.compile(outer)
87018655
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
87028656
with self.assertRaises(torch._dynamo.exc.RecompileError):
8703-
outer(torch.randn(3, 3))
8704-
from torch._dynamo.utils import create_nested_fn_cache
8705-
8706-
create_nested_fn_cache.clear()
8707-
outer(torch.randn(3, 3))
8657+
fn(torch.rand(2, 3), torch.rand(2, 3))
8658+
fn(torch.rand(2, 3), (1, 2, 3))
87088659

87098660
def test_guards_strip_function_call(self):
87108661
from torch._dynamo.guards import strip_function_call

test/test_autograd.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,8 +614,6 @@ def unpack(x):
614614

615615
with disable_gc():
616616
unpack_hook_ref = scope()
617-
if torch._dynamo.is_compiling():
618-
torch._dynamo.reset()
619617
self.assertIsNone(unpack_hook_ref())
620618

621619
def test_will_engine_execute_node(self):

torch/_dynamo/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from .pgo import reset_code_state
5252
from .symbolic_convert import TensorifyState
5353
from .utils import (
54-
create_nested_fn_cache,
5554
graph_break_reasons,
5655
guard_failures,
5756
orig_code_map,
@@ -145,7 +144,6 @@ def reset() -> None:
145144
torch._dynamo.utils.warn_once_cache.clear()
146145
torch._dynamo.utils.user_obj_id_to_weakref.clear()
147146
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
148-
create_nested_fn_cache.clear()
149147

150148

151149
def reset_code_caches() -> None:

torch/_dynamo/utils.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4849,22 +4849,3 @@ def get_traced_code() -> Optional[list[CodeType]]:
48494849
from torch._guards import TracingContext
48504850

48514851
return TracingContext.get_traced_code()
4852-
4853-
4854-
class CreateNestedFnCache:
4855-
cache: dict[str, types.FunctionType] = {}
4856-
4857-
@classmethod
4858-
def get(cls, key: str) -> Optional[types.FunctionType]:
4859-
return cls.cache.get(key, None)
4860-
4861-
@classmethod
4862-
def set(cls, key: str, value: types.FunctionType) -> None:
4863-
cls.cache[key] = value
4864-
4865-
@classmethod
4866-
def clear(cls: type[CreateNestedFnCache]) -> None:
4867-
cls.cache.clear()
4868-
4869-
4870-
create_nested_fn_cache: CreateNestedFnCache = CreateNestedFnCache()

torch/_dynamo/variables/functions.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
check_unspec_or_constant_args,
7070
cmp_name_to_op_mapping,
7171
counters,
72-
create_nested_fn_cache,
7372
identity,
7473
is_function,
7574
is_wrapper_or_member_descriptor,
@@ -277,11 +276,6 @@ def _create_nested_fn(
277276
):
278277
from types import FunctionType
279278

280-
# Add caching for the actual IDs of user functions so that we can use them in the ID_MATCH guard.
281-
cache_key = str(id(code)) + str(id(closure)) + str(id(f_globals))
282-
if create_nested_fn_cache.get(cache_key):
283-
return create_nested_fn_cache.get(cache_key)
284-
285279
func = FunctionType(code, f_globals, name, defaults, closure)
286280
func.__kwdefaults__ = kwdefaults
287281

@@ -293,7 +287,7 @@ def _create_nested_fn(
293287
# TypeError: __annotations__ must be set to a dict object
294288
assert annotations is None or isinstance(annotations, dict)
295289
func.__annotations__ = annotations
296-
create_nested_fn_cache.set(cache_key, func)
290+
297291
return func
298292

299293

@@ -1472,13 +1466,7 @@ def as_python_constant(self):
14721466

14731467
@classmethod
14741468
def create_with_source(cls, value, source):
1475-
if inspect.getattr_static(value, "_torchdynamo_orig_callable", False):
1476-
install_guard(
1477-
AttrSource(source, "_torchdynamo_orig_callable").make_guard(
1478-
GuardBuilder.FUNCTION_MATCH
1479-
)
1480-
)
1481-
elif not is_wrapper_or_member_descriptor(value):
1469+
if not is_wrapper_or_member_descriptor(value):
14821470
# These descriptors are not guaranteed to return the same object on
14831471
# attribute lookup. They are unlikely to be changed, so we can skip
14841472
# guarding them.

0 commit comments

Comments
 (0)