From 35996f1bb1968960e8f284a852ccf6f3742a2b1e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 8 May 2024 11:42:25 -0400 Subject: [PATCH] fix[venom]: fix branch eliminator cases in sccp (#4003) in sccp, when the operand of `jnz`, `djmp` or `assert` is already an `IRLiteral` (this is most easily seen by disabling the IRnode branch eliminator), the compiler will panic. this commit fixes the bug, and refactor some commonly used code into helper functions. --------- Co-authored-by: Harry Kalogirou --- tests/unit/compiler/venom/test_sccp.py | 72 +++++++++++++++++++++++++- vyper/venom/passes/sccp/sccp.py | 54 +++++++++++++------ 2 files changed, 108 insertions(+), 18 deletions(-) diff --git a/tests/unit/compiler/venom/test_sccp.py b/tests/unit/compiler/venom/test_sccp.py index 37a8bc9000..e65839136e 100644 --- a/tests/unit/compiler/venom/test_sccp.py +++ b/tests/unit/compiler/venom/test_sccp.py @@ -1,5 +1,8 @@ +import pytest + +from vyper.exceptions import StaticAssertionException from vyper.venom.analysis.analysis import IRAnalysesCache -from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable +from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRLiteral, IRVariable from vyper.venom.context import IRContext from vyper.venom.passes.make_ssa import MakeSSA from vyper.venom.passes.sccp import SCCP @@ -28,6 +31,73 @@ def test_simple_case(): assert sccp.lattice[IRVariable("%4")].value == 96 +def test_branch_eliminator_simple(): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb1 = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + br1.append_instruction("stop") + br2 = IRBasicBlock(IRLabel("else"), fn) + br2.append_instruction("jmp", IRLabel("foo")) + + fn.append_basic_block(br1) + fn.append_basic_block(br2) + + bb1.append_instruction("jnz", IRLiteral(1), br1.label, br2.label) + + bb2 = IRBasicBlock(IRLabel("foo"), fn) + bb2.append_instruction("jnz", IRLiteral(0), br1.label, br2.label) + fn.append_basic_block(bb2) + + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + sccp = SCCP(ac, fn) + sccp.run_pass() + + assert bb1.instructions[-1].opcode == "jmp" + assert bb1.instructions[-1].operands == [br1.label] + assert bb2.instructions[-1].opcode == "jmp" + assert bb2.instructions[-1].operands == [br2.label] + + +def test_assert_elimination(): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + bb.append_instruction("assert", IRLiteral(1)) + bb.append_instruction("assert_unreachable", IRLiteral(1)) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + sccp = SCCP(ac, fn) + sccp.run_pass() + + for inst in bb.instructions[:-1]: + assert inst.opcode == "nop" + + +@pytest.mark.parametrize("asserter", ("assert", "assert_unreachable")) +def test_assert_false(asserter): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + bb.append_instruction(asserter, IRLiteral(0)) + bb.append_instruction("stop") + + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + sccp = SCCP(ac, fn) + with pytest.raises(StaticAssertionException): + sccp.run_pass() + + def test_cont_jump_case(): ctx = IRContext() fn = ctx.create_function("_global") diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index 577030dea6..ced1f711c5 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -39,7 +39,7 @@ class FlowWorkItem: WorkListItem = Union[FlowWorkItem, SSAWorkListItem] LatticeItem = Union[LatticeEnum, IRLiteral] -Lattice = dict[IROperand, LatticeItem] +Lattice = dict[IRVariable, LatticeItem] class SCCP(IRPass): @@ -143,6 +143,22 @@ def _handle_SSA_work_item(self, work_item: SSAWorkListItem): elif len(self.cfg_in_exec[work_item.inst.parent]) > 0: self._visit_expr(work_item.inst) + def _lookup_from_lattice(self, op: IROperand) -> LatticeItem: + assert isinstance(op, IRVariable), "Can't get lattice for non-variable" + lat = self.lattice[op] + assert lat is not None, f"Got undefined var {op}" + return lat + + def _set_lattice(self, op: IROperand, value: LatticeItem): + assert isinstance(op, IRVariable), "Can't set lattice for non-variable" + self.lattice[op] = value + + def _eval_from_lattice(self, op: IROperand) -> IRLiteral | LatticeEnum: + if isinstance(op, IRLiteral): + return op + + return self._lookup_from_lattice(op) + def _visit_phi(self, inst: IRInstruction): assert inst.opcode == "phi", "Can't visit non phi instruction" in_vars: list[LatticeItem] = [] @@ -150,26 +166,27 @@ def _visit_phi(self, inst: IRInstruction): bb = self.fn.get_basic_block(bb_label.name) if bb not in self.cfg_in_exec[inst.parent]: continue - in_vars.append(self.lattice[var]) + in_vars.append(self._lookup_from_lattice(var)) value = reduce(_meet, in_vars, LatticeEnum.TOP) # type: ignore assert inst.output in self.lattice, "Got undefined var for phi" - if value != self.lattice[inst.output]: - self.lattice[inst.output] = value + + if value != self._lookup_from_lattice(inst.output): + self._set_lattice(inst.output, value) self._add_ssa_work_items(inst) def _visit_expr(self, inst: IRInstruction): opcode = inst.opcode if opcode in ["store", "alloca"]: - if isinstance(inst.operands[0], IRLiteral): - self.lattice[inst.output] = inst.operands[0] # type: ignore - else: - self.lattice[inst.output] = self.lattice[inst.operands[0]] # type: ignore + assert inst.output is not None, "Got store/alloca without output" + out = self._eval_from_lattice(inst.operands[0]) + self._set_lattice(inst.output, out) self._add_ssa_work_items(inst) elif opcode == "jmp": target = self.fn.get_basic_block(inst.operands[0].value) self.work_list.append(FlowWorkItem(inst.parent, target)) elif opcode == "jnz": - lat = self.lattice[inst.operands[0]] + lat = self._eval_from_lattice(inst.operands[0]) + assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}" if lat == LatticeEnum.BOTTOM: for out_bb in inst.parent.cfg_out: @@ -182,7 +199,7 @@ def _visit_expr(self, inst: IRInstruction): target = self.fn.get_basic_block(inst.operands[2].name) self.work_list.append(FlowWorkItem(inst.parent, target)) elif opcode == "djmp": - lat = self.lattice[inst.operands[0]] + lat = self._eval_from_lattice(inst.operands[0]) assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}" if lat == LatticeEnum.BOTTOM: for op in inst.operands[1:]: @@ -200,7 +217,7 @@ def _visit_expr(self, inst: IRInstruction): self._eval(inst) else: if inst.output is not None: - self.lattice[inst.output] = LatticeEnum.BOTTOM + self._set_lattice(inst.output, LatticeEnum.BOTTOM) def _eval(self, inst) -> LatticeItem: """ @@ -267,16 +284,17 @@ def _propagate_constants(self): """ for bb in self.dom.dfs_walk: for inst in bb.instructions: - self._replace_constants(inst, self.lattice) + self._replace_constants(inst) - def _replace_constants(self, inst: IRInstruction, lattice: Lattice): + def _replace_constants(self, inst: IRInstruction): """ This method replaces constant values in the instruction with their actual values. It also updates the instruction opcode in case of jumps and asserts as needed. """ if inst.opcode == "jnz": - lat = lattice[inst.operands[0]] + lat = self._eval_from_lattice(inst.operands[0]) + if isinstance(lat, IRLiteral): if lat.value == 0: target = inst.operands[2] @@ -285,8 +303,10 @@ def _replace_constants(self, inst: IRInstruction, lattice: Lattice): inst.opcode = "jmp" inst.operands = [target] self.cfg_dirty = True - elif inst.opcode == "assert": - lat = lattice[inst.operands[0]] + + elif inst.opcode in ("assert", "assert_unreachable"): + lat = self._eval_from_lattice(inst.operands[0]) + if isinstance(lat, IRLiteral): if lat.value > 0: inst.opcode = "nop" @@ -303,7 +323,7 @@ def _replace_constants(self, inst: IRInstruction, lattice: Lattice): for i, op in enumerate(inst.operands): if isinstance(op, IRVariable): - lat = lattice[op] + lat = self.lattice[op] if isinstance(lat, IRLiteral): inst.operands[i] = lat