Skip to content

Commit

Permalink
feat[venom]: add algebraic optimization pass (#4054)
Browse files Browse the repository at this point in the history
Add a new venom pass to do algebraic optimizations. 

Currently optimizes `iszero` chains.

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
harkal and charles-cooper authored May 29, 2024
1 parent 1b335c5 commit dcec257
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 0 deletions.
129 changes: 129 additions & 0 deletions tests/unit/compiler/venom/test_algebraic_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pytest

from vyper.venom.analysis.analysis import IRAnalysesCache
from vyper.venom.basicblock import IRBasicBlock, IRLabel
from vyper.venom.context import IRContext
from vyper.venom.passes.algebraic_optimization import AlgebraicOptimizationPass
from vyper.venom.passes.make_ssa import MakeSSA
from vyper.venom.passes.remove_unused_variables import RemoveUnusedVariablesPass


@pytest.mark.parametrize("iszero_count", range(5))
def test_simple_jump_case(iszero_count):
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
fn.append_basic_block(br1)
br2 = IRBasicBlock(IRLabel("else"), fn)
fn.append_basic_block(br2)

p1 = bb.append_instruction("param")
op1 = bb.append_instruction("store", p1)
op2 = bb.append_instruction("store", 64)
op3 = bb.append_instruction("add", op1, op2)
jnz_input = op3

for _ in range(iszero_count):
jnz_input = bb.append_instruction("iszero", jnz_input)

bb.append_instruction("jnz", jnz_input, br1.label, br2.label)

br1.append_instruction("add", op3, 10)
br1.append_instruction("stop")
br2.append_instruction("add", op3, p1)
br2.append_instruction("stop")

ac = IRAnalysesCache(fn)
MakeSSA(ac, fn).run_pass()
AlgebraicOptimizationPass(ac, fn).run_pass()
RemoveUnusedVariablesPass(ac, fn).run_pass()

iszeros = [inst for inst in bb.instructions if inst.opcode == "iszero"]
removed_iszeros = iszero_count - len(iszeros)

assert removed_iszeros % 2 == 0
assert len(iszeros) == iszero_count % 2


@pytest.mark.parametrize("iszero_count", range(1, 5))
def test_simple_bool_cast_case(iszero_count):
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
fn.append_basic_block(br1)

p1 = bb.append_instruction("param")
op1 = bb.append_instruction("store", p1)
op2 = bb.append_instruction("store", 64)
op3 = bb.append_instruction("add", op1, op2)
jnz_input = op3

for _ in range(iszero_count):
jnz_input = bb.append_instruction("iszero", jnz_input)

bb.append_instruction("mstore", jnz_input, p1)
bb.append_instruction("jmp", br1.label)

br1.append_instruction("add", op3, 10)
br1.append_instruction("stop")

ac = IRAnalysesCache(fn)
MakeSSA(ac, fn).run_pass()
AlgebraicOptimizationPass(ac, fn).run_pass()
RemoveUnusedVariablesPass(ac, fn).run_pass()

iszeros = [inst for inst in bb.instructions if inst.opcode == "iszero"]
removed_iszeros = iszero_count - len(iszeros)

assert removed_iszeros % 2 == 0
assert len(iszeros) in [1, 2]
assert len(iszeros) % 2 == iszero_count % 2


@pytest.mark.parametrize("interleave_point", range(1, 5))
def test_interleaved_case(interleave_point):
iszeros_after_interleave_point = interleave_point // 2
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
fn.append_basic_block(br1)
br2 = IRBasicBlock(IRLabel("else"), fn)
fn.append_basic_block(br2)

p1 = bb.append_instruction("param")
op1 = bb.append_instruction("store", p1)
op2 = bb.append_instruction("store", 64)
op3 = bb.append_instruction("add", op1, op2)
op3_inv = bb.append_instruction("iszero", op3)
jnz_input = op3_inv
for _ in range(interleave_point):
jnz_input = bb.append_instruction("iszero", jnz_input)
bb.append_instruction("mstore", jnz_input, p1)
for _ in range(iszeros_after_interleave_point):
jnz_input = bb.append_instruction("iszero", jnz_input)
bb.append_instruction("jnz", jnz_input, br1.label, br2.label)

br1.append_instruction("add", op3, 10)
br1.append_instruction("stop")
br2.append_instruction("add", op3, p1)
br2.append_instruction("stop")

ac = IRAnalysesCache(fn)
MakeSSA(ac, fn).run_pass()
AlgebraicOptimizationPass(ac, fn).run_pass()
RemoveUnusedVariablesPass(ac, fn).run_pass()

assert bb.instructions[-1].opcode == "jnz"
if (interleave_point + iszeros_after_interleave_point) % 2 == 0:
assert bb.instructions[-1].operands[0] == op3_inv
else:
assert bb.instructions[-1].operands[0] == op3
2 changes: 2 additions & 0 deletions vyper/venom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vyper.venom.context import IRContext
from vyper.venom.function import IRFunction
from vyper.venom.ir_node_to_venom import ir_node_to_venom
from vyper.venom.passes.algebraic_optimization import AlgebraicOptimizationPass
from vyper.venom.passes.branch_optimization import BranchOptimizationPass
from vyper.venom.passes.dft import DFTPass
from vyper.venom.passes.make_ssa import MakeSSA
Expand Down Expand Up @@ -50,6 +51,7 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None:
SCCP(ac, fn).run_pass()
StoreElimination(ac, fn).run_pass()
SimplifyCFGPass(ac, fn).run_pass()
AlgebraicOptimizationPass(ac, fn).run_pass()
BranchOptimizationPass(ac, fn).run_pass()
RemoveUnusedVariablesPass(ac, fn).run_pass()
DFTPass(ac, fn).run_pass()
Expand Down
67 changes: 67 additions & 0 deletions vyper/venom/passes/algebraic_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from vyper.venom.analysis.dfg import DFGAnalysis
from vyper.venom.analysis.liveness import LivenessAnalysis
from vyper.venom.basicblock import IRInstruction, IROperand
from vyper.venom.passes.base_pass import IRPass


class AlgebraicOptimizationPass(IRPass):
"""
This pass reduces algebraic evaluatable expressions.
It currently optimizes:
* iszero chains
"""

def _optimize_iszero_chains(self) -> None:
fn = self.function
for bb in fn.get_basic_blocks():
for inst in bb.instructions:
if inst.opcode != "iszero":
continue

iszero_chain = self._get_iszero_chain(inst.operands[0])
iszero_count = len(iszero_chain)
if iszero_count == 0:
continue

for use_inst in self.dfg.get_uses(inst.output):
opcode = use_inst.opcode

if opcode == "iszero":
# We keep iszero instuctions as is
continue
if opcode in ("jnz", "assert"):
# instructions that accept a truthy value as input:
# we can remove up to all the iszero instructions
keep_count = 1 - iszero_count % 2
else:
# all other instructions:
# we need to keep at least one or two iszero instructions
keep_count = 1 + iszero_count % 2

if keep_count >= iszero_count:
continue

out_var = iszero_chain[keep_count].operands[0]
use_inst.replace_operands({inst.output: out_var})

def _get_iszero_chain(self, op: IROperand) -> list[IRInstruction]:
chain: list[IRInstruction] = []

while True:
inst = self.dfg.get_producing_instruction(op)
if inst is None or inst.opcode != "iszero":
break
op = inst.operands[0]
chain.append(inst)

chain.reverse()
return chain

def run_pass(self):
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis)

self._optimize_iszero_chains()

self.analyses_cache.invalidate_analysis(DFGAnalysis)
self.analyses_cache.invalidate_analysis(LivenessAnalysis)

0 comments on commit dcec257

Please sign in to comment.