diff --git a/vyper/venom/analysis/cfg.py b/vyper/venom/analysis/cfg.py index 2a521ab131..6bd7e538e9 100644 --- a/vyper/venom/analysis/cfg.py +++ b/vyper/venom/analysis/cfg.py @@ -10,12 +10,12 @@ class CFGAnalysis(IRAnalysis): def analyze(self) -> None: fn = self.function - for bb in fn.basic_blocks: + for bb in fn.get_basic_blocks(): bb.cfg_in = OrderedSet() bb.cfg_out = OrderedSet() bb.out_vars = OrderedSet() - for bb in fn.basic_blocks: + for bb in fn.get_basic_blocks(): assert len(bb.instructions) > 0, "Basic block should not be empty" last_inst = bb.instructions[-1] assert ( @@ -29,7 +29,7 @@ def analyze(self) -> None: fn.get_basic_block(op.value).add_cfg_in(bb) # Fill in the "out" set for each basic block - for bb in fn.basic_blocks: + for bb in fn.get_basic_blocks(): for in_bb in bb.cfg_in: in_bb.add_cfg_out(bb) diff --git a/vyper/venom/analysis/dfg.py b/vyper/venom/analysis/dfg.py index 8b113e74bc..dc7076d5de 100644 --- a/vyper/venom/analysis/dfg.py +++ b/vyper/venom/analysis/dfg.py @@ -33,7 +33,7 @@ def analyze(self): # %16 = iszero %15 # dfg_outputs of %15 is (%15 = add %13 %14) # dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...] - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): for inst in bb.instructions: operands = inst.get_inputs() res = inst.get_outputs() diff --git a/vyper/venom/analysis/dominators.py b/vyper/venom/analysis/dominators.py index c0b149d880..129d1d0f22 100644 --- a/vyper/venom/analysis/dominators.py +++ b/vyper/venom/analysis/dominators.py @@ -153,7 +153,7 @@ def as_graph(self) -> str: Generate a graphviz representation of the dominator tree. """ lines = ["digraph dominator_tree {"] - for bb in self.fn.basic_blocks: + for bb in self.fn.get_basic_blocks(): if bb == self.entry_block: continue idom = self.immediate_dominator(bb) diff --git a/vyper/venom/analysis/dup_requirements.py b/vyper/venom/analysis/dup_requirements.py index 015c7c5871..3452bc2e0f 100644 --- a/vyper/venom/analysis/dup_requirements.py +++ b/vyper/venom/analysis/dup_requirements.py @@ -4,7 +4,7 @@ class DupRequirementsAnalysis(IRAnalysis): def analyze(self): - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): last_liveness = bb.out_vars for inst in reversed(bb.instructions): inst.dup_requirements = OrderedSet() diff --git a/vyper/venom/analysis/liveness.py b/vyper/venom/analysis/liveness.py index 95853e57aa..5e78aa4ff3 100644 --- a/vyper/venom/analysis/liveness.py +++ b/vyper/venom/analysis/liveness.py @@ -15,7 +15,7 @@ def analyze(self): self._reset_liveness() while True: changed = False - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): changed |= self._calculate_out_vars(bb) changed |= self._calculate_liveness(bb) @@ -23,7 +23,7 @@ def analyze(self): break def _reset_liveness(self) -> None: - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): bb.out_vars = OrderedSet() for inst in bb.instructions: inst.liveness = OrderedSet() diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 556be28246..eace17af0d 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -13,57 +13,42 @@ class IRFunction: name: IRLabel # symbol name ctx: "IRContext" # type: ignore # noqa: F821 args: list - basic_blocks: list[IRBasicBlock] last_label: int last_variable: int + _basic_block_dict: dict[str, IRBasicBlock] # Used during code generation _ast_source_stack: list[IRnode] _error_msg_stack: list[str] - _bb_index: dict[str, int] def __init__(self, name: IRLabel, ctx: "IRContext" = None) -> None: # type: ignore # noqa: F821 self.ctx = ctx self.name = name self.args = [] - self.basic_blocks = [] + self._basic_block_dict = {} self.last_variable = 0 self._ast_source_stack = [] self._error_msg_stack = [] - self._bb_index = {} self.append_basic_block(IRBasicBlock(name, self)) @property def entry(self) -> IRBasicBlock: - return self.basic_blocks[0] + return next(self.get_basic_blocks()) - def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: + def append_basic_block(self, bb: IRBasicBlock): """ Append basic block to function. """ - assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'" - self.basic_blocks.append(bb) - - return self.basic_blocks[-1] - - def _get_basicblock_index(self, label: str): - # perf: keep an "index" of labels to block indices to - # perform fast lookup. - # TODO: maybe better just to throw basic blocks in an ordered - # dict of some kind. - ix = self._bb_index.get(label, -1) - if 0 <= ix < len(self.basic_blocks) and self.basic_blocks[ix].label == label: - return ix - # do a reindex - self._bb_index = dict((bb.label.name, ix) for ix, bb in enumerate(self.basic_blocks)) - # sanity check - no duplicate labels - assert len(self._bb_index) == len( - self.basic_blocks - ), f"Duplicate labels in function '{self.name}' {self._bb_index} {self.basic_blocks}" - return self._bb_index[label] + assert isinstance(bb, IRBasicBlock), bb + assert bb.label.name not in self._basic_block_dict + self._basic_block_dict[bb.label.name] = bb + + def remove_basic_block(self, bb: IRBasicBlock): + assert isinstance(bb, IRBasicBlock), bb + del self._basic_block_dict[bb.label.name] def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ @@ -71,33 +56,31 @@ def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: If label is None, return the last basic block. """ if label is None: - return self.basic_blocks[-1] - ix = self._get_basicblock_index(label) - return self.basic_blocks[ix] + return next(reversed(self._basic_block_dict.values())) + + return self._basic_block_dict[label] + + def clear_basic_blocks(self): + self._basic_block_dict.clear() - def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: + def get_basic_blocks(self) -> Iterator[IRBasicBlock]: """ - Get basic block after label. + Get an iterator over this function's basic blocks """ - ix = self._get_basicblock_index(label.value) - if 0 <= ix < len(self.basic_blocks) - 1: - return self.basic_blocks[ix + 1] - raise AssertionError(f"Basic block after '{label}' not found") + return iter(self._basic_block_dict.values()) + + @property + def num_basic_blocks(self) -> int: + return len(self._basic_block_dict) def get_terminal_basicblocks(self) -> Iterator[IRBasicBlock]: """ Get basic blocks that are terminal. """ - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): if bb.is_terminal: yield bb - def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: - """ - Get basic blocks that point to the given basic block - """ - return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] - def get_next_variable(self) -> IRVariable: self.last_variable += 1 return IRVariable(f"%{self.last_variable}") @@ -109,15 +92,14 @@ def remove_unreachable_blocks(self) -> int: self._compute_reachability() removed = [] - new_basic_blocks = [] # Remove unreachable basic blocks - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): if not bb.is_reachable: removed.append(bb) - else: - new_basic_blocks.append(bb) - self.basic_blocks = new_basic_blocks + + for bb in removed: + self.remove_basic_block(bb) # Remove phi instructions that reference removed basic blocks for bb in removed: @@ -142,7 +124,7 @@ def _compute_reachability(self) -> None: """ Compute reachability of basic blocks. """ - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): bb.reachable = OrderedSet() bb.is_reachable = False @@ -172,7 +154,7 @@ def normalized(self) -> bool: Having a normalized CFG makes calculation of stack layout easier when emitting assembly. """ - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): # Ignore if there are no multiple predecessors if len(bb.cfg_in) <= 1: continue @@ -211,22 +193,23 @@ def chain_basic_blocks(self) -> None: Otherwise, append a stop instruction. This is necessary for the IR to be valid, and is done after the IR is generated. """ - for i, bb in enumerate(self.basic_blocks): + bbs = list(self.get_basic_blocks()) + for i, bb in enumerate(bbs): if not bb.is_terminated: - if len(self.basic_blocks) - 1 > i: + if i < len(bbs) - 1: # TODO: revisit this. When contructor calls internal functions they # are linked to the last ctor block. Should separate them before this # so we don't have to handle this here - if self.basic_blocks[i + 1].label.value.startswith("internal"): + if bbs[i + 1].label.value.startswith("internal"): bb.append_instruction("stop") else: - bb.append_instruction("jmp", self.basic_blocks[i + 1].label) + bb.append_instruction("jmp", bbs[i + 1].label) else: bb.append_instruction("exit") def copy(self): new = IRFunction(self.name) - new.basic_blocks = self.basic_blocks.copy() + new._basic_block_dict = self._basic_block_dict.copy() new.last_label = self.last_label new.last_variable = self.last_variable return new @@ -246,11 +229,11 @@ def _make_label(bb): ret = "digraph G {\n" - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): for out_bb in bb.cfg_out: ret += f' "{bb.label.value}" -> "{out_bb.label.value}"\n' - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): ret += f' "{bb.label.value}" [shape=plaintext, ' ret += f'label={_make_label(bb)}, fontname="Courier" fontsize="8"]\n' @@ -259,6 +242,6 @@ def _make_label(bb): def __repr__(self) -> str: str = f"IRFunction: {self.name}\n" - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): str += f"{bb}\n" return str.strip() diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index b4465e9f7b..61b3c081ff 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -135,10 +135,9 @@ def _append_jmp(fn: IRFunction, label: IRLabel) -> None: bb.append_instruction("jmp", label) -def _new_block(fn: IRFunction) -> IRBasicBlock: +def _new_block(fn: IRFunction) -> None: bb = IRBasicBlock(fn.ctx.get_next_label(), fn) - bb = fn.append_basic_block(bb) - return bb + fn.append_basic_block(bb) def _append_return_args(fn: IRFunction, ofst: int = 0, size: int = 0): @@ -328,7 +327,7 @@ def _convert_ir_bb(fn, ir, symbols): # exit bb exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), fn) - exit_bb = fn.append_basic_block(exit_bb) + fn.append_basic_block(exit_bb) if_ret = fn.get_next_variable() if then_ret_val is not None and else_ret_val is not None: diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index e4e27ed813..06366e4336 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -74,8 +74,8 @@ def run_pass(self) -> None: self.fence_id = 0 self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() - basic_blocks = self.function.basic_blocks + basic_blocks = list(self.function.get_basic_blocks()) - self.function.basic_blocks = [] + self.function.clear_basic_blocks() for bb in basic_blocks: self._process_basic_block(bb) diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 83c565b1be..cf44c3cf89 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -58,7 +58,7 @@ def _run_pass(self) -> int: self.analyses_cache.request_analysis(CFGAnalysis) # Split blocks that need splitting - for bb in fn.basic_blocks: + for bb in list(fn.get_basic_blocks()): if len(bb.cfg_in) > 1: self._split_basic_block(bb) @@ -71,7 +71,7 @@ def _run_pass(self) -> int: def run_pass(self): fn = self.function - for _ in range(len(fn.basic_blocks) * 2): + for _ in range(fn.num_basic_blocks * 2): if self._run_pass() == 0: break else: diff --git a/vyper/venom/passes/remove_unused_variables.py b/vyper/venom/passes/remove_unused_variables.py index b7fb3abbf0..a4cd737e98 100644 --- a/vyper/venom/passes/remove_unused_variables.py +++ b/vyper/venom/passes/remove_unused_variables.py @@ -9,7 +9,7 @@ def run_pass(self): self.analyses_cache.request_analysis(LivenessAnalysis) - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): for i, inst in enumerate(bb.instructions[:-1]): if inst.volatile: continue diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index 7f3fc7e03e..577030dea6 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -87,7 +87,7 @@ def _calculate_sccp(self, entry: IRBasicBlock): and the work list. The `_propagate_constants()` method is responsible for updating the IR with the constant values. """ - self.cfg_in_exec = {bb: OrderedSet() for bb in self.fn.basic_blocks} + self.cfg_in_exec = {bb: OrderedSet() for bb in self.fn.get_basic_blocks()} dummy = IRBasicBlock(IRLabel("__dummy_start"), self.fn) self.work_list.append(FlowWorkItem(dummy, entry)) diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py index bb5233eba0..08582fee96 100644 --- a/vyper/venom/passes/simplify_cfg.py +++ b/vyper/venom/passes/simplify_cfg.py @@ -30,7 +30,7 @@ def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock): break inst.operands[inst.operands.index(b.label)] = a.label - self.function.basic_blocks.remove(b) + self.function.remove_basic_block(b) def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): next_bb = b.cfg_out.first() @@ -44,7 +44,7 @@ def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): next_bb.remove_cfg_in(b) next_bb.add_cfg_in(a) - self.function.basic_blocks.remove(b) + self.function.remove_basic_block(b) def _collapse_chained_blocks_r(self, bb: IRBasicBlock): """ @@ -87,31 +87,32 @@ def _optimize_empty_basicblocks(self) -> int: Remove empty basic blocks. """ fn = self.function - count = 0 - i = 0 - while i < len(fn.basic_blocks): - bb = fn.basic_blocks[i] + worklist = list(fn.get_basic_blocks()) + i = count = 0 + while i < len(worklist): + bb = worklist[i] i += 1 + if len(bb.instructions) > 0: continue + next_bb = worklist[i] + replaced_label = bb.label - replacement_label = fn.basic_blocks[i].label if i < len(fn.basic_blocks) else None - if replacement_label is None: - continue + replacement_label = next_bb.label # Try to preserve symbol labels if replaced_label.is_symbol: replaced_label, replacement_label = replacement_label, replaced_label - fn.basic_blocks[i].label = replacement_label + next_bb.label = replacement_label - for bb2 in fn.basic_blocks: + for bb2 in fn.get_basic_blocks(): for inst in bb2.instructions: for op in inst.operands: if isinstance(op, IRLabel) and op.value == replaced_label.value: op.value = replacement_label.value - fn.basic_blocks.remove(bb) + fn.remove_basic_block(bb) i -= 1 count += 1 @@ -121,7 +122,7 @@ def run_pass(self): fn = self.function entry = fn.entry - for _ in range(len(fn.basic_blocks)): + for _ in range(fn.num_basic_blocks): changes = self._optimize_empty_basicblocks() changes += fn.remove_unreachable_blocks() if changes == 0: @@ -131,7 +132,7 @@ def run_pass(self): self.analyses_cache.force_analysis(CFGAnalysis) - for _ in range(len(fn.basic_blocks)): # essentially `while True` + for _ in range(fn.num_basic_blocks): # essentially `while True` self._collapse_chained_blocks(entry) if fn.remove_unreachable_blocks() == 0: break