Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Split implement exception into #203

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def start_compile(self, *ret_vars: VariableBase):
for ret_var in ret_vars
for ret_item in ret_var.flatten_items()
]
self.pycode_gen.gen_disable_eval_frame()
tensor_items = self._find_tensor_outputs(ret_items)
compiled_fn, statment_ir = self.sir_ctx.compile_fn(
[Symbol(tensor_var.var_name) for tensor_var in tensor_items]
Expand Down
53 changes: 23 additions & 30 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
BreakGraphError,
InnerError,
NotImplementException,
NotImplementFatal,
NotImplementInsignificant,
Singleton,
is_strict_mode,
log,
Expand Down Expand Up @@ -240,7 +242,7 @@ def inner(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
raise NotImplementException(
raise NotImplementFatal(
f'An exception occurred when processing break graph, fallback to dygraph, error message is: \n{type(e)} : {e}\n'
)

Expand Down Expand Up @@ -339,9 +341,7 @@ def step(self, instr: Instruction):
if instr.starts_line is not None:
self._current_line = instr.starts_line
if not hasattr(self, instr.opname):
raise NotImplementException(
f"opcode: {instr.opname} is not supported."
)
raise NotImplementFatal(f"opcode: {instr.opname} is not supported.")
log(
3,
f"[Translate {self._name}]: (line {self._current_line:>3}) {instr.opname:<12} {instr.argval}, stack is {self._stack}\n",
Expand Down Expand Up @@ -679,7 +679,7 @@ def CALL_FUNCTION(self, instr):
kwargs = {}
fn = self.pop()
if not isinstance(fn, CallableVariable):
raise NotImplementException(f"CALL_FUNCTION: {fn} is not callable")
raise NotImplementFatal(f"CALL_FUNCTION: {fn} is not callable")
ret = fn(*args, **kwargs)
self.push(ret)

Expand All @@ -703,9 +703,7 @@ def CALL_FUNCTION_KW(self, instr):

fn = self.pop()
if not isinstance(fn, CallableVariable):
raise NotImplementException(
f"CALL_FUNCTION_KW: {fn} is not callable."
)
raise NotImplementFatal(f"CALL_FUNCTION_KW: {fn} is not callable.")
ret = fn(*args, **kwargs)
self.push(ret)

Expand All @@ -724,9 +722,7 @@ def CALL_FUNCTION_EX(self, instr):

fn = self.pop()
if not isinstance(fn, CallableVariable):
raise NotImplementException(
f"CALL_FUNCTION_EX: {fn} is not callable."
)
raise NotImplementFatal(f"CALL_FUNCTION_EX: {fn} is not callable.")
ret = fn(*args, **kwargs)
self.push(ret)

Expand All @@ -753,7 +749,7 @@ def COMPARE_OP(self, instr):
)
return
except Exception as e:
raise NotImplementException(
raise NotImplementFatal(
f"{instr} is not support between {left} and {right}. may be not a supported compare op."
)

Expand Down Expand Up @@ -790,7 +786,7 @@ def MAKE_FUNCTION(self, instr):
related_list.append(self.pop())

if flag & MF.MF_HAS_KWDEFAULTS:
raise NotImplementException(
raise NotImplementFatal(
"Found need func_kwdefaults when MAKE_FUNCTION."
)

Expand Down Expand Up @@ -888,7 +884,7 @@ def JUMP_IF_FALSE_OR_POP(self, instr):
else:
self.pop()
return
raise NotImplementException(
raise NotImplementFatal(
"Currently don't support predicate a non-const / non-tensor obj."
)

Expand All @@ -903,7 +899,7 @@ def JUMP_IF_TRUE_OR_POP(self, instr):
else:
self.pop()
return
raise NotImplementException(
raise NotImplementFatal(
"Currently don't support predicate a non-const / non-tensor obj."
)

Expand All @@ -916,7 +912,7 @@ def POP_JUMP_IF_FALSE(self, instr):
if is_jump:
self._lasti = self.indexof(instr.jump_to)
return
raise NotImplementException(
raise NotImplementFatal(
"Currently don't support predicate a non-const / non-tensor obj."
)

Expand All @@ -929,7 +925,7 @@ def POP_JUMP_IF_TRUE(self, instr):
if is_jump:
self._lasti = self.indexof(instr.jump_to)
return
raise NotImplementException(
raise NotImplementFatal(
"Currently don't support predicate a non-const / non-tensor obj."
)

Expand All @@ -945,13 +941,11 @@ def UNPACK_SEQUENCE(self, instr):
'''
if isinstance(sequence, TensorVariable):
# TODO: If need to unpack a Tensor, should have different logic.
raise NotImplementException("Unpack a iterator is not implemented.")
raise NotImplementFatal("Unpack a tensor is not implemented.")
elif isinstance(sequence, (ListVariable, TupleVariable)):
seq = sequence.value
else:
raise NotImplementException(
f"Unpack {sequence} is not implemented."
)
raise NotImplementFatal(f"Unpack {sequence} is not implemented.")

assert (
len(seq) == instr.arg
Expand Down Expand Up @@ -1002,9 +996,7 @@ def FORMAT_VALUE(self, instr):
)
)
else:
raise NotImplementException(
f"Do not support format {type(value)} now"
)
raise NotImplementFatal(f"Do not support format {type(value)} now")

# NOTE: This operation will generate SideEffects, and the mechanism has not been completed yet
def DICT_UPDATE(self, instr):
Expand Down Expand Up @@ -1118,7 +1110,7 @@ def _break_graph_in_jump(self, result, instr):
for name in if_inputs:
self.get_var(name).reconstruct(self._graph.pycode_gen)
self._graph.pycode_gen.gen_call_function(
argc=if_fn.__code__.co_argcount
argc=if_fn.__code__.co_argcount, enable_evalframe=True
)
self._graph.pycode_gen.gen_return()
else:
Expand All @@ -1135,7 +1127,7 @@ def _break_graph_in_jump(self, result, instr):
for name in else_inputs:
self.get_var(name).reconstruct(self._graph.pycode_gen)
self._graph.pycode_gen.gen_call_function(
argc=else_fn.__code__.co_argcount
argc=else_fn.__code__.co_argcount, enable_evalframe=True
)
self._graph.pycode_gen.gen_return()
else:
Expand Down Expand Up @@ -1197,7 +1189,7 @@ def _break_graph_in_call(self, origin_stack, instr, push_n):
for name in resume_input_name:
self._locals[name].reconstruct(self._graph.pycode_gen)
self._graph.pycode_gen.gen_call_function(
argc=resume_fn.__code__.co_argcount
argc=resume_fn.__code__.co_argcount, enable_evalframe=True
)

# gen RETURN_VALUE
Expand Down Expand Up @@ -1299,7 +1291,8 @@ def update_locals(name, variable):

# 5.4 call loop body
self._graph.pycode_gen.gen_call_function(
argc=loop_body.__code__.co_argcount
argc=loop_body.__code__.co_argcount,
enable_evalframe=True,
)

# 5.5 unpack and store retval, keep break_flag in stack
Expand Down Expand Up @@ -1328,7 +1321,7 @@ def update_locals(name, variable):
self._graph.pycode_gen.gen_load_fast(name)

self._graph.pycode_gen.gen_call_function(
argc=after_loop_fn.__code__.co_argcount
argc=after_loop_fn.__code__.co_argcount, enable_evalframe=True
)

self._graph.pycode_gen.gen_return()
Expand Down Expand Up @@ -1359,7 +1352,7 @@ def FOR_ITER(self, instr):
end = self.indexof(instr.jump_to)
for i in range(start, end):
if self._instructions[i].opname == "RETURN_VALUE":
raise NotImplementException(
raise NotImplementInsignificant(
"Found RETURN_VALUE in for loop body."
)

Expand Down
52 changes: 51 additions & 1 deletion sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import opcode

import paddle

from ...utils import (
ResumeFnNameFactory,
list_contain_by_id,
Expand Down Expand Up @@ -352,6 +354,52 @@ def gen_load_const(self, value):
idx = list_find_index_by_id(self._code_options["co_consts"], value)
self._add_instr("LOAD_CONST", arg=idx, argval=value)

def gen_disable_eval_frame(self):
self.gen_load_object(
paddle.fluid.core.set_eval_frame, "paddle_set_eval_frame_function"
)
self.gen_load_const(None)
self.gen_call_function(1)
self.gen_store_fast("paddle_old_eval_frame_fn")

def gen_enable_eval_frame(self):
self.gen_load_object(
paddle.fluid.core.set_eval_frame, "paddle_set_eval_frame_function"
)
self.gen_load_fast("paddle_old_eval_frame_fn")
self.gen_call_function(1)
self.gen_pop_top()

def gen_print_log(self, message):
"""print a log :"""
self.gen_disable_eval_frame()
self.gen_load_global("print")
self.gen_load_const(message)
self.gen_call_function(1)
self.gen_enable_eval_frame()

def gen_dbg_function(self, dbg_fun):
"""debug bytecode helper function.
Usage like:
def dbg_func():
import inspect
import dis
print("dbg here.")
print(locals())
dis.dis(inspect.currentframe().f_back.f_code)
frame = inspect.currentframe().f_back
code = (inspect.currentframe().f_back.f_code)
breakpoint()
print(inspect.currentframe().f_back.f_locals['y'])

self.pycode_gen.gen_dbg_function(dbg_func)
"""
self.gen_disable_eval_frame()
self.gen_load_object(dbg_fun, "dbg1")
self.gen_call_function(0)
self.gen_pop_top()
self.gen_enable_eval_frame()

def gen_load_global(self, name):
if name not in self._code_options["co_names"]:
self._code_options["co_names"].append(name)
Expand Down Expand Up @@ -424,7 +472,9 @@ def gen_build_map(self, count):
def gen_unpack_sequence(self, count):
self._add_instr("UNPACK_SEQUENCE", arg=count, argval=count)

def gen_call_function(self, argc=0):
def gen_call_function(self, argc=0, enable_evalframe=False):
if enable_evalframe:
self.gen_enable_eval_frame()
self._add_instr("CALL_FUNCTION", arg=argc, argval=argc)

def gen_pop_top(self):
Expand Down
10 changes: 5 additions & 5 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import paddle

from ...utils import BreakGraphError, NotImplementException
from ...utils import BreakGraphError, InnerError, NotImplementInsignificant
from ...utils.magic_methods import (
BINARY_OPS,
UNARY_OPS,
Expand Down Expand Up @@ -250,8 +250,8 @@ def tensor_mod_dispatcher(
raise BreakGraphError(
"(ConstantVariable % TensorVariable) raise a callback. "
)
raise NotImplementException(
"Tensor doesn't support __rmod__"
raise InnerError(
"TypeError: unsupported operand type(s) for %: 'int' and 'Tensor'"
)

else:
Expand All @@ -275,7 +275,7 @@ def tensor_mod_dispatcher(

@Dispatcher.register_decorator(unary_fn)
def numpy_unary_dispatcher(var: NumpyVariable):
raise NotImplementException(
raise NotImplementInsignificant(
'Numpy operator need fallback to dygraph'
)

Expand All @@ -285,6 +285,6 @@ def numpy_unary_dispatcher(var: NumpyVariable):

@Dispatcher.register_decorator(binary_fn)
def numpy_binary_dispatcher(var: NumpyVariable, other: NumpyVariable):
raise NotImplementException(
raise NotImplementInsignificant(
'Numpy operator need fallback to dygraph'
)
6 changes: 3 additions & 3 deletions sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import paddle

from ....utils import NameGenerator, get_unbound_method, log, log_do
from ....utils.exceptions import InnerError, NotImplementException
from ....utils.exceptions import InnerError, NotImplementFatal
from ..guard import StringifyExpression, union_free_vars
from ..pycode_generator import PyCodeGen
from ..tracker import DummyTracker, GetAttrTracker, GetItemTracker, Tracker
Expand Down Expand Up @@ -224,7 +224,7 @@ def reconstruct(self, codegen: PyCodeGen):
self._reconstruct(codegen)

def _reconstruct(self, codegen: PyCodeGen):
raise NotImplementException()
raise NotImplementFatal("Not implement reconstruct.")

def flatten_items(self) -> list[VariableBase]:
from .container import ContainerVariable
Expand Down Expand Up @@ -281,7 +281,7 @@ def getattr(self, name: str):
)

def __setitem__(self, key, value):
raise NotImplementException(f"{self} is not support setitem.")
raise NotImplementFatal(f"{self} is not support setitem.")

def __repr__(self):
info = {**self.main_info, **self.debug_info}
Expand Down
4 changes: 2 additions & 2 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ....utils import (
BreakGraphError,
NameGenerator,
NotImplementException,
NotImplementInsignificant,
log_do,
paddle_tensor_methods,
)
Expand Down Expand Up @@ -453,7 +453,7 @@ def format_number(number: np.number):
union_free_vars(frame_value_tracer.free_vars, {"np": np}),
)
else:
raise NotImplementException(
raise NotImplementInsignificant(
"We can not stringify numpy variable when value is np.ndarray"
)

Expand Down
10 changes: 6 additions & 4 deletions sot/opcode_translator/executor/variables/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any

from ....utils import log_do
from ....utils.exceptions import InnerError, NotImplementException
from ....utils.exceptions import InnerError, NotImplementFatal
from ..guard import StringifyExpression
from ..pycode_generator import PyCodeGen
from ..tracker import (
Expand All @@ -24,10 +24,12 @@

class ContainerVariable(VariableBase):
def get_items(self) -> list[VariableBase]:
raise NotImplementException()
raise NotImplementFatal(
"Not implement get_items for container variable."
)

def __len__(self):
raise NotImplementException()
raise NotImplementFatal("Not implement __len__ for container variable.")

def len(self):
return VariableFactory.from_value(
Expand Down Expand Up @@ -401,7 +403,7 @@ def getattr(self, name):
builtin_fn, self.graph, DanglingTracker()
).bind(self, name)
else:
raise NotImplementException(
raise NotImplementFatal(
f"attribute {name} for dict is not implemented"
)

Expand Down
Loading