From 93147be821189a42a4157c6aca7b54810fbee519 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 8 May 2024 10:42:20 -0400 Subject: [PATCH] feat[venom]: optimize `get_basic_block()` (#4002) `get_basic_block()` is a hotspot in venom (up to 35% of total compilation time!). this optimizes `get_basic_block()`, on a large contract near the 24kb limit this reduces time spent in venom from 3s to 1s (total time from 6s to 4s). note on the same contract, time spent in the IRnode optimizer pipeline is 2s - so time in venom is now smaller than time in legacy optimizer(!) notes: - refactor to use dict for basic_blocks - clean up basic blocks API hide basic blocks behind `get_basic_blocks()` iterator and `num_basic_blocks`. --- vyper/venom/analysis/cfg.py | 6 +- vyper/venom/analysis/dfg.py | 2 +- vyper/venom/analysis/dominators.py | 2 +- vyper/venom/analysis/dup_requirements.py | 2 +- vyper/venom/analysis/liveness.py | 4 +- vyper/venom/function.py | 97 ++++++++----------- vyper/venom/ir_node_to_venom.py | 7 +- vyper/venom/passes/dft.py | 4 +- vyper/venom/passes/normalization.py | 4 +- vyper/venom/passes/remove_unused_variables.py | 2 +- vyper/venom/passes/sccp/sccp.py | 2 +- vyper/venom/passes/simplify_cfg.py | 29 +++--- 12 files changed, 72 insertions(+), 89 deletions(-) diff --git a/vyper/venom/analysis/cfg.py b/vyper/venom/analysis/cfg.py index 2a521ab131..6bd7e538e9 100644 --- a/vyper/venom/analysis/cfg.py +++ b/vyper/venom/analysis/cfg.py @@ -10,12 +10,12 @@ class CFGAnalysis(IRAnalysis): def analyze(self) -> None: fn = self.function - for bb in fn.basic_blocks: + for bb in fn.get_basic_blocks(): bb.cfg_in = OrderedSet() bb.cfg_out = OrderedSet() bb.out_vars = OrderedSet() - for bb in fn.basic_blocks: + for bb in fn.get_basic_blocks(): assert len(bb.instructions) > 0, "Basic block should not be empty" last_inst = bb.instructions[-1] assert ( @@ -29,7 +29,7 @@ def analyze(self) -> None: fn.get_basic_block(op.value).add_cfg_in(bb) # Fill in the "out" set for each basic block - for bb in fn.basic_blocks: + for bb in fn.get_basic_blocks(): for in_bb in bb.cfg_in: in_bb.add_cfg_out(bb) diff --git a/vyper/venom/analysis/dfg.py b/vyper/venom/analysis/dfg.py index 8b113e74bc..dc7076d5de 100644 --- a/vyper/venom/analysis/dfg.py +++ b/vyper/venom/analysis/dfg.py @@ -33,7 +33,7 @@ def analyze(self): # %16 = iszero %15 # dfg_outputs of %15 is (%15 = add %13 %14) # dfg_inputs of %15 is all the instructions which *use* %15, ex. [(%16 = iszero %15), ...] - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): for inst in bb.instructions: operands = inst.get_inputs() res = inst.get_outputs() diff --git a/vyper/venom/analysis/dominators.py b/vyper/venom/analysis/dominators.py index c0b149d880..129d1d0f22 100644 --- a/vyper/venom/analysis/dominators.py +++ b/vyper/venom/analysis/dominators.py @@ -153,7 +153,7 @@ def as_graph(self) -> str: Generate a graphviz representation of the dominator tree. """ lines = ["digraph dominator_tree {"] - for bb in self.fn.basic_blocks: + for bb in self.fn.get_basic_blocks(): if bb == self.entry_block: continue idom = self.immediate_dominator(bb) diff --git a/vyper/venom/analysis/dup_requirements.py b/vyper/venom/analysis/dup_requirements.py index 015c7c5871..3452bc2e0f 100644 --- a/vyper/venom/analysis/dup_requirements.py +++ b/vyper/venom/analysis/dup_requirements.py @@ -4,7 +4,7 @@ class DupRequirementsAnalysis(IRAnalysis): def analyze(self): - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): last_liveness = bb.out_vars for inst in reversed(bb.instructions): inst.dup_requirements = OrderedSet() diff --git a/vyper/venom/analysis/liveness.py b/vyper/venom/analysis/liveness.py index 95853e57aa..5e78aa4ff3 100644 --- a/vyper/venom/analysis/liveness.py +++ b/vyper/venom/analysis/liveness.py @@ -15,7 +15,7 @@ def analyze(self): self._reset_liveness() while True: changed = False - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): changed |= self._calculate_out_vars(bb) changed |= self._calculate_liveness(bb) @@ -23,7 +23,7 @@ def analyze(self): break def _reset_liveness(self) -> None: - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): bb.out_vars = OrderedSet() for inst in bb.instructions: inst.liveness = OrderedSet() diff --git a/vyper/venom/function.py b/vyper/venom/function.py index 556be28246..eace17af0d 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -13,57 +13,42 @@ class IRFunction: name: IRLabel # symbol name ctx: "IRContext" # type: ignore # noqa: F821 args: list - basic_blocks: list[IRBasicBlock] last_label: int last_variable: int + _basic_block_dict: dict[str, IRBasicBlock] # Used during code generation _ast_source_stack: list[IRnode] _error_msg_stack: list[str] - _bb_index: dict[str, int] def __init__(self, name: IRLabel, ctx: "IRContext" = None) -> None: # type: ignore # noqa: F821 self.ctx = ctx self.name = name self.args = [] - self.basic_blocks = [] + self._basic_block_dict = {} self.last_variable = 0 self._ast_source_stack = [] self._error_msg_stack = [] - self._bb_index = {} self.append_basic_block(IRBasicBlock(name, self)) @property def entry(self) -> IRBasicBlock: - return self.basic_blocks[0] + return next(self.get_basic_blocks()) - def append_basic_block(self, bb: IRBasicBlock) -> IRBasicBlock: + def append_basic_block(self, bb: IRBasicBlock): """ Append basic block to function. """ - assert isinstance(bb, IRBasicBlock), f"append_basic_block takes IRBasicBlock, got '{bb}'" - self.basic_blocks.append(bb) - - return self.basic_blocks[-1] - - def _get_basicblock_index(self, label: str): - # perf: keep an "index" of labels to block indices to - # perform fast lookup. - # TODO: maybe better just to throw basic blocks in an ordered - # dict of some kind. - ix = self._bb_index.get(label, -1) - if 0 <= ix < len(self.basic_blocks) and self.basic_blocks[ix].label == label: - return ix - # do a reindex - self._bb_index = dict((bb.label.name, ix) for ix, bb in enumerate(self.basic_blocks)) - # sanity check - no duplicate labels - assert len(self._bb_index) == len( - self.basic_blocks - ), f"Duplicate labels in function '{self.name}' {self._bb_index} {self.basic_blocks}" - return self._bb_index[label] + assert isinstance(bb, IRBasicBlock), bb + assert bb.label.name not in self._basic_block_dict + self._basic_block_dict[bb.label.name] = bb + + def remove_basic_block(self, bb: IRBasicBlock): + assert isinstance(bb, IRBasicBlock), bb + del self._basic_block_dict[bb.label.name] def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: """ @@ -71,33 +56,31 @@ def get_basic_block(self, label: Optional[str] = None) -> IRBasicBlock: If label is None, return the last basic block. """ if label is None: - return self.basic_blocks[-1] - ix = self._get_basicblock_index(label) - return self.basic_blocks[ix] + return next(reversed(self._basic_block_dict.values())) + + return self._basic_block_dict[label] + + def clear_basic_blocks(self): + self._basic_block_dict.clear() - def get_basic_block_after(self, label: IRLabel) -> IRBasicBlock: + def get_basic_blocks(self) -> Iterator[IRBasicBlock]: """ - Get basic block after label. + Get an iterator over this function's basic blocks """ - ix = self._get_basicblock_index(label.value) - if 0 <= ix < len(self.basic_blocks) - 1: - return self.basic_blocks[ix + 1] - raise AssertionError(f"Basic block after '{label}' not found") + return iter(self._basic_block_dict.values()) + + @property + def num_basic_blocks(self) -> int: + return len(self._basic_block_dict) def get_terminal_basicblocks(self) -> Iterator[IRBasicBlock]: """ Get basic blocks that are terminal. """ - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): if bb.is_terminal: yield bb - def get_basicblocks_in(self, basic_block: IRBasicBlock) -> list[IRBasicBlock]: - """ - Get basic blocks that point to the given basic block - """ - return [bb for bb in self.basic_blocks if basic_block.label in bb.cfg_in] - def get_next_variable(self) -> IRVariable: self.last_variable += 1 return IRVariable(f"%{self.last_variable}") @@ -109,15 +92,14 @@ def remove_unreachable_blocks(self) -> int: self._compute_reachability() removed = [] - new_basic_blocks = [] # Remove unreachable basic blocks - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): if not bb.is_reachable: removed.append(bb) - else: - new_basic_blocks.append(bb) - self.basic_blocks = new_basic_blocks + + for bb in removed: + self.remove_basic_block(bb) # Remove phi instructions that reference removed basic blocks for bb in removed: @@ -142,7 +124,7 @@ def _compute_reachability(self) -> None: """ Compute reachability of basic blocks. """ - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): bb.reachable = OrderedSet() bb.is_reachable = False @@ -172,7 +154,7 @@ def normalized(self) -> bool: Having a normalized CFG makes calculation of stack layout easier when emitting assembly. """ - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): # Ignore if there are no multiple predecessors if len(bb.cfg_in) <= 1: continue @@ -211,22 +193,23 @@ def chain_basic_blocks(self) -> None: Otherwise, append a stop instruction. This is necessary for the IR to be valid, and is done after the IR is generated. """ - for i, bb in enumerate(self.basic_blocks): + bbs = list(self.get_basic_blocks()) + for i, bb in enumerate(bbs): if not bb.is_terminated: - if len(self.basic_blocks) - 1 > i: + if i < len(bbs) - 1: # TODO: revisit this. When contructor calls internal functions they # are linked to the last ctor block. Should separate them before this # so we don't have to handle this here - if self.basic_blocks[i + 1].label.value.startswith("internal"): + if bbs[i + 1].label.value.startswith("internal"): bb.append_instruction("stop") else: - bb.append_instruction("jmp", self.basic_blocks[i + 1].label) + bb.append_instruction("jmp", bbs[i + 1].label) else: bb.append_instruction("exit") def copy(self): new = IRFunction(self.name) - new.basic_blocks = self.basic_blocks.copy() + new._basic_block_dict = self._basic_block_dict.copy() new.last_label = self.last_label new.last_variable = self.last_variable return new @@ -246,11 +229,11 @@ def _make_label(bb): ret = "digraph G {\n" - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): for out_bb in bb.cfg_out: ret += f' "{bb.label.value}" -> "{out_bb.label.value}"\n' - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): ret += f' "{bb.label.value}" [shape=plaintext, ' ret += f'label={_make_label(bb)}, fontname="Courier" fontsize="8"]\n' @@ -259,6 +242,6 @@ def _make_label(bb): def __repr__(self) -> str: str = f"IRFunction: {self.name}\n" - for bb in self.basic_blocks: + for bb in self.get_basic_blocks(): str += f"{bb}\n" return str.strip() diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index b4465e9f7b..61b3c081ff 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -135,10 +135,9 @@ def _append_jmp(fn: IRFunction, label: IRLabel) -> None: bb.append_instruction("jmp", label) -def _new_block(fn: IRFunction) -> IRBasicBlock: +def _new_block(fn: IRFunction) -> None: bb = IRBasicBlock(fn.ctx.get_next_label(), fn) - bb = fn.append_basic_block(bb) - return bb + fn.append_basic_block(bb) def _append_return_args(fn: IRFunction, ofst: int = 0, size: int = 0): @@ -328,7 +327,7 @@ def _convert_ir_bb(fn, ir, symbols): # exit bb exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), fn) - exit_bb = fn.append_basic_block(exit_bb) + fn.append_basic_block(exit_bb) if_ret = fn.get_next_variable() if then_ret_val is not None and else_ret_val is not None: diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index e4e27ed813..06366e4336 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -74,8 +74,8 @@ def run_pass(self) -> None: self.fence_id = 0 self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() - basic_blocks = self.function.basic_blocks + basic_blocks = list(self.function.get_basic_blocks()) - self.function.basic_blocks = [] + self.function.clear_basic_blocks() for bb in basic_blocks: self._process_basic_block(bb) diff --git a/vyper/venom/passes/normalization.py b/vyper/venom/passes/normalization.py index 83c565b1be..cf44c3cf89 100644 --- a/vyper/venom/passes/normalization.py +++ b/vyper/venom/passes/normalization.py @@ -58,7 +58,7 @@ def _run_pass(self) -> int: self.analyses_cache.request_analysis(CFGAnalysis) # Split blocks that need splitting - for bb in fn.basic_blocks: + for bb in list(fn.get_basic_blocks()): if len(bb.cfg_in) > 1: self._split_basic_block(bb) @@ -71,7 +71,7 @@ def _run_pass(self) -> int: def run_pass(self): fn = self.function - for _ in range(len(fn.basic_blocks) * 2): + for _ in range(fn.num_basic_blocks * 2): if self._run_pass() == 0: break else: diff --git a/vyper/venom/passes/remove_unused_variables.py b/vyper/venom/passes/remove_unused_variables.py index b7fb3abbf0..a4cd737e98 100644 --- a/vyper/venom/passes/remove_unused_variables.py +++ b/vyper/venom/passes/remove_unused_variables.py @@ -9,7 +9,7 @@ def run_pass(self): self.analyses_cache.request_analysis(LivenessAnalysis) - for bb in self.function.basic_blocks: + for bb in self.function.get_basic_blocks(): for i, inst in enumerate(bb.instructions[:-1]): if inst.volatile: continue diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index 7f3fc7e03e..577030dea6 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -87,7 +87,7 @@ def _calculate_sccp(self, entry: IRBasicBlock): and the work list. The `_propagate_constants()` method is responsible for updating the IR with the constant values. """ - self.cfg_in_exec = {bb: OrderedSet() for bb in self.fn.basic_blocks} + self.cfg_in_exec = {bb: OrderedSet() for bb in self.fn.get_basic_blocks()} dummy = IRBasicBlock(IRLabel("__dummy_start"), self.fn) self.work_list.append(FlowWorkItem(dummy, entry)) diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py index bb5233eba0..08582fee96 100644 --- a/vyper/venom/passes/simplify_cfg.py +++ b/vyper/venom/passes/simplify_cfg.py @@ -30,7 +30,7 @@ def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock): break inst.operands[inst.operands.index(b.label)] = a.label - self.function.basic_blocks.remove(b) + self.function.remove_basic_block(b) def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): next_bb = b.cfg_out.first() @@ -44,7 +44,7 @@ def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): next_bb.remove_cfg_in(b) next_bb.add_cfg_in(a) - self.function.basic_blocks.remove(b) + self.function.remove_basic_block(b) def _collapse_chained_blocks_r(self, bb: IRBasicBlock): """ @@ -87,31 +87,32 @@ def _optimize_empty_basicblocks(self) -> int: Remove empty basic blocks. """ fn = self.function - count = 0 - i = 0 - while i < len(fn.basic_blocks): - bb = fn.basic_blocks[i] + worklist = list(fn.get_basic_blocks()) + i = count = 0 + while i < len(worklist): + bb = worklist[i] i += 1 + if len(bb.instructions) > 0: continue + next_bb = worklist[i] + replaced_label = bb.label - replacement_label = fn.basic_blocks[i].label if i < len(fn.basic_blocks) else None - if replacement_label is None: - continue + replacement_label = next_bb.label # Try to preserve symbol labels if replaced_label.is_symbol: replaced_label, replacement_label = replacement_label, replaced_label - fn.basic_blocks[i].label = replacement_label + next_bb.label = replacement_label - for bb2 in fn.basic_blocks: + for bb2 in fn.get_basic_blocks(): for inst in bb2.instructions: for op in inst.operands: if isinstance(op, IRLabel) and op.value == replaced_label.value: op.value = replacement_label.value - fn.basic_blocks.remove(bb) + fn.remove_basic_block(bb) i -= 1 count += 1 @@ -121,7 +122,7 @@ def run_pass(self): fn = self.function entry = fn.entry - for _ in range(len(fn.basic_blocks)): + for _ in range(fn.num_basic_blocks): changes = self._optimize_empty_basicblocks() changes += fn.remove_unreachable_blocks() if changes == 0: @@ -131,7 +132,7 @@ def run_pass(self): self.analyses_cache.force_analysis(CFGAnalysis) - for _ in range(len(fn.basic_blocks)): # essentially `while True` + for _ in range(fn.num_basic_blocks): # essentially `while True` self._collapse_chained_blocks(entry) if fn.remove_unreachable_blocks() == 0: break