Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Typing] Fix some typehint and add more assert #284

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 67 additions & 26 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import traceback
import types
from itertools import chain
from typing import Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple

from ...utils import (
BreakGraphError,
Expand Down Expand Up @@ -72,16 +72,22 @@
VariableFactory,
)

if TYPE_CHECKING:
from typing import TypeVar

VariableT = TypeVar("VariableT", bound=VariableBase)

GuardedFunction = Tuple[types.CodeType, Guard]
GuardedFunctions = List[GuardedFunction]
CacheGetter = Callable[
[types.FrameType, GuardedFunctions], Optional["CustomCode"]
]

CustomCode = collections.namedtuple(
"CustomCode", ["code", "disable_eval_frame"]
)


GuardedFunction = Tuple[types.CodeType, Guard]
GuardedFunctions = List[GuardedFunction]
CacheGetter = Callable[
[types.FrameType, GuardedFunctions], Optional[CustomCode]
]
dummy_guard: Guard = lambda frame: True

SUPPORT_COMPARE_OP = {
Expand Down Expand Up @@ -634,15 +640,17 @@ def indexof(self, instr: Instruction):
"""
return self._instructions.index(instr)

def pop(self) -> VariableBase:
def pop(self, *, var_type: type[VariableT] = VariableBase) -> VariableT:
"""
Pops the top value from the stack.

Returns:
The popped value.

"""
return self._stack.pop()
var = self._stack.pop()
assert isinstance(var, var_type)
return var

def peek(self) -> VariableBase:
"""
Expand Down Expand Up @@ -995,7 +1003,7 @@ def BUILD_CONST_KEY_MAP(self, instr: Instruction):
assert map_size + 1 <= len(
self._stack
), f"OpExecutor want BUILD_CONST_KEY_MAP with size {map_size} + 1, but current stack do not have enough elems."
keys = self.pop().get_items()
keys = self.pop(var_type=ContainerVariable).get_items()
assert len(keys) == map_size
values = self.pop_n(map_size)
self.push(self.build_map(keys, values))
Expand Down Expand Up @@ -1241,10 +1249,14 @@ def GET_ITER(self, instr: Instruction):
)

def JUMP_FORWARD(self, instr):
self._lasti = self.indexof(instr.jump_to)
self._lasti = self.indexof(
instr.safe_getattr("jump_to", var_type=Instruction)
)

def JUMP_ABSOLUTE(self, instr: Instruction):
self._lasti = self.indexof(instr.jump_to)
self._lasti = self.indexof(
instr.safe_getattr("jump_to", var_type=Instruction)
)

def CONTAINS_OP(self, instr: Instruction):
# It will only be 0 or 1
Expand All @@ -1264,7 +1276,9 @@ def JUMP_IF_FALSE_OR_POP(self, instr: Instruction):
self._graph.add_global_guarded_variable(pred_obj)
is_jump = not bool(pred_obj)
if is_jump:
self._lasti = self.indexof(instr.jump_to)
self._lasti = self.indexof(
instr.safe_getattr("jump_to", var_type=Instruction)
)
else:
self.pop()
return
Expand All @@ -1279,7 +1293,9 @@ def JUMP_IF_TRUE_OR_POP(self, instr: Instruction):
self._graph.add_global_guarded_variable(pred_obj)
is_jump = bool(pred_obj)
if is_jump:
self._lasti = self.indexof(instr.jump_to)
self._lasti = self.indexof(
instr.safe_getattr("jump_to", var_type=Instruction)
)
else:
self.pop()
return
Expand All @@ -1294,7 +1310,9 @@ def POP_JUMP_IF_FALSE(self, instr: Instruction):
self._graph.add_global_guarded_variable(pred_obj)
is_jump = not bool(pred_obj)
if is_jump:
self._lasti = self.indexof(instr.jump_to)
self._lasti = self.indexof(
instr.safe_getattr("jump_to", var_type=Instruction)
)
return
raise NotImplementException(
"Currently don't support predicate a non-const / non-tensor obj."
Expand All @@ -1307,7 +1325,9 @@ def POP_JUMP_IF_TRUE(self, instr: Instruction):
self._graph.add_global_guarded_variable(pred_obj)
is_jump = bool(pred_obj)
if is_jump:
self._lasti = self.indexof(instr.jump_to)
self._lasti = self.indexof(
instr.safe_getattr("jump_to", var_type=Instruction)
)
return
raise NotImplementException(
"Currently don't support predicate a non-const / non-tensor obj."
Expand Down Expand Up @@ -1389,7 +1409,8 @@ def DICT_UPDATE(self, instr: Instruction):
)

def DICT_MERGE(self, instr: Instruction):
dict_value = self.pop()
# TODO: self._stack[index] should be replaced?
dict_value = self.pop(var_type=DictVariable)
assert instr.arg > 0
for key in dict_value.get_wrapped_items().keys():
result = self._stack[-instr.arg].get_wrapped_items().get(key, None)
Expand All @@ -1416,7 +1437,10 @@ def LIST_EXTEND(self, instr: Instruction):
)

def LIST_TO_TUPLE(self, instr: Instruction):
list_value = self.pop()
# TODO(zrr1999): I think list_value should a ListVariable instance,
# but return_value of get_wrapped_items method in ListVariable is a list instead of tuple.
# list_value = self.pop(var_type=ListVariable)
list_value = self.pop(var_type=ContainerVariable)
self.push(
TupleVariable(
list_value.get_wrapped_items(),
Comment on lines +1440 to 1446
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里TupleVariable传入的参数通常是list,但是实际上要求是tuple,这块是不是应该对TupleVariable修改一下改成iterable,感觉上应该没有什么影响。或者是这里套一个tuple

Expand Down Expand Up @@ -1517,7 +1541,8 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction):
self.indexof(instr) + 1, stack_size
)
else_fn, else_inputs = self._create_resume_fn(
self.indexof(instr.jump_to), stack_size
self.indexof(instr.safe_getattr("jump_to", var_type=Instruction)),
stack_size,
)

# gen call static fn opcode
Expand Down Expand Up @@ -1718,11 +1743,18 @@ def _break_graph_in_for_loop(

pycode_gen = PyCodeGen(self._frame)
loop_body, loop_inputs = pycode_gen.gen_loop_body_between(
for_iter, loop_body_start_idx, self.indexof(for_iter.jump_to)
for_iter,
loop_body_start_idx,
self.indexof(
for_iter.safe_getattr("jump_to", var_type=Instruction)
),
)

after_loop_fn, fn_inputs = self._create_resume_fn(
self.indexof(for_iter.jump_to), len(self._stack)
self.indexof(
for_iter.safe_getattr("jump_to", var_type=Instruction)
),
len(self._stack),
)

total_inputs = OrderedSet(list(fn_inputs) + list(loop_inputs))
Expand Down Expand Up @@ -1822,7 +1854,9 @@ def _inline_call_for_loop(
origin_instrs = get_instructions(pycode_gen._origin_code)

start_idx = self.indexof(for_iter)
end_idx = self.indexof(for_iter.jump_to)
end_idx = self.indexof(
for_iter.safe_getattr("jump_to", var_type=Instruction)
)

inputs = list(
analysis_inputs_outputs(origin_instrs, start_idx, end_idx)
Expand Down Expand Up @@ -1850,7 +1884,10 @@ def _inline_call_for_loop(

if (
instr.jump_to in origin_instrs
and origin_instrs.index(instr.jump_to) >= end_idx
and origin_instrs.index(
instr.safe_getattr("jump_to", var_type=Instruction)
)
>= end_idx
):
instr.jump_to = nop_for_break

Expand Down Expand Up @@ -1899,19 +1936,19 @@ def STORE_ATTR(self, instr):
)

def FOR_ITER(self, instr):
iterator = self.pop()
iterator = self.pop(var_type=IterVariable)
backup_iter_idx = None

start = self.indexof(instr)
end = self.indexof(instr.jump_to)
end = self.indexof(instr.safe_getattr("jump_to", var_type=Instruction))
for i in range(start, end):
if self._instructions[i].opname == "RETURN_VALUE":
raise NotImplementException(
"Found RETURN_VALUE in for loop body."
)

self._graph.add_global_guarded_variable(iterator)
# TODO need support TensorIterVariable.next
# TODO: need support TensorIterVariable.next

try:
if not isinstance(
Expand All @@ -1923,8 +1960,12 @@ def FOR_ITER(self, instr):
backup_iter_idx = iterator.idx

self._inline_call_for_loop(iterator, instr)
self._lasti = self.indexof(instr.jump_to)
self._lasti = self.indexof(
instr.safe_getattr("jump_to", var_type=Instruction)
)
except BreakGraphError as e:
# TODO: backup_iter_idx is not None?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 backup_iter_idx == 0 的情况我感觉应该也进入分支,这样的话应该是if backup_iter_idx is not None。

# TODO: idx is not a member of IterVariable
if backup_iter_idx:
iterator.idx = backup_iter_idx
self._graph.remove_global_guarded_variable(iterator)
Expand Down
10 changes: 7 additions & 3 deletions sot/opcode_translator/instruction_utils/instruction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import dataclasses
import dis
import sys
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from .opcode_info import ABS_JUMP, ALL_JUMP, REL_JUMP

if TYPE_CHECKING:
import types
from typing import Any


@dataclasses.dataclass
Expand Down Expand Up @@ -290,12 +291,15 @@ def calc_offset_from_bytecode_offset(bytecode_offset: int) -> int:
return bytecode_offset // 2


def replace_instr(instructions, instr, new_instr):
def replace_instr(
instructions: list[Instruction], instr: Instruction, new_instr
):
idx = instructions.index(instr)
# TODO: maybe new_instr is a lsit?
instructions[idx : idx + 1] = new_instr


def instrs_info(instrs, mark=None, range=None):
def instrs_info(instrs: dict[str, Instruction], mark=None, range=None):
ret = []
start = -1
end = 1000000
Expand Down
5 changes: 3 additions & 2 deletions sot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def __init__(self):
def clear(self):
self.graph_num = 0
self.op_num = 0
self.graphs: list = []
self.ops: list = []
self.graphs = []
self.ops = []

def get_graph_num(self):
return self.graph_num
Expand All @@ -257,6 +257,7 @@ def add_subgraph(self, program: Program):
for op in block.ops:
self.op_num += 1
sub_op.append(op)
# TODO: self.ops is a list, and sub_op is a list?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里逻辑似乎有一些问题,我不确定是否应该改成self.ops.extend(sub_op),还是说ops应该是list[list[paddle.fluid.framework.Operator]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是 list[list[paddle.fluid.framework.Operator]],之前我没仔细看代码,加错了 😂,这段逻辑我最近需要稍微变动下,稍后的 PR 应该会一并改一下

self.ops.append(sub_op)

def add_subgprah_info(self, strs):
Expand Down