Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

update eval_frame_callback after move part of callback to c++ #419

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sot/opcode_translator/skip_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,13 @@ def need_skip(frame):
if isinstance(_filename, str):
filename = _filename
return need_skip_path(filename)


with_graph_codes = (
paddle.nn.Layer.__call__.__code__,
paddle.nn.Layer._dygraph_call_func.__code__,
)

paddle.framework.core.eval_frame_no_skip_codes(tuple(no_skip_code))
paddle.framework.core.eval_frame_skip_file_prefix(tuple(skip_file_names))
paddle.framework.core.sot_setup_codes_with_graph(with_graph_codes)
62 changes: 14 additions & 48 deletions sot/opcode_translator/transform.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import annotations

import dis
import sys
from functools import partial

from ..profiler import EventGuard
from ..utils import CodeStatus, log, log_do
from ..utils import log, log_do
from .custom_code import CustomCode
from .executor.executor_cache import OpcodeExecutorCache
from .skip_files import need_skip


def print_locals(frame):
Expand Down Expand Up @@ -41,58 +39,26 @@ def eval_frame_callback(frame, **kwargs) -> CustomCode:
with EventGuard(
f"eval_frame_callback: {frame.f_code.co_name}", event_level=2
):
# is generator
if frame.f_code.co_flags & 0x20 > 0:
return CustomCode(None, True)
log(2, f"[eval_frame_callback] start to translate: {frame.f_code}\n")
log_do(4, partial(print_locals, frame))

# NOTE(SigureMo): Temporary fallback when code has exception handling.
if sys.version_info >= (3, 11) and frame.f_code.co_exceptiontable:
log(3, f"[transform] OriginCode: {frame.f_code.co_name}\n")
log_do(3, lambda: dis.dis(frame.f_code))

custom_code = OpcodeExecutorCache()(frame, **kwargs)

if custom_code.code is None:
log(
3,
f"[eval_frame_callback] {frame.f_code} has co_exceptiontable\n",
"[transform] NewCode (same as origin code): "
+ frame.f_code.co_name
+ "\n",
)
return CustomCode(None, False)

if need_skip(frame):
log(3, f"[eval_frame_callback] skip {frame.f_code}\n")
custom_code = CustomCode(None, False)
new_code = frame.f_code
else:
log(
2, f"[eval_frame_callback] start to translate: {frame.f_code}\n"
)
log_do(4, partial(print_locals, frame))

log(3, f"[transform] OriginCode: {frame.f_code.co_name}\n")
log_do(3, lambda: dis.dis(frame.f_code))

custom_code = OpcodeExecutorCache()(frame, **kwargs)

if custom_code.code is None:
log(
3,
"[transform] NewCode (same as origin code): "
+ frame.f_code.co_name
+ "\n",
)
new_code = frame.f_code
else:
log(
3,
"[transform] NewCode: " + custom_code.code.co_name + "\n",
)
log_do(3, lambda: dis.dis(custom_code.code))
new_code = custom_code.code

# just check those codes which need open eval_frame
if (
custom_code.disable_eval_frame is False
and CodeStatus().is_code_without_graph(new_code)
):
log(
3,
"[eval_frame_callback] Code has no graph, block it.\n",
"[transform] NewCode: " + custom_code.code.co_name + "\n",
)
return CustomCode(None, True)
log_do(3, lambda: dis.dis(custom_code.code))

return custom_code
20 changes: 11 additions & 9 deletions sot/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING

import paddle

from ..profiler import EventGuard
from ..utils import (
Cache,
CodeStatus,
GraphLogger,
Singleton,
StepInfoManager,
log_do,
)
from ..utils import Cache, GraphLogger, Singleton, StepInfoManager, log_do
from .interpreter import compile_sir

if TYPE_CHECKING:
Expand All @@ -24,6 +18,14 @@ def clear_eager_tensor_name(output_tensors):
output_tensor.name = ""


def trace_back_frames():
frame = inspect.currentframe()
while frame.f_back is not None:
frame = frame.f_back
code = frame.f_code
paddle.framework.core.sot_set_with_graph(code)


class FallbackWrapper:
"""
Used to store and call static graph methods generated by paddle.jit.to_static
Expand All @@ -38,7 +40,7 @@ def __init__(self, compiled_fn, SIR):
def __call__(self, *args, **kwargs):
with EventGuard(f"FallbackWrapper: {self.SIR.name}"):
if StepInfoManager().need_back_trace:
CodeStatus().trace_back_frames()
trace_back_frames()

log_do(
2,
Expand Down
1 change: 0 additions & 1 deletion sot/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .code_status import CodeStatus # noqa: F401
from .exceptions import ( # noqa: F401
BreakGraphError,
FallbackError,
Expand Down
76 changes: 0 additions & 76 deletions sot/utils/code_status.py

This file was deleted.

Loading
Loading