Skip to content

Commit

Permalink
fix[venom]: fix branch eliminator cases in sccp (#4003)
Browse files Browse the repository at this point in the history
in sccp, when the operand of `jnz`, `djmp` or `assert` is already an
`IRLiteral` (this is most easily seen by disabling the IRnode branch
eliminator), the compiler will panic. this commit fixes the bug, and
refactor some commonly used code into helper functions.

---------

Co-authored-by: Harry Kalogirou <[email protected]>
  • Loading branch information
charles-cooper and harkal authored May 8, 2024
1 parent 93147be commit 35996f1
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 18 deletions.
72 changes: 71 additions & 1 deletion tests/unit/compiler/venom/test_sccp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

from vyper.exceptions import StaticAssertionException
from vyper.venom.analysis.analysis import IRAnalysesCache
from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable
from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRLiteral, IRVariable
from vyper.venom.context import IRContext
from vyper.venom.passes.make_ssa import MakeSSA
from vyper.venom.passes.sccp import SCCP
Expand Down Expand Up @@ -28,6 +31,73 @@ def test_simple_case():
assert sccp.lattice[IRVariable("%4")].value == 96


def test_branch_eliminator_simple():
ctx = IRContext()
fn = ctx.create_function("_global")

bb1 = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
br1.append_instruction("stop")
br2 = IRBasicBlock(IRLabel("else"), fn)
br2.append_instruction("jmp", IRLabel("foo"))

fn.append_basic_block(br1)
fn.append_basic_block(br2)

bb1.append_instruction("jnz", IRLiteral(1), br1.label, br2.label)

bb2 = IRBasicBlock(IRLabel("foo"), fn)
bb2.append_instruction("jnz", IRLiteral(0), br1.label, br2.label)
fn.append_basic_block(bb2)

ac = IRAnalysesCache(fn)
MakeSSA(ac, fn).run_pass()
sccp = SCCP(ac, fn)
sccp.run_pass()

assert bb1.instructions[-1].opcode == "jmp"
assert bb1.instructions[-1].operands == [br1.label]
assert bb2.instructions[-1].opcode == "jmp"
assert bb2.instructions[-1].operands == [br2.label]


def test_assert_elimination():
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

bb.append_instruction("assert", IRLiteral(1))
bb.append_instruction("assert_unreachable", IRLiteral(1))
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)
MakeSSA(ac, fn).run_pass()
sccp = SCCP(ac, fn)
sccp.run_pass()

for inst in bb.instructions[:-1]:
assert inst.opcode == "nop"


@pytest.mark.parametrize("asserter", ("assert", "assert_unreachable"))
def test_assert_false(asserter):
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

bb.append_instruction(asserter, IRLiteral(0))
bb.append_instruction("stop")

ac = IRAnalysesCache(fn)
MakeSSA(ac, fn).run_pass()
sccp = SCCP(ac, fn)
with pytest.raises(StaticAssertionException):
sccp.run_pass()


def test_cont_jump_case():
ctx = IRContext()
fn = ctx.create_function("_global")
Expand Down
54 changes: 37 additions & 17 deletions vyper/venom/passes/sccp/sccp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class FlowWorkItem:

WorkListItem = Union[FlowWorkItem, SSAWorkListItem]
LatticeItem = Union[LatticeEnum, IRLiteral]
Lattice = dict[IROperand, LatticeItem]
Lattice = dict[IRVariable, LatticeItem]


class SCCP(IRPass):
Expand Down Expand Up @@ -143,33 +143,50 @@ def _handle_SSA_work_item(self, work_item: SSAWorkListItem):
elif len(self.cfg_in_exec[work_item.inst.parent]) > 0:
self._visit_expr(work_item.inst)

def _lookup_from_lattice(self, op: IROperand) -> LatticeItem:
assert isinstance(op, IRVariable), "Can't get lattice for non-variable"
lat = self.lattice[op]
assert lat is not None, f"Got undefined var {op}"
return lat

def _set_lattice(self, op: IROperand, value: LatticeItem):
assert isinstance(op, IRVariable), "Can't set lattice for non-variable"
self.lattice[op] = value

def _eval_from_lattice(self, op: IROperand) -> IRLiteral | LatticeEnum:
if isinstance(op, IRLiteral):
return op

return self._lookup_from_lattice(op)

def _visit_phi(self, inst: IRInstruction):
assert inst.opcode == "phi", "Can't visit non phi instruction"
in_vars: list[LatticeItem] = []
for bb_label, var in inst.phi_operands:
bb = self.fn.get_basic_block(bb_label.name)
if bb not in self.cfg_in_exec[inst.parent]:
continue
in_vars.append(self.lattice[var])
in_vars.append(self._lookup_from_lattice(var))
value = reduce(_meet, in_vars, LatticeEnum.TOP) # type: ignore
assert inst.output in self.lattice, "Got undefined var for phi"
if value != self.lattice[inst.output]:
self.lattice[inst.output] = value

if value != self._lookup_from_lattice(inst.output):
self._set_lattice(inst.output, value)
self._add_ssa_work_items(inst)

def _visit_expr(self, inst: IRInstruction):
opcode = inst.opcode
if opcode in ["store", "alloca"]:
if isinstance(inst.operands[0], IRLiteral):
self.lattice[inst.output] = inst.operands[0] # type: ignore
else:
self.lattice[inst.output] = self.lattice[inst.operands[0]] # type: ignore
assert inst.output is not None, "Got store/alloca without output"
out = self._eval_from_lattice(inst.operands[0])
self._set_lattice(inst.output, out)
self._add_ssa_work_items(inst)
elif opcode == "jmp":
target = self.fn.get_basic_block(inst.operands[0].value)
self.work_list.append(FlowWorkItem(inst.parent, target))
elif opcode == "jnz":
lat = self.lattice[inst.operands[0]]
lat = self._eval_from_lattice(inst.operands[0])

assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}"
if lat == LatticeEnum.BOTTOM:
for out_bb in inst.parent.cfg_out:
Expand All @@ -182,7 +199,7 @@ def _visit_expr(self, inst: IRInstruction):
target = self.fn.get_basic_block(inst.operands[2].name)
self.work_list.append(FlowWorkItem(inst.parent, target))
elif opcode == "djmp":
lat = self.lattice[inst.operands[0]]
lat = self._eval_from_lattice(inst.operands[0])
assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}"
if lat == LatticeEnum.BOTTOM:
for op in inst.operands[1:]:
Expand All @@ -200,7 +217,7 @@ def _visit_expr(self, inst: IRInstruction):
self._eval(inst)
else:
if inst.output is not None:
self.lattice[inst.output] = LatticeEnum.BOTTOM
self._set_lattice(inst.output, LatticeEnum.BOTTOM)

def _eval(self, inst) -> LatticeItem:
"""
Expand Down Expand Up @@ -267,16 +284,17 @@ def _propagate_constants(self):
"""
for bb in self.dom.dfs_walk:
for inst in bb.instructions:
self._replace_constants(inst, self.lattice)
self._replace_constants(inst)

def _replace_constants(self, inst: IRInstruction, lattice: Lattice):
def _replace_constants(self, inst: IRInstruction):
"""
This method replaces constant values in the instruction with
their actual values. It also updates the instruction opcode in
case of jumps and asserts as needed.
"""
if inst.opcode == "jnz":
lat = lattice[inst.operands[0]]
lat = self._eval_from_lattice(inst.operands[0])

if isinstance(lat, IRLiteral):
if lat.value == 0:
target = inst.operands[2]
Expand All @@ -285,8 +303,10 @@ def _replace_constants(self, inst: IRInstruction, lattice: Lattice):
inst.opcode = "jmp"
inst.operands = [target]
self.cfg_dirty = True
elif inst.opcode == "assert":
lat = lattice[inst.operands[0]]

elif inst.opcode in ("assert", "assert_unreachable"):
lat = self._eval_from_lattice(inst.operands[0])

if isinstance(lat, IRLiteral):
if lat.value > 0:
inst.opcode = "nop"
Expand All @@ -303,7 +323,7 @@ def _replace_constants(self, inst: IRInstruction, lattice: Lattice):

for i, op in enumerate(inst.operands):
if isinstance(op, IRVariable):
lat = lattice[op]
lat = self.lattice[op]
if isinstance(lat, IRLiteral):
inst.operands[i] = lat

Expand Down

0 comments on commit 35996f1

Please sign in to comment.