diff --git a/tests/unit/compiler/venom/test_algebraic_optimizer.py b/tests/unit/compiler/venom/test_algebraic_optimizer.py new file mode 100644 index 0000000000..e0368d4197 --- /dev/null +++ b/tests/unit/compiler/venom/test_algebraic_optimizer.py @@ -0,0 +1,129 @@ +import pytest + +from vyper.venom.analysis.analysis import IRAnalysesCache +from vyper.venom.basicblock import IRBasicBlock, IRLabel +from vyper.venom.context import IRContext +from vyper.venom.passes.algebraic_optimization import AlgebraicOptimizationPass +from vyper.venom.passes.make_ssa import MakeSSA +from vyper.venom.passes.remove_unused_variables import RemoveUnusedVariablesPass + + +@pytest.mark.parametrize("iszero_count", range(5)) +def test_simple_jump_case(iszero_count): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), fn) + fn.append_basic_block(br2) + + p1 = bb.append_instruction("param") + op1 = bb.append_instruction("store", p1) + op2 = bb.append_instruction("store", 64) + op3 = bb.append_instruction("add", op1, op2) + jnz_input = op3 + + for _ in range(iszero_count): + jnz_input = bb.append_instruction("iszero", jnz_input) + + bb.append_instruction("jnz", jnz_input, br1.label, br2.label) + + br1.append_instruction("add", op3, 10) + br1.append_instruction("stop") + br2.append_instruction("add", op3, p1) + br2.append_instruction("stop") + + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + AlgebraicOptimizationPass(ac, fn).run_pass() + RemoveUnusedVariablesPass(ac, fn).run_pass() + + iszeros = [inst for inst in bb.instructions if inst.opcode == "iszero"] + removed_iszeros = iszero_count - len(iszeros) + + assert removed_iszeros % 2 == 0 + assert len(iszeros) == iszero_count % 2 + + +@pytest.mark.parametrize("iszero_count", range(1, 5)) +def test_simple_bool_cast_case(iszero_count): + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + + p1 = bb.append_instruction("param") + op1 = bb.append_instruction("store", p1) + op2 = bb.append_instruction("store", 64) + op3 = bb.append_instruction("add", op1, op2) + jnz_input = op3 + + for _ in range(iszero_count): + jnz_input = bb.append_instruction("iszero", jnz_input) + + bb.append_instruction("mstore", jnz_input, p1) + bb.append_instruction("jmp", br1.label) + + br1.append_instruction("add", op3, 10) + br1.append_instruction("stop") + + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + AlgebraicOptimizationPass(ac, fn).run_pass() + RemoveUnusedVariablesPass(ac, fn).run_pass() + + iszeros = [inst for inst in bb.instructions if inst.opcode == "iszero"] + removed_iszeros = iszero_count - len(iszeros) + + assert removed_iszeros % 2 == 0 + assert len(iszeros) in [1, 2] + assert len(iszeros) % 2 == iszero_count % 2 + + +@pytest.mark.parametrize("interleave_point", range(1, 5)) +def test_interleaved_case(interleave_point): + iszeros_after_interleave_point = interleave_point // 2 + ctx = IRContext() + fn = ctx.create_function("_global") + + bb = fn.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), fn) + fn.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), fn) + fn.append_basic_block(br2) + + p1 = bb.append_instruction("param") + op1 = bb.append_instruction("store", p1) + op2 = bb.append_instruction("store", 64) + op3 = bb.append_instruction("add", op1, op2) + op3_inv = bb.append_instruction("iszero", op3) + jnz_input = op3_inv + for _ in range(interleave_point): + jnz_input = bb.append_instruction("iszero", jnz_input) + bb.append_instruction("mstore", jnz_input, p1) + for _ in range(iszeros_after_interleave_point): + jnz_input = bb.append_instruction("iszero", jnz_input) + bb.append_instruction("jnz", jnz_input, br1.label, br2.label) + + br1.append_instruction("add", op3, 10) + br1.append_instruction("stop") + br2.append_instruction("add", op3, p1) + br2.append_instruction("stop") + + ac = IRAnalysesCache(fn) + MakeSSA(ac, fn).run_pass() + AlgebraicOptimizationPass(ac, fn).run_pass() + RemoveUnusedVariablesPass(ac, fn).run_pass() + + assert bb.instructions[-1].opcode == "jnz" + if (interleave_point + iszeros_after_interleave_point) % 2 == 0: + assert bb.instructions[-1].operands[0] == op3_inv + else: + assert bb.instructions[-1].operands[0] == op3 diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 82901126bc..cd981cd462 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -9,6 +9,7 @@ from vyper.venom.context import IRContext from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom +from vyper.venom.passes.algebraic_optimization import AlgebraicOptimizationPass from vyper.venom.passes.branch_optimization import BranchOptimizationPass from vyper.venom.passes.dft import DFTPass from vyper.venom.passes.make_ssa import MakeSSA @@ -50,6 +51,7 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel) -> None: SCCP(ac, fn).run_pass() StoreElimination(ac, fn).run_pass() SimplifyCFGPass(ac, fn).run_pass() + AlgebraicOptimizationPass(ac, fn).run_pass() BranchOptimizationPass(ac, fn).run_pass() RemoveUnusedVariablesPass(ac, fn).run_pass() DFTPass(ac, fn).run_pass() diff --git a/vyper/venom/passes/algebraic_optimization.py b/vyper/venom/passes/algebraic_optimization.py new file mode 100644 index 0000000000..4094219a6d --- /dev/null +++ b/vyper/venom/passes/algebraic_optimization.py @@ -0,0 +1,67 @@ +from vyper.venom.analysis.dfg import DFGAnalysis +from vyper.venom.analysis.liveness import LivenessAnalysis +from vyper.venom.basicblock import IRInstruction, IROperand +from vyper.venom.passes.base_pass import IRPass + + +class AlgebraicOptimizationPass(IRPass): + """ + This pass reduces algebraic evaluatable expressions. + + It currently optimizes: + * iszero chains + """ + + def _optimize_iszero_chains(self) -> None: + fn = self.function + for bb in fn.get_basic_blocks(): + for inst in bb.instructions: + if inst.opcode != "iszero": + continue + + iszero_chain = self._get_iszero_chain(inst.operands[0]) + iszero_count = len(iszero_chain) + if iszero_count == 0: + continue + + for use_inst in self.dfg.get_uses(inst.output): + opcode = use_inst.opcode + + if opcode == "iszero": + # We keep iszero instuctions as is + continue + if opcode in ("jnz", "assert"): + # instructions that accept a truthy value as input: + # we can remove up to all the iszero instructions + keep_count = 1 - iszero_count % 2 + else: + # all other instructions: + # we need to keep at least one or two iszero instructions + keep_count = 1 + iszero_count % 2 + + if keep_count >= iszero_count: + continue + + out_var = iszero_chain[keep_count].operands[0] + use_inst.replace_operands({inst.output: out_var}) + + def _get_iszero_chain(self, op: IROperand) -> list[IRInstruction]: + chain: list[IRInstruction] = [] + + while True: + inst = self.dfg.get_producing_instruction(op) + if inst is None or inst.opcode != "iszero": + break + op = inst.operands[0] + chain.append(inst) + + chain.reverse() + return chain + + def run_pass(self): + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) + + self._optimize_iszero_chains() + + self.analyses_cache.invalidate_analysis(DFGAnalysis) + self.analyses_cache.invalidate_analysis(LivenessAnalysis)