diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/ast_processor.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/ast_processor.py index 3f5ac3480..dfbc3966c 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/ast_processor.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/ast_processor.py @@ -154,7 +154,7 @@ def _simplify_reaching_conditions(self) -> None: This helps to remove unnecessary conditions for finding switches. """ for node in self.asforest.post_order(self.asforest.current_root): - node.simplify_reaching_condition(self.asforest.condition_handler.get_z3_condition_map()) + node.simplify_reaching_condition(self.asforest.condition_handler) def _combine_nodes_with_same_reaching_conditions(self) -> None: """ diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py index 1a4063840..774165f81 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py @@ -268,7 +268,9 @@ def _update_reaching_condition_of(self, case_node: CaseNode, considered_conditio else: exception_condition &= ~literal case_node.reaching_condition = ( - LogicCondition.disjunction_of(literals_of_case_node) if literals_of_case_node else self.condition_handler.get_false_value() + LogicCondition.disjunction_of(list(literals_of_case_node)) + if literals_of_case_node + else self.condition_handler.get_false_value() ) if not exception_condition.is_true: case_node.child.reaching_condition = case_node.child.reaching_condition & exception_condition diff --git a/decompiler/structures/ast/ast_nodes.py b/decompiler/structures/ast/ast_nodes.py index 3ac356ff9..566a2f50e 100644 --- a/decompiler/structures/ast/ast_nodes.py +++ b/decompiler/structures/ast/ast_nodes.py @@ -5,6 +5,7 @@ from enum import Enum from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Tuple, TypeVar, Union +from decompiler.structures.ast.condition_symbol import ConditionHandler from decompiler.structures.ast.reachability_graph import CaseDependencyGraph, SiblingReachability from decompiler.structures.graphs.interface import GraphNodeInterface from decompiler.structures.logic.logic_condition import LogicCondition, PseudoLogicCondition @@ -163,10 +164,10 @@ def clean(self) -> None: """Makes clean ups, depending on the node. This helps to standardize the AST.""" pass - def simplify_reaching_condition(self, z3_condition_of: Dict[LogicCondition, PseudoLogicCondition]): + def simplify_reaching_condition(self, condition_handler: ConditionHandler): """Simplify the reaching condition. If it is false we remove the subtree of this node.""" if not self.reaching_condition.is_true: - self.reaching_condition.remove_redundancy(z3_condition_of) + self.reaching_condition.remove_redundancy(condition_handler) if self.reaching_condition.is_false: logging.warning(f"The CFG node {self} has reaching condition false, therefore, we remove it.") self._ast.remove_subtree(self) @@ -534,16 +535,16 @@ def get_possible_case_candidate_condition(self) -> Optional[LogicCondition]: return self.reaching_condition & self.condition return None - def simplify_reaching_condition(self, z3_condition_of: Dict[LogicCondition, PseudoLogicCondition]): + def simplify_reaching_condition(self, condition_handler: ConditionHandler): """ Add the reaching condition to the condition of the condition node if the false-branch does not exist. Otherwise, only simplify it. """ self.clean() if self.false_branch is None and not self.reaching_condition.is_true: self.condition &= self.reaching_condition - self.condition.remove_redundancy(z3_condition_of) + self.condition.remove_redundancy(condition_handler) self.reaching_condition = LogicCondition.initialize_true(self.reaching_condition.context) - super().simplify_reaching_condition(z3_condition_of) + super().simplify_reaching_condition(condition_handler) def switch_branches(self): """Switch the true-branch and false-branch, this includes negating the condition.""" diff --git a/decompiler/structures/ast/syntaxforest.py b/decompiler/structures/ast/syntaxforest.py index 207f7eb35..9e30acd8b 100644 --- a/decompiler/structures/ast/syntaxforest.py +++ b/decompiler/structures/ast/syntaxforest.py @@ -395,7 +395,7 @@ def combine_cascading_single_branch_conditions(self, root: Optional[AbstractSynt & condition_node.reaching_condition & nested_condition_node.reaching_condition ) - condition_node.condition = new_condition.remove_redundancy(self.condition_handler.get_z3_condition_map()) + condition_node.condition = new_condition.remove_redundancy(self.condition_handler) condition_node.reaching_condition = self.condition_handler.get_true_value() self.replace_condition_node_by_single_branch(nested_condition_node) diff --git a/decompiler/structures/logic/custom_logic.py b/decompiler/structures/logic/custom_logic.py index fcd085c32..1754ee233 100644 --- a/decompiler/structures/logic/custom_logic.py +++ b/decompiler/structures/logic/custom_logic.py @@ -1,10 +1,562 @@ -from decompiler.structures.logic.logic_interface import ConditionInterface +from __future__ import annotations + +import logging +from itertools import product +from typing import TYPE_CHECKING, Dict, Generic, Iterator, List, Sequence, Set, Tuple, TypeVar + +import decompiler.structures.pseudo as pseudo +from decompiler.structures.logic.logic_interface import ConditionInterface, PseudoLogicInterface +from decompiler.structures.pseudo import Condition +from simplifier.operations import BitwiseAnd, BitwiseNegate, BitwiseOr +from simplifier.visitor import ToCnfVisitor, ToDnfVisitor +from simplifier.visitor.serialize_visitor import SerializeVisitor +from simplifier.world.nodes import BaseVariable, BitVector, Constant, Operation, TmpVariable, Variable, WorldObject from simplifier.world.world import World +if TYPE_CHECKING: + from decompiler.structures.ast.condition_symbol import ConditionHandler + +LOGICCLASS = TypeVar("LOGICCLASS", bound="CustomLogicCondition") +PseudoLOGICCLASS = TypeVar("PseudoLOGICCLASS", bound="PseudoCustomLogicCondition") -class CustomLogicCondition(ConditionInterface, World): + +class CustomLogicCondition(ConditionInterface, Generic[LOGICCLASS]): """Class in charge of implementing generic logic operations using costume logic.""" - pass + def __init__(self, condition: WorldObject, tmp: bool = False): + if isinstance(condition, Variable): + self._variable = condition + else: + self._variable: BaseVariable = condition.world.new_variable(condition.size, tmp) + self.context.define(self._variable, condition) + + @classmethod + def generate_new_context(cls) -> World: + """Generate a context for z3-conditions.""" + return World() + + @property + def _condition(self) -> WorldObject: + if term := self.context.get_definition(self._variable): + return term + return self._variable + + def __len__(self) -> int: + """Return the length of a formula, which corresponds to its complexity.""" + if isinstance(self._condition, Variable): + return 1 + count = 0 + for node in self.context.iter_postorder(self._condition): + if not isinstance(node, Operation): + continue + count += sum(1 for op in node.operands if isinstance(op, Variable)) + return count + + def __str__(self) -> str: + """Return a string representation.""" + condition = self._condition + if isinstance(condition, Constant) and condition.size == 1: + return "false" if condition.unsigned == 0 else "true" + return str(condition) + + def copy(self) -> LOGICCLASS: + """Copy an instance of the Z3ConditionInterface.""" + return self.__class__(self._condition) + + @classmethod + def initialize_symbol(cls, name: str, context: World) -> LOGICCLASS: + """Create a symbol.""" + return cls(context.variable(name, 1)) + + @classmethod + def initialize_true(cls, context: World) -> LOGICCLASS: + """Return condition tag that represents True.""" + return cls(context.constant(1, 1)) + + @classmethod + def initialize_false(cls, context: World) -> LOGICCLASS: + """Return condition tag that represents False.""" + return cls(context.constant(0, 1)) + + @classmethod + def disjunction_of(cls, clauses: Sequence[LOGICCLASS]) -> LOGICCLASS: + """Create a disjunction for the list of given clauses.""" + world = clauses[0].context + return cls(world.bitwise_or(*(clause._condition for clause in clauses))) + + @classmethod + def conjunction_of(cls, clauses: Sequence[LOGICCLASS]) -> LOGICCLASS: + """Create a conjunction for the list of given clauses.""" + world = clauses[0].context + return cls(world.bitwise_and(*(clause._condition for clause in clauses))) + + def __and__(self, other: LOGICCLASS) -> LOGICCLASS: + """Logical and of two condition tag interfaces.""" + return self.__class__(self.context.bitwise_and(self._condition, other._condition)) + + def __or__(self, other: LOGICCLASS) -> LOGICCLASS: + """Logical or of two condition tag interfaces.""" + return self.__class__(self.context.bitwise_or(self._condition, other._condition)) + + def __invert__(self) -> LOGICCLASS: + """Logical negate of two condition tag interfaces.""" + return self.__class__(self._custom_negate(self._condition)) + + def _custom_negate(self, condition: WorldObject) -> WorldObject: + """Negate the given world object.""" + if isinstance(condition, BitwiseNegate): + return condition.operand + return self.context.bitwise_negate(condition) + + @property + def context(self) -> World: + """Return context of logic condition.""" + return self._variable.world + + @property + def is_true(self) -> bool: + """Check whether the tag is the 'true-symbol'.""" + return isinstance(self._condition, Constant) and self._condition.unsigned != 0 + + @property + def is_false(self) -> bool: + """Check whether the tag is the 'false-symbol'.""" + return isinstance(self._condition, Constant) and self._condition.unsigned == 0 + + @property + def is_disjunction(self) -> bool: + """Check whether the condition is a disjunction of conditions, i.e. A v B v C.""" + return isinstance(self._condition, BitwiseOr) + + @property + def is_conjunction(self) -> bool: + """Check whether the condition is a conjunction of conditions, i.e. A ^ B ^ C.""" + return isinstance(self._condition, BitwiseAnd) + + @property + def is_negation(self) -> bool: + """Check whether the condition is a negation of conditions, i.e. !A.""" + return isinstance(self._condition, BitwiseNegate) + + @property + def operands(self) -> List[LOGICCLASS]: + """Return all operands of the condition.""" + return self._get_operands() + + def _get_operands(self, tmp: bool = False) -> List[LOGICCLASS]: + """Get operands.""" + condition = self._condition + if isinstance(condition, BitVector): + return [] + assert isinstance(condition, Operation), f"The condition must be an operation." + return [self.__class__(operand, tmp) for operand in condition.operands] + + @property + def is_symbol(self) -> bool: + """Check whether the object is a symbol.""" + return self._is_symbol(self._condition) + + @property + def is_literal(self) -> bool: + """Check whether the object is a literal, i.e., a symbol or a negated symbol""" + return self._is_literal(self._condition) + + @property + def is_disjunction_of_literals(self) -> bool: + """ + Check whether the given condition is a disjunction of literals, i.e., whether it is + - a symbol, + - the negation of a symbol or + - a disjunction of symbols or negation of symbols. + """ + return self._is_disjunction_of_literals(self._condition) + + @property + def is_cnf_form(self) -> bool: + """Check whether the condition is already in cnf-form.""" + if self.is_true or self.is_false or self.is_disjunction_of_literals: + return True + return self.is_conjunction and all(self._is_disjunction_of_literals(clause) for clause in self._condition.operands) + + def is_equal_to(self, other: LOGICCLASS) -> bool: + """Check whether the conditions are equal, i.e., have the same from except the ordering.""" + return World.compare(self._condition, other._condition) + + def does_imply(self, other: LOGICCLASS) -> bool: + """Check whether the condition implies the given condition.""" + tmp_condition = self.__class__(self.context.bitwise_or(self._custom_negate(self._condition), other._condition)) + self.context.free_world_condition(tmp_condition._variable) + tmp_condition._variable.simplify() + does_imply_value = tmp_condition.is_true + self.context.cleanup([tmp_condition._variable]) + return does_imply_value + + def to_cnf(self) -> LOGICCLASS: + """Bring the condition tag into cnf-form.""" + if self.is_cnf_form: + return self + self.context.free_world_condition(self._variable) + ToCnfVisitor(self._variable) + return self + + def to_dnf(self) -> LOGICCLASS: + """Bring the condition tag into dnf-form.""" + dnf_form = self.copy() + self.context.free_world_condition(dnf_form._variable) + ToDnfVisitor(dnf_form._variable) + return dnf_form + + def simplify(self) -> LOGICCLASS: + """Simplify the given condition. Make sure that it does not destroy cnf-form.""" + if isinstance(self._variable, Variable): + self.context.free_world_condition(self._variable) + self._variable.simplify() + else: + new_var = self.context.variable(f"Simplify", 1) + self.context.define(new_var, self._condition) + self.context.free_world_condition(new_var) + new_var.simplify() + self._variable = self.context.new_variable(1, tmp=True) + self.context.substitute(new_var, self._variable) + return self + + def get_symbols(self) -> Iterator[LOGICCLASS]: + """Return all symbols used by the condition.""" + for symbol in self._get_symbols(self._condition): + yield self.__class__(symbol) + + def get_symbols_as_string(self) -> Iterator[str]: + """Return all symbols as strings.""" + for symbol in self._get_symbols(self._condition): + yield str(symbol) + + def get_literals(self) -> Iterator[LOGICCLASS]: + """Return all literals used by the condition.""" + for literal in self._get_literals(self._condition): + yield self.__class__(literal) + + def substitute_by_true(self, condition: LOGICCLASS) -> LOGICCLASS: + """ + Substitutes the given condition by true. + + Example: substituting in the expression (a∨b)∧c the condition (a∨b) by true results in the condition c, + and substituting the condition c by true in the condition (a∨b) + """ + assert self.context == condition.context, f"The condition must be contained in the same graph." + if not self.is_true and (self.is_equal_to(condition) or condition.does_imply(self)): + self._replace_condition_by_true() + return self + + self.to_cnf() + if self.is_true or self.is_false or self.is_negation or self.is_symbol: + return self + + condition_operands: List[LOGICCLASS] = condition._get_operands() + operands: List[LOGICCLASS] = self._get_operands() + numb_of_arg_expr: int = len(operands) if self.is_conjunction else 1 + numb_of_arg_cond: int = len(condition_operands) if condition.is_conjunction else 1 + + if numb_of_arg_expr <= numb_of_arg_cond: + self.context.cleanup() + return self + + subexpressions: List[LOGICCLASS] = [condition] if numb_of_arg_cond == 1 else condition_operands + self._replace_subexpressions_by_true(subexpressions) + to_remove = [cond._variable for cond in condition_operands + operands if cond._variable != cond._condition] + self.context.cleanup(to_remove) + return self + + def _replace_subexpressions_by_true(self, subexpressions: List[LOGICCLASS]): + """Replace each clause of the Custom-Condition by True, if it is contained in the list of given subexpressions.""" + for sub_expr_1, sub_expr_2 in product(subexpressions, self.operands): + if sub_expr_1.is_equivalent_to(sub_expr_2): + relations = self.context.get_relation(self._condition, sub_expr_2._condition) + for relation in relations: + self.context.remove_operand(self._condition, relation.sink) + + def _replace_condition_by_true(self) -> None: + """Replace the Custom Logic condition by True.""" + if self.is_symbol: + self._variable: BaseVariable = self.context.new_variable(self._condition.size) + self.context.define(self._variable, self.context.constant(1, 1)) + else: + self.context.replace(self._condition, self.context.constant(1, 1)) + self.context.cleanup() + + def remove_redundancy(self, condition_handler: ConditionHandler) -> LOGICCLASS: + """ + Simplify conditions by considering the pseudo-conditions (more advanced simplification). + + - The given formula is simplified using the given dictionary that maps to each symbol a pseudo-condition. + - This helps, for example for finding switch cases, because it simplifies the condition + 'x1 & x2' if 'x1 = var < 10' and 'x2 = var == 5' to the condition 'x2'. + """ + if self.is_literal or self.is_true or self.is_false: + return self + assert isinstance(self._condition, Operation), "We only remove redundancy for operations" + + real_condition, compared_expressions = self._replace_symbols_by_real_conditions(condition_handler) + + self.context.free_world_condition(real_condition._variable) + real_condition.simplify() + + self._replace_real_conditions_by_symbols(real_condition, compared_expressions, condition_handler) + + self.context.replace(self._condition, real_condition._condition) + self.context.cleanup() + return self + + def _replace_real_conditions_by_symbols( + self, + real_condition: PseudoCustomLogicCondition, + compared_expressions: Dict[Variable, pseudo.Expression], + condition_handler: ConditionHandler, + ): + """Replace all clauses of the given real-condition by symbols.""" + non_logic_operands = { + node + for node in self.context.iter_postorder(real_condition._variable) + if isinstance(node, Operation) and not isinstance(node, (BitwiseOr, BitwiseAnd, BitwiseNegate)) + } + replacement_dict = { + real_cond._condition: symbol._condition + for symbol, real_cond in condition_handler.get_z3_condition_map().items() + if any(operand in compared_expressions for operand in real_cond._condition.operands) + } + for operand in non_logic_operands: + negated_operand = operand.copy_tree().negate() + for condition, symbol in replacement_dict.items(): + if World.compare(condition, operand): + self.context.replace(operand, symbol) + break + if World.compare(condition, negated_operand): + self.context.replace(operand, self.context.bitwise_negate(symbol)) + break + else: + new_operands = list() + for op in operand.operands: + if op in compared_expressions: + new_operands.append(compared_expressions[op]) + else: + assert isinstance(op, Constant), f"The operand must be a Constant" + new_operands.append(pseudo.Constant(op.signed, pseudo.Integer(op.size, signed=True))) + condition_symbol = condition_handler.add_condition(Condition(self.OPERAND_MAPPING[operand.SYMBOL], new_operands)) + self.context.replace(operand, condition_symbol.symbol._condition) + + def _replace_symbols_by_real_conditions( + self, condition_handler: ConditionHandler + ) -> Tuple[PseudoCustomLogicCondition, Dict[Variable, pseudo.Expression]]: + """ + Return the real condition where the symbols are replaced by the conditions of the condition handler + as well as a mapping between the replaced symbols and the corresponding pseudo-expression. + """ + copied_condition = PseudoCustomLogicCondition(self._condition) + self.context.free_world_condition(copied_condition._variable) + condition_nodes = set(self.context.iter_postorder(copied_condition._variable)) + compared_expressions: Dict[Variable, pseudo.Expression] = dict() + for symbol in self.get_symbols(): + pseudo_condition: Condition = condition_handler.get_condition_of(symbol) + for operand in pseudo_condition.operands: + if not isinstance(operand, pseudo.Constant): + compared_expressions[self.context.variable(self._variable_name_for(operand))] = operand + self._replace_symbol(symbol, condition_handler, condition_nodes) + return copied_condition, compared_expressions + + def _replace_symbol(self, symbol: CustomLogicCondition, condition_handler: ConditionHandler, condition_nodes: Set[WorldObject]): + """ + Replace the given symbol by the corresponding pseudo-condition. + + :symbol: The symbol we want to replace in the custom-logic-condition + :condition_handler: The object handling the connection between the symbols, the pseudo-logic-condition, and the "real" condition + :condition_nodes: The set of all nodes in the world that belong to the custom-logic condition where we replace the symbols. + """ + world_condition = condition_handler.get_z3_condition_of(symbol)._condition + world_symbol = symbol._condition + for parent in [parent for parent in self.context.parent_operation(world_symbol) if parent in condition_nodes]: + for relation in self.context.get_relation(parent, world_symbol): + index = relation.index + self.context.remove_operand(parent, relation.sink) + self.context.add_operand(parent, world_condition, index) + + def serialize(self) -> str: + """Serialize the given condition into a SMT2 string representation.""" + return self._condition.accept(SerializeVisitor()) + + @classmethod + def deserialize(cls, data: str, context: World) -> LOGICCLASS: + """Deserialize the given string representing a z3 expression.""" + return CustomLogicCondition(context.from_string(data)) + + def rich_string_representation(self, condition_map: Dict[LOGICCLASS, pseudo.Condition]): + """Replace each symbol by the condition of the condition map and print this condition as string.""" + return self._rich_string_representation( + self._condition, {symbol._condition: condition for symbol, condition in condition_map.items()} + ) + + # some world-implementation helpers: + + def _is_symbol(self, condition: WorldObject) -> bool: + return isinstance(condition, Variable) and condition.size == 1 and self.context.get_definition(condition) is None + + def _is_literal(self, condition: WorldObject) -> bool: + return self._is_symbol(condition) or (isinstance(condition, BitwiseNegate) and self._is_symbol(condition.operand)) + + def _is_disjunction_of_literals(self, condition: WorldObject) -> bool: + """ + Check whether the given condition is a disjunction of literals, i.e., whether it is + - a symbol, + - the negation of a symbol or + - a disjunction of symbols or negation of symbols. + """ + if self._is_literal(condition): + return True + return isinstance(condition, BitwiseOr) and all(self._is_literal(operand) for operand in condition.operands) + + def _get_symbols(self, condition: WorldObject) -> Iterator[Variable]: + """Get symbols on World-level""" + for node in self.context.iter_postorder(condition): + if self._is_symbol(node): + yield node + + def _get_literals(self, condition: WorldObject) -> Iterator[WorldObject]: + """Get literals on World-level""" + if self._is_literal(condition): + yield condition + elif isinstance(condition, (BitwiseOr, BitwiseAnd, BitwiseNegate)): + for child in condition.operands: + yield from self._get_literals(child) + else: + assert isinstance(condition, Constant) and condition.size == 1, f"The condition {condition} does not consist of literals." + + def _rich_string_representation(self, condition: WorldObject, condition_map: Dict[Variable, pseudo.Condition]) -> str: + """Replace each symbol of the given condition by the pseudo-condition of the condition map and return this condition as string.""" + if self._is_symbol(condition): + if condition in condition_map: + return str(condition_map[condition]) + return f"{condition}" + if isinstance(condition, Constant) and condition.size == 1: + return "false" if condition.unsigned == 0 else "true" + if isinstance(condition, BitwiseNegate): + original_condition = condition.operand + if original_condition in condition_map: + return str(condition_map[original_condition].negate()) + return f"!{self._rich_string_representation(original_condition, condition_map)}" + if isinstance(condition, (BitwiseOr, BitwiseAnd)): + operands = condition.operands + symbol = "|" if isinstance(condition, BitwiseOr) else "&" + if len(operands) == 1: + return self._rich_string_representation(operands[0], condition_map) + return "(" + f" {symbol} ".join([f"{self._rich_string_representation(operand, condition_map)}" for operand in operands]) + ")" + return f"{condition}" + + @staticmethod + def _variable_name_for(expression: pseudo.Expression) -> str: + if isinstance(expression, pseudo.Variable): + return f"{expression},{expression.ssa_name}" + return f"{expression},{[str(var.ssa_name) for var in expression.requirements]}" + + OPERAND_MAPPING = { + "==": pseudo.OperationType.equal, + "!=": pseudo.OperationType.not_equal, + "s<=": pseudo.OperationType.less_or_equal, + "u<=": pseudo.OperationType.less_or_equal_us, + "s>": pseudo.OperationType.greater, + "u>": pseudo.OperationType.greater_us, + "s<": pseudo.OperationType.less, + "u<": pseudo.OperationType.less_us, + "s>=": pseudo.OperationType.greater_or_equal, + "u>=": pseudo.OperationType.greater_or_equal_us, + } + + +class PseudoCustomLogicCondition(PseudoLogicInterface, CustomLogicCondition, Generic[LOGICCLASS, PseudoLOGICCLASS]): + def __init__(self, condition: WorldObject, tmp: bool = False): + super().__init__(condition, tmp) + + @classmethod + def initialize_from_condition(cls, condition: pseudo.Condition, context: World) -> PseudoLOGICCLASS: + """Create the simplified condition from the condition of type Condition.""" + custom_condition = cls._get_custom_condition_of(condition, context) + return cls(custom_condition) + + @classmethod + def initialize_from_conditions_or(cls, conditions: List[pseudo.Condition], context: World) -> PseudoLOGICCLASS: + or_conditions = [] + for cond in conditions: + or_conditions.append(cls._get_custom_condition_of(cond, context)) + return cls(context.bitwise_or(*or_conditions)) + + @classmethod + def initialize_from_formula(cls, condition: LOGICCLASS, condition_map: Dict[LOGICCLASS, PseudoLOGICCLASS]) -> PseudoLOGICCLASS: + """Create the simplified condition from the condition that is a formula of symbols.""" + condition.to_cnf() + if condition.is_true: + return cls.initialize_true(condition.context) + if condition.is_false: + return cls.initialize_false(condition.context) + if condition.is_literal: + return cls._get_condition_of_literal(condition, condition_map) + if condition.is_disjunction: + return cls._get_condition_of_disjunction(condition, condition_map) + + operands = list() + for conjunction in condition.operands: + if conjunction.is_literal: + operands.append(cls._get_condition_of_literal(conjunction, condition_map)._condition) + else: + operands.append(cls._get_condition_of_disjunction(conjunction, condition_map)._condition) + + return cls(condition.context.bitwise_and(*operands)) + + @classmethod + def _get_condition_of_disjunction(cls, disjunction: LOGICCLASS, condition_map: Dict[LOGICCLASS, PseudoLOGICCLASS]) -> PseudoLOGICCLASS: + """Return for a disjunction (Or) the corresponding z3-condition.""" + assert disjunction.is_disjunction, f"The input must be a disjunction, but it is {disjunction}" + operands = [cls._get_condition_of_literal(operand, condition_map)._condition for operand in disjunction.operands] + return cls(disjunction.context.bitwise_or(*operands)) + + @staticmethod + def _get_condition_of_literal(literal: LOGICCLASS, condition_map: Dict[LOGICCLASS, PseudoLOGICCLASS]) -> PseudoLOGICCLASS: + """Given a literal, i.e., a symbol or a negation of a symbol, return the condition the symbol is mapped to.""" + assert literal.is_literal, f"The input must be a literal, but it is {literal}" + if literal.is_symbol: + return condition_map[literal] + return ~condition_map[~literal] + + @staticmethod + def _get_custom_condition_of(condition: pseudo.Condition, world: World) -> WorldObject: + """ + Convert a given condition a op b into the custom-condition bit_vec_a op bit_vec_b. + + a and b can be any type of Expression. The name of the bitvector reflects the expression as well as + the SSA-variable names that occur in the expression. + """ + if condition.left.type.size != condition.right.type.size: + logging.warning( + f"The operands of {condition} have different sizes: {condition.left.type.size} & {condition.right.type.size}. Increase the size of the smaller one." + ) + bit_vec_size = max(condition.left.type.size, condition.right.type.size, 1) + operand_1: BitVector = PseudoCustomLogicCondition._convert_expression(condition.left, bit_vec_size, world) + operand_2: BitVector = PseudoCustomLogicCondition._convert_expression(condition.right, bit_vec_size, world) + return PseudoCustomLogicCondition.SHORTHAND[condition.operation](world, operand_1, operand_2) + + @staticmethod + def _convert_expression(expression: pseudo.Expression, bit_vec_size: int, world: World) -> BitVector: + """Convert the given expression into a z3 bit-vector.""" + if isinstance(expression, pseudo.Constant): + return world.constant(expression.value, bit_vec_size) + else: + return world.variable(PseudoCustomLogicCondition._variable_name_for(expression), bit_vec_size) - # TODO implement all abstract methods + SHORTHAND = { + pseudo.OperationType.equal: lambda world, a, b: world.bool_equal(a, b), + pseudo.OperationType.not_equal: lambda world, a, b: world.bool_unequal(a, b), + pseudo.OperationType.less: lambda world, a, b: world.signed_lt(a, b), + pseudo.OperationType.less_or_equal: lambda world, a, b: world.signed_le(a, b), + pseudo.OperationType.greater: lambda world, a, b: world.signed_gt(a, b), + pseudo.OperationType.greater_or_equal: lambda world, a, b: world.signed_ge(a, b), + pseudo.OperationType.greater_us: lambda world, a, b: world.unsigned_gt(a, b), + pseudo.OperationType.less_us: lambda world, a, b: world.unsigned_lt(a, b), + pseudo.OperationType.greater_or_equal_us: lambda world, a, b: world.unsigned_ge(a, b), + pseudo.OperationType.less_or_equal_us: lambda world, a, b: world.unsigned_le(a, b), + } diff --git a/decompiler/structures/logic/logic_condition.py b/decompiler/structures/logic/logic_condition.py index ed0d27ded..c7f0d5883 100644 --- a/decompiler/structures/logic/logic_condition.py +++ b/decompiler/structures/logic/logic_condition.py @@ -1,11 +1,14 @@ from __future__ import annotations -from typing import Dict, Type, TypeVar +from typing import TYPE_CHECKING, Dict, Type, TypeVar +from decompiler.structures.logic.custom_logic import CustomLogicCondition, PseudoCustomLogicCondition from decompiler.structures.logic.interface_decorators import ensure_cnf -from decompiler.structures.logic.logic_interface import ConditionInterface from decompiler.structures.logic.z3_logic import PseudoZ3LogicCondition, Z3LogicCondition +if TYPE_CHECKING: + from decompiler.structures.ast.condition_symbol import ConditionHandler + LOGICCLASS = TypeVar("LOGICCLASS", bound="ConditionInterface") PseudoLOGICCLASS = TypeVar("PseudoLOGICCLASS", bound="PseudoLogicInterface") @@ -13,8 +16,8 @@ def generate_logic_condition_class(base) -> Type[LOGICCLASS]: class BLogicCondition(base[LOGICCLASS]): @ensure_cnf - def __init__(self, condition): - super().__init__(condition) + def __init__(self, condition, tmp: bool = False): + super().__init__(condition, tmp) def simplify_to_shortest(self, complexity_bound: int) -> BLogicCondition: """Simplify the condition to the shortest one (CNF or DNF).""" @@ -41,7 +44,7 @@ def _get_complexity_of_simplification(self) -> int: @ensure_cnf def substitute_by_true(self, condition: BLogicCondition) -> BLogicCondition: """ - Substitutes the given condition by true. + Substitutes the given condition by true, i.e., changes the condition under the assumption that the given condition fulfilled. Example: substituting in the expression (a∨b)∧c the condition (a∨b) by true results in the condition c, and substituting the condition c by true in the condition (a∨b) @@ -49,7 +52,7 @@ def substitute_by_true(self, condition: BLogicCondition) -> BLogicCondition: return super().substitute_by_true(condition) @ensure_cnf - def remove_redundancy(self, condition_map: Dict[BLogicCondition, PseudoLogicCondition]) -> BLogicCondition: + def remove_redundancy(self, condition_handler: ConditionHandler) -> BLogicCondition: """ More advanced simplification of conditions. @@ -57,19 +60,21 @@ def remove_redundancy(self, condition_map: Dict[BLogicCondition, PseudoLogicCond - This helps, for example for finding switch cases, because it simplifies the condition 'x1 & x2' if 'x1 = var < 10' and 'x2 = var == 5' to the condition 'x2'. """ - return super().remove_redundancy(condition_map) + return super().remove_redundancy(condition_handler) return BLogicCondition LogicCondition = generate_logic_condition_class(Z3LogicCondition) +# LogicCondition = generate_logic_condition_class(CustomLogicCondition) -def generate_pseudo_logic_condition_class(base) -> Type[PseudoLOGICCLASS]: - class BPseudoLogicCondition(LogicCondition, base[LOGICCLASS, PseudoLOGICCLASS]): +def generate_pseudo_logic_condition_class(base, log_cond=LogicCondition) -> Type[PseudoLOGICCLASS]: + class BPseudoLogicCondition(log_cond, base[LOGICCLASS, PseudoLOGICCLASS]): pass return BPseudoLogicCondition PseudoLogicCondition = generate_pseudo_logic_condition_class(PseudoZ3LogicCondition) +# PseudoLogicCondition = generate_pseudo_logic_condition_class(PseudoCustomLogicCondition) diff --git a/decompiler/structures/logic/logic_interface.py b/decompiler/structures/logic/logic_interface.py index a60289d4e..d7eaa50fe 100644 --- a/decompiler/structures/logic/logic_interface.py +++ b/decompiler/structures/logic/logic_interface.py @@ -1,10 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, Generic, Iterable, Iterator, List, TypeVar +from typing import TYPE_CHECKING, Dict, Generic, Iterable, Iterator, List, Sequence, TypeVar from decompiler.structures.pseudo import Condition +if TYPE_CHECKING: + from decompiler.structures.ast.condition_symbol import ConditionHandler + CONTEXT = TypeVar("CONTEXT") @@ -98,12 +101,12 @@ def initialize_false(cls, context: CONTEXT) -> ConditionInterface: @classmethod @abstractmethod - def disjunction_of(cls, clauses: Iterable[ConditionInterface]) -> ConditionInterface: + def disjunction_of(cls, clauses: Sequence[ConditionInterface]) -> ConditionInterface: """Creates a disjunction for the list of given clauses.""" @classmethod @abstractmethod - def conjunction_of(cls, clauses: Iterable[ConditionInterface]) -> ConditionInterface: + def conjunction_of(cls, clauses: Sequence[ConditionInterface]) -> ConditionInterface: """Creates a conjunction for the list of given clauses.""" @abstractmethod @@ -150,7 +153,7 @@ def is_disjunction_of_literals(self) -> bool: @property def is_cnf_form(self) -> bool: """Check whether the condition is already in cnf-form.""" - if self.is_disjunction_of_literals: + if self.is_true or self.is_false or self.is_disjunction_of_literals: return True return self.is_conjunction and all(clause.is_disjunction_of_literals for clause in self.operands) @@ -209,7 +212,7 @@ def substitute_by_true(self, condition: ConditionInterface) -> ConditionInterfac """ @abstractmethod - def remove_redundancy(self, condition_map: Dict[ConditionInterface, ConditionInterface]) -> ConditionInterface: + def remove_redundancy(self, condition_handler: ConditionHandler) -> ConditionInterface: """ More advanced simplification of conditions. diff --git a/decompiler/structures/logic/z3_logic.py b/decompiler/structures/logic/z3_logic.py index e8b059f1f..ef7ad86eb 100644 --- a/decompiler/structures/logic/z3_logic.py +++ b/decompiler/structures/logic/z3_logic.py @@ -2,13 +2,15 @@ import functools from itertools import product -from typing import Dict, Generic, Iterable, Iterator, List, TypeVar +from typing import TYPE_CHECKING, Dict, Generic, Iterable, Iterator, List, Sequence, TypeVar from decompiler.structures.logic.logic_interface import ConditionInterface, PseudoLogicInterface from decompiler.structures.logic.z3_implementations import Z3Implementation from decompiler.structures.pseudo import Condition from z3 import And, Bool, BoolRef, BoolVal, Context, Not, Or, Solver, is_and, is_false, is_not, is_or, is_true, substitute +if TYPE_CHECKING: + from decompiler.structures.ast.condition_symbol import ConditionHandler LOGICCLASS = TypeVar("LOGICCLASS", bound="Z3LogicCondition") PseudoLOGICCLASS = TypeVar("PseudoLOGICCLASS", bound="PseudoZ3LogicCondition") @@ -19,7 +21,7 @@ class Z3LogicCondition(ConditionInterface, Generic[LOGICCLASS]): SIMPLIFICATION_THRESHOLD = 2000 COMPLEXITY_THRESHOLD = 100000 - def __init__(self, condition: BoolRef): + def __init__(self, condition: BoolRef, tmp: bool = False): self._condition: BoolRef = condition self.z3 = Z3Implementation(True, self.SIMPLIFICATION_THRESHOLD, self.COMPLEXITY_THRESHOLD) @@ -52,12 +54,12 @@ def initialize_false(cls, context: Context) -> LOGICCLASS: return cls(BoolVal(False, ctx=context)) @classmethod - def disjunction_of(cls, clauses: Iterable[LOGICCLASS]) -> LOGICCLASS: + def disjunction_of(cls, clauses: Sequence[LOGICCLASS]) -> LOGICCLASS: """Creates a disjunction for the list of given clauses.""" return cls(functools.reduce(Or, [clause._condition for clause in clauses])) @classmethod - def conjunction_of(cls, clauses: List[LOGICCLASS]) -> LOGICCLASS: + def conjunction_of(cls, clauses: Sequence[LOGICCLASS]) -> LOGICCLASS: """Creates a conjunction for the list of given clauses.""" return cls(functools.reduce(And, [clause._condition for clause in clauses])) @@ -189,7 +191,7 @@ def substitute_by_true(self, condition: LOGICCLASS) -> LOGICCLASS: Example: substituting in the expression (a∨b)∧c the condition (a∨b) by true results in the condition c, and substituting the condition c by true in the condition (a∨b) """ - if self.is_equivalent_to(condition): + if condition.does_imply(self): self._condition = BoolVal(True, ctx=condition.context) return self self.to_cnf() @@ -212,7 +214,7 @@ def substitute_by_true(self, condition: LOGICCLASS) -> LOGICCLASS: self._condition = expression return self - def remove_redundancy(self, condition_map: Dict[LOGICCLASS, PseudoZ3LogicCondition]) -> LOGICCLASS: + def remove_redundancy(self, condition_handler: ConditionHandler) -> LOGICCLASS: """ More advanced simplification of conditions. @@ -220,6 +222,9 @@ def remove_redundancy(self, condition_map: Dict[LOGICCLASS, PseudoZ3LogicConditi - This helps, for example for finding switch cases, because it simplifies the condition 'x1 & x2' if 'x1 = var < 10' and 'x2 = var == 5' to the condition 'x2'. """ + if self.is_literal or self.is_true or self.is_false: + return self + condition_map = condition_handler.get_z3_condition_map() condition: BoolRef = self._condition replacement_to_z3 = list() replacement_to_symbol = list() @@ -255,7 +260,7 @@ def rich_string_representation(self, condition_map: Dict[LOGICCLASS, Condition]) class PseudoZ3LogicCondition(PseudoLogicInterface, Z3LogicCondition, Generic[LOGICCLASS, PseudoLOGICCLASS]): - def __init__(self, condition: BoolRef): + def __init__(self, condition: BoolRef, tmp: bool = False): super().__init__(condition) self.z3 = Z3Implementation(False, self.SIMPLIFICATION_THRESHOLD, self.COMPLEXITY_THRESHOLD) diff --git a/tests/structures/logic/test_custom_logic.py b/tests/structures/logic/test_custom_logic.py new file mode 100644 index 000000000..342eb417c --- /dev/null +++ b/tests/structures/logic/test_custom_logic.py @@ -0,0 +1,937 @@ +from typing import List, Tuple + +import pytest +from decompiler.structures.ast.condition_symbol import ConditionHandler, ConditionSymbol +from decompiler.structures.logic.custom_logic import CustomLogicCondition, PseudoCustomLogicCondition +from decompiler.structures.pseudo import BinaryOperation, Condition, Constant, Integer, OperationType, Variable +from simplifier.world.nodes import TmpVariable, WorldObject +from simplifier.world.world import World + + +class MockConditionHandler(ConditionHandler): + def add_condition(self, condition: Condition) -> ConditionSymbol: + """Adds a condition to the condition map.""" + symbol = self._get_next_symbol() + z3_condition = PseudoCustomLogicCondition.initialize_from_condition(condition, self._logic_context) + condition_symbol = ConditionSymbol(condition, symbol, z3_condition) + self._condition_map[symbol] = condition_symbol + return condition_symbol + + def _get_next_symbol(self) -> CustomLogicCondition: + """Get the next unused symbol name.""" + self._symbol_counter += 1 + return CustomLogicCondition.initialize_symbol(f"x{self._symbol_counter}", self._logic_context) + + +def b_x(i: int, world: World) -> WorldObject: + return world.variable(f"x{i}", 1) + + +def custom_x(i: int, world: World) -> CustomLogicCondition: + return CustomLogicCondition.initialize_symbol(f"x{i}", world) + + +def true_value(world: World) -> CustomLogicCondition: + return CustomLogicCondition.initialize_true(world) + + +def false_value(world: World) -> CustomLogicCondition: + return CustomLogicCondition.initialize_false(world) + + +def custom_variable(world: World, name: str = "a + 0x5,['eax#3']", size: int = 32) -> WorldObject: + return world.variable(name, size) + + +def custom_constant(world: World, const: int, size=32) -> WorldObject: + return world.constant(const, size) + + +def lower(variable: WorldObject, const: int) -> PseudoCustomLogicCondition: + custom_condition = variable.world.signed_lt(variable, custom_constant(variable.world, const)) + return PseudoCustomLogicCondition(custom_condition) + + +def lower_eq(variable: WorldObject, const: int) -> PseudoCustomLogicCondition: + custom_condition = variable.world.signed_lq(variable, custom_constant(variable.world, const)) + return PseudoCustomLogicCondition(custom_condition) + + +def equal(variable: WorldObject, const: int) -> PseudoCustomLogicCondition: + custom_condition = variable.world.bool_equal(variable, custom_constant(variable.world, const)) + return PseudoCustomLogicCondition(custom_condition) + + +def u_lower_eq(variable: WorldObject, const: int) -> PseudoCustomLogicCondition: + custom_condition = variable.world.unsigned_le(variable, custom_constant(variable.world, const)) + return PseudoCustomLogicCondition(custom_condition) + + +def u_greater(variable: WorldObject, const: int) -> PseudoCustomLogicCondition: + custom_condition = variable.world.unsigned_gt(variable, custom_constant(variable.world, const)) + return PseudoCustomLogicCondition(custom_condition) + + +constant_4 = Constant(4, Integer.int32_t()) +constant_5 = Constant(5, Integer.int32_t()) +constant_10 = Constant(10, Integer.int32_t()) +constant_20 = Constant(20, Integer.int32_t()) + +var_a = Variable( + "a", Integer.int32_t(), ssa_label=None, is_aliased=False, ssa_name=Variable("eax", Integer.int32_t(), ssa_label=3, is_aliased=False) +) +var_b = Variable( + "b", Integer.int32_t(), ssa_label=None, is_aliased=False, ssa_name=Variable("edx", Integer.int32_t(), ssa_label=5, is_aliased=False) +) + + +def _get_is_instance_test_case( + world: World, true_val=False, false_val=False, symbol=False, and_f=False, or_f=False, neg_symbol=False +) -> List[Tuple[CustomLogicCondition, bool]]: + return [ + (true_value(world), true_val), + (false_value(world), false_val), + (custom_x(1, world), symbol), + (custom_x(1, world) | custom_x(2, world), or_f), + (custom_x(1, world) & custom_x(2, world), and_f), + (~custom_x(1, world), neg_symbol), + ] + + +def _get_operation_instances(world: World) -> List[Tuple[WorldObject, WorldObject]]: + return [ + (b_x(1, world), b_x(2, world)), + (world.bitwise_and(b_x(1, world), b_x(2, world)), b_x(3, world)), + (b_x(1, world), world.bitwise_or(b_x(2, world), world.bitwise_negate(b_x(3, world)))), + ] + + +def _get_normal_forms(form): + init_world = World() + terms = [ + ~custom_x(1, init_world), + (~custom_x(1, init_world) | custom_x(2, init_world)) & (custom_x(3, init_world) | ~custom_x(1, init_world)), + (~custom_x(1, init_world) | custom_x(2, init_world)) + & (custom_x(3, init_world) | ~custom_x(1, init_world)) + & (custom_x(4, init_world) | (custom_x(2, init_world) & custom_x(3, init_world))), + (custom_x(2, init_world) & ~custom_x(1, init_world)) | (custom_x(3, init_world) & ~custom_x(1, init_world)), + custom_x(1, init_world) + | (custom_x(2, init_world) & ~(custom_x(1, init_world))) + | (custom_x(3, init_world) & ~(custom_x(1, init_world) | custom_x(2, init_world))) + | (custom_x(5, init_world) & custom_x(4, init_world) & ~custom_x(1, init_world)), + ((custom_x(2, init_world) | custom_x(4, init_world)) & ~custom_x(1, init_world)) + | ((custom_x(3, init_world) | custom_x(4, init_world)) & (custom_x(5, init_world) | ~custom_x(1, init_world))), + ] + if form == "cnf": + cnf_world = World() + result = [ + ~custom_x(1, cnf_world), + (custom_x(2, cnf_world) | ~custom_x(1, cnf_world)) & (custom_x(3, cnf_world) | ~custom_x(1, cnf_world)), + CustomLogicCondition.conjunction_of( + [ + (custom_x(2, cnf_world) | ~custom_x(1, cnf_world)), + (custom_x(3, cnf_world) | ~custom_x(1, cnf_world)), + (custom_x(2, cnf_world) | custom_x(4, cnf_world)), + (custom_x(3, cnf_world) | custom_x(4, cnf_world)), + ] + ), + (custom_x(2, cnf_world) | custom_x(3, cnf_world)) & ~custom_x(1, cnf_world), + CustomLogicCondition.disjunction_of( + [custom_x(1, cnf_world), custom_x(2, cnf_world), custom_x(3, cnf_world), custom_x(5, cnf_world)] + ) + & CustomLogicCondition.disjunction_of( + [custom_x(1, cnf_world), custom_x(4, cnf_world), custom_x(2, cnf_world), custom_x(3, cnf_world)] + ), + CustomLogicCondition.conjunction_of( + [ + CustomLogicCondition.disjunction_of([custom_x(2, cnf_world), custom_x(3, cnf_world), custom_x(4, cnf_world)]), + CustomLogicCondition.disjunction_of([~custom_x(1, cnf_world), custom_x(3, cnf_world), custom_x(4, cnf_world)]), + ~custom_x(1, cnf_world) | custom_x(5, cnf_world), + ] + ), + ] + elif form == "dnf": + dnf_world = World() + result = [ + ~custom_x(1, dnf_world), + ~custom_x(1, dnf_world) | (custom_x(3, dnf_world) & custom_x(2, dnf_world)), + (custom_x(3, dnf_world) & custom_x(2, dnf_world)) | (custom_x(4, dnf_world) & ~custom_x(1, dnf_world)), + (custom_x(2, dnf_world) & ~custom_x(1, dnf_world)) | (custom_x(3, dnf_world) & ~custom_x(1, dnf_world)), + CustomLogicCondition.disjunction_of( + [custom_x(1, dnf_world), custom_x(2, dnf_world), custom_x(3, dnf_world), (custom_x(5, dnf_world) & custom_x(4, dnf_world))] + ), + CustomLogicCondition.disjunction_of( + [ + custom_x(2, dnf_world) & ~custom_x(1, dnf_world), + custom_x(4, dnf_world) & ~custom_x(1, dnf_world), + custom_x(3, dnf_world) & ~custom_x(1, dnf_world), + custom_x(3, dnf_world) & custom_x(5, dnf_world), + custom_x(4, dnf_world) & custom_x(5, dnf_world), + ] + ), + ] + else: + raise ValueError(f"wrong input") + return [(term, normal_form) for term, normal_form in zip(terms, result)] + + +class TestCustomLogicCondition: + """Test the z3-logic condition.""" + + # Part implemented in the ConditionInterface + @pytest.mark.parametrize( + "world, term, length", + [ + (world := World(), CustomLogicCondition.initialize_true(world), 0), + (world := World(), CustomLogicCondition.initialize_false(world), 0), + (world := World(), custom_x(1, world), 1), + (world := World(), ~custom_x(1, world), 1), + (world := World(), custom_x(1, world) | custom_x(2, world), 2), + (world := World(), custom_x(1, world) & custom_x(2, world), 2), + (world := World(), (custom_x(1, world) & custom_x(2, world)) | custom_x(3, world), 3), + (world := World(), (custom_x(1, world) & custom_x(2, world)) | (custom_x(1, world) & custom_x(3, world)), 4), + ], + ) + def test_len(self, world, term, length): + assert len(term) == length + + @pytest.mark.parametrize( + "term, result", + _get_is_instance_test_case(world := World(), symbol=True, neg_symbol=True) + [(~(custom_x(1, world) | custom_x(2, world)), False)], + ) + def test_is_literal(self, term: CustomLogicCondition, result: bool): + assert term.is_literal == result + + @pytest.mark.parametrize( + "world, term, result", + [ + (world := World(), custom_x(1, world), True), + (world := World(), ~custom_x(1, world), True), + (world := World(), custom_x(1, world) | custom_x(2, world), True), + (world := World(), ~custom_x(1, world) | custom_x(2, world), True), + (world := World(), (~custom_x(1, world) | custom_x(2, world) | custom_x(3, world)).simplify(), True), + (world := World(), custom_x(1, world) & custom_x(2, world), False), + (world := World(), (custom_x(1, world) | custom_x(2, world)) & custom_x(3, world), False), + (world := World(), (custom_x(1, world) & custom_x(2, world)) | custom_x(3, world), False), + ], + ) + def test_is_disjunction_of_literals(self, world, term, result): + assert term.is_disjunction_of_literals == result + + @pytest.mark.parametrize( + "world, term, result", + [ + (world := World(), custom_x(1, world), True), + (world := World(), ~custom_x(1, world), True), + (world := World(), custom_x(1, world) | custom_x(2, world), True), + (world := World(), ~custom_x(1, world) | custom_x(2, world), True), + (world := World(), (~custom_x(1, world) | custom_x(2, world) | custom_x(3, world)).simplify(), True), + (world := World(), custom_x(1, world) & custom_x(2, world), True), + (world := World(), (custom_x(1, world) | custom_x(2, world)) & custom_x(3, world), True), + (world := World(), (custom_x(1, world) | ~custom_x(2, world)) & ~custom_x(3, world), True), + (world := World(), (custom_x(1, world) & custom_x(2, world)) | custom_x(3, world), False), + (world := World(), ((custom_x(1, world) & custom_x(2, world)) | custom_x(3, world)) & custom_x(4, world), False), + ], + ) + def test_is_cnf_form(self, world, term, result): + assert term.is_cnf_form == result + + @pytest.mark.parametrize( + "world, term1, term2, result", + [ + ( + world := World(), + CustomLogicCondition.disjunction_of( + ( + (custom_x(1, world) & ~custom_x(1, world)), + ~custom_x(2, world), + (custom_x(3, world) & (custom_x(4, world) | ~custom_x(4, world))), + ~(custom_x(5, world) & custom_x(2, world) & ~custom_x(1, world)), + (~(custom_x(5, world) & ~custom_x(5, world)) & custom_x(1, world)), + ~(custom_x(3, world) | ~custom_x(3, world)), + ) + ), + custom_x(1, world) | ~custom_x(5, world) | ~custom_x(2, world) | custom_x(3, world), + True, + ), + ( + world := World(), + custom_x(1, world) + | (custom_x(2, world) & ~custom_x(1, world)) + | (custom_x(3, world) & ~(custom_x(1, world) | custom_x(2, world))) + | (custom_x(5, world) & custom_x(4, world) & ~custom_x(1, world)), + custom_x(1, world) | custom_x(2, world) | custom_x(3, world) | (custom_x(5, world) & custom_x(4, world)), + True, + ), + ( + world := World(), + custom_x(1, world) + | (custom_x(2, world) & ~custom_x(1, world)) + | (custom_x(3, world) & ~(custom_x(1, world) | custom_x(2, world))) + | (custom_x(5, world) & custom_x(4, world) & ~custom_x(1, world)), + (custom_x(1, world) | custom_x(2, world) | custom_x(3, world) | custom_x(5, world)) + & (custom_x(1, world) | custom_x(4, world) | custom_x(2, world) | custom_x(3, world)), + True, + ), + ( + world := World(), + custom_x(1, world) & custom_x(2, world), + custom_x(1, world) & custom_x(2, world) & custom_x(3, world), + False, + ), + ( + world := World(), + custom_x(1, world) & custom_x(2, world), + (custom_x(1, world) & custom_x(2, world)) | custom_x(1, world), + False, + ), + ], + ) + def test_is_equivalent_to(self, world, term1, term2, result): + assert term1.is_equivalent_to(term2) == result + + @pytest.mark.parametrize( + "world, term1, term2, result", + [ + (world := World(), custom_x(1, world), custom_x(1, world) | custom_x(2, world), True), + (world := World(), custom_x(1, world), custom_x(1, world) & custom_x(2, world), False), + ( + world := World(), + (custom_x(1, world) | custom_x(2, world)) & (~custom_x(1, world) | custom_x(3, world)), + (custom_x(1, world) & custom_x(3, world)) + | (~custom_x(1, world) & custom_x(2, world)) + | (custom_x(1, world) & custom_x(4, world)), + True, + ), + ( + world := World(), + (custom_x(1, world) | custom_x(2, world)) & (~custom_x(1, world) | custom_x(3, world)), + (custom_x(1, world) & custom_x(3, world)) + | (custom_x(1, world) & custom_x(2, world)) + | (custom_x(1, world) & custom_x(4, world)), + False, + ), + ], + ) + def test_does_imply(self, world, term1, term2, result): + assert term1.does_imply(term2) == result + + @pytest.mark.parametrize( + "world, term1, term2, result", + [ + (world := World(), true_value(world), false_value(world), False), + (world := World(), false_value(world), true_value(world), False), + (world := World(), custom_x(1, world) & ~custom_x(1, world), true_value(world), False), + (world := World(), custom_x(1, world) | ~custom_x(1, world), false_value(world), False), + (world := World(), custom_x(1, world), ~custom_x(1, world), True), + (world := World(), custom_x(1, world) | custom_x(2, world), ~custom_x(1, world) & ~custom_x(2, world), True), + (world := World(), custom_x(1, world) & custom_x(2, world), ~(custom_x(1, world) & custom_x(2, world)), True), + ( + world := World(), + custom_x(1, world) | custom_x(2, world), + (~custom_x(1, world) & ~custom_x(2, world)) | custom_x(1, world), + False, + ), + ( + world := World(), + custom_x(1, world) & custom_x(2, world), + (~custom_x(1, world) | ~custom_x(2, world)) & custom_x(1, world), + False, + ), + ], + ) + def test_is_complementary_to(self, world, term1, term2, result): + assert term1.is_complementary_to(term2) == result + + @pytest.mark.parametrize( + "world, term", + [ + (world := World(), world.constant(1, 1)), + (world := World(), world.constant(0, 1)), + (world := World(), b_x(1, world)), + (world := World(), world.bitwise_negate(b_x(1, world))), + (world := World(), world.bitwise_and(b_x(1, world), b_x(2, world))), + (world := World(), world.bitwise_or(b_x(1, world), b_x(2, world))), + (world := World(), world.bitwise_and(world.bitwise_or(b_x(1, world), b_x(2, world)), b_x(3, world))), + ], + ) + def test_init(self, world, term): + cond = CustomLogicCondition(term) + assert cond._condition == term + + def test_initialize_symbol(self): + world = World() + cond = CustomLogicCondition.initialize_symbol("x1", world) + assert cond._condition == World().variable("x1", 1) + + def test_initialize_true(self): + world = World() + cond = CustomLogicCondition.initialize_true(world) + assert cond._condition == World().constant(1, 1) + + def test_initialize_false(self): + world = World() + cond = CustomLogicCondition.initialize_false(world) + assert cond._condition == World().constant(0, 1) + + @pytest.mark.parametrize("term1, term2", _get_operation_instances(world := World())) + def test_and(self, term1, term2): + cond = CustomLogicCondition(term1) & CustomLogicCondition(term2) + assert World.compare(cond._condition, World().bitwise_and(term1, term2)) + + @pytest.mark.parametrize("term1, term2", _get_operation_instances(world := World())) + def test_or(self, term1, term2): + cond = CustomLogicCondition(term1) | CustomLogicCondition(term2) + assert World.compare(cond._condition, World().bitwise_or(term1, term2)) + + @pytest.mark.parametrize("term1, term2", _get_operation_instances(world := World())) + def test_negate(self, term1, term2): + cond = ~CustomLogicCondition(term1) + assert World.compare(cond._condition, World().bitwise_negate(term1)) + + @pytest.mark.parametrize( + "world, term, string", + [ + (world := World(), world.constant(1, 1), "true"), + (world := World(), world.constant(0, 1), "false"), + ( + world := World(), + world.bitwise_or( + world.bitwise_and(b_x(1, world)), + world.bitwise_negate(b_x(2, world)), + world.bitwise_and(b_x(3, world), world.bitwise_or(world.bitwise_negate(b_x(4, world)))), + world.bitwise_negate(world.bitwise_and(b_x(5, world), b_x(2, world), world.bitwise_negate(b_x(1, world)))), + world.bitwise_and( + world.bitwise_negate(world.bitwise_and(b_x(5, world), world.bitwise_negate(b_x(5, world)))), + b_x(1, world), + ), + world.bitwise_negate(world.bitwise_or(b_x(3, world), world.bitwise_negate(b_x(3, world)))), + ), + "(x1 | ~x2 | (x3 & ~x4) | ~(x5 & x2 & ~x1) | (~(x5 & ~x5) & x1) | ~(x3 | ~x3))", + ), + ( + world := World(), + world.bitwise_or( + world.bitwise_and(b_x(1, world), world.bitwise_negate(b_x(1, world))), + world.bitwise_negate(b_x(2, world)), + world.bitwise_and(b_x(3, world), world.bitwise_or(b_x(4, world), world.bitwise_negate(b_x(4, world)))), + world.bitwise_negate(world.bitwise_and(b_x(5, world), b_x(2, world), world.bitwise_negate(b_x(1, world)))), + world.bitwise_and( + world.bitwise_negate(world.bitwise_and(b_x(5, world), world.bitwise_negate(b_x(5, world)))), + b_x(1, world), + ), + world.bitwise_negate(world.bitwise_or(b_x(3, world), world.bitwise_negate(b_x(3, world)))), + ), + "((x1 & ~x1) | ~x2 | (x3 & (x4 | ~x4)) | ~(x5 & x2 & ~x1) | (~(x5 & ~x5) & x1) | ~(x3 | ~x3))", + ), + ], + ) + def test_string(self, world, term, string): + cond = CustomLogicCondition(term) + assert str(cond) == string + + @pytest.mark.parametrize("term, result", _get_is_instance_test_case(world=World(), true_val=True)) + def test_is_true(self, term, result): + assert term.is_true == result + + @pytest.mark.parametrize("term, result", _get_is_instance_test_case(world=World(), false_val=True)) + def test_is_false(self, term, result): + assert term.is_false == result + + @pytest.mark.parametrize("term, result", _get_is_instance_test_case(world=World(), or_f=True)) + def test_is_disjunction(self, term, result): + assert term.is_disjunction == result + + @pytest.mark.parametrize("term, result", _get_is_instance_test_case(world=World(), and_f=True)) + def test_is_conjunction(self, term, result): + assert term.is_conjunction == result + + @pytest.mark.parametrize("term, result", _get_is_instance_test_case(world=World(), neg_symbol=True)) + def test_is_negation(self, term, result): + assert term.is_negation == result + + @pytest.mark.parametrize( + "world, term, operands", + [ + (world := World(), true_value(world), []), + (world := World(), false_value(world), []), + (world := World(), custom_x(1, world), []), + (world := World(), custom_x(1, world) | custom_x(2, world), [custom_x(1, world), custom_x(2, world)]), + (world := World(), custom_x(1, world) & custom_x(2, world), [custom_x(1, world), custom_x(2, world)]), + (world := World(), ~custom_x(1, world), [custom_x(1, world)]), + ( + world := World(), + (custom_x(1, world) | custom_x(2, world)) & custom_x(3, world), + [custom_x(1, world) | custom_x(2, world), custom_x(3, world)], + ), + ], + ) + def test_operands(self, world, term, operands): + assert [str(op) for op in term.operands] == [str(op) for op in operands] + + @pytest.mark.parametrize( + "world, term, result", + [ + (world := World(), world.constant(1, 1), False), + (world := World(), world.constant(0, 1), False), + (world := World(), world.bitwise_negate(b_x(1, world)), False), + (world := World(), world.bitwise_and(b_x(1, world), b_x(2, world)), False), + (world := World(), world.bitwise_or(world.bitwise_negate(b_x(1, world)), b_x(1, world)), False), + (world := World(), b_x(1, world), True), + ], + ) + def test_is_symbol(self, world, term, result): + """Check whether the object is a symbol.""" + cond = CustomLogicCondition(term) + return cond.is_symbol == result + + @pytest.mark.parametrize( + "term1, term2, result", + [ + (b_x(1, World()), b_x(2, World()), False), + (b_x(1, World()), (world := World()).bitwise_negate(b_x(1, world)), False), + (b_x(1, World()), (world := World()).bitwise_and(b_x(1, world)), False), + (b_x(1, World()), (world := World()).bitwise_or(b_x(1, world)), False), + ( + (world := World()).bitwise_and(b_x(1, world), b_x(2, world), b_x(2, world)), + (world := World()).bitwise_and(b_x(1, world), b_x(1, world), b_x(2, world)), + False, + ), + ( + (world := World()).bitwise_and(b_x(1, world), world.bitwise_and(b_x(2, world), b_x(3, world))), + (world := World()).bitwise_and(world.bitwise_and(b_x(1, world), b_x(2, world)), b_x(3, world)), + True, + ), + ( + (world := World()).bitwise_and(b_x(1, world), b_x(2, world), b_x(2, world)), + (world := World()).bitwise_and(b_x(1, world), b_x(2, world)), + False, + ), + ( + (world := World()).bitwise_and(b_x(1, world), b_x(2, world)), + (world := World()).bitwise_and(b_x(1, world), b_x(1, world), b_x(2, world)), + False, + ), + ( + (world := World()).bitwise_and(b_x(1, world), b_x(2, world)), + (world := World()).bitwise_and(b_x(2, world), b_x(1, world)), + True, + ), + ( + (world := World()).bitwise_and(b_x(1, world), world.bitwise_or(b_x(2, world), b_x(3, world))), + (world := World()).bitwise_and(world.bitwise_or(b_x(3, world), b_x(2, world)), b_x(1, world)), + True, + ), + ], + ) + def test_is_equal_to(self, term1, term2, result): + cond1 = CustomLogicCondition(term1) + cond2 = CustomLogicCondition(term2) + return cond1.is_equal_to(cond2) == result + + @pytest.mark.parametrize( + "term1, term2, result", + [ + (custom_variable(World(), "x1", 1), custom_variable(World(), "x2", 1), False), + (custom_variable(World(), "x1", 1), custom_variable(World(), "x1", 1), True), + (custom_variable(World(), "x1", 1), (world := World()).bitwise_negate(custom_variable(world, "x1", 1)), False), + (custom_constant(World(), 1, 1), custom_constant(World(), 1, 1), True), + (custom_constant(World(), 0, 1), custom_constant(World(), 0, 1), True), + (custom_constant(World(), 0, 1), custom_constant(World(), 1, 1), False), + ( + (world := World()).bitwise_and( + custom_variable(world, "x1"), custom_variable(world, "x2", 1), custom_variable(world, "x3", 1) + ), + (world := World()).bitwise_and( + custom_variable(world, "x1", 1), custom_variable(world, "x2", 1), custom_variable(world, "x3", 1) + ), + True, + ), + ( + (world := World()).bitwise_and(custom_variable(world, "x1", 1), custom_variable(world, "x2", 2)), + (world := World()).bitwise_and(custom_variable(world, "x2", 1), custom_variable(world, "x1", 1)), + True, + ), + ( + (world := World()).bitwise_and( + custom_variable(world, "x1", 1), world.bitwise_or(custom_variable(world, "x2", 1), custom_variable(world, "x3", 1)) + ), + (world := World()).bitwise_and( + world.bitwise_or(custom_variable(world, "x3", 1), custom_variable(world, "x2", 1)), custom_variable(world, "x1", 1) + ), + True, + ), + ], + ) + def test_is_equal_to_different_context(self, term1, term2, result): + cond1 = CustomLogicCondition(term1) + cond2 = CustomLogicCondition(term2) + assert cond1.is_equal_to(cond2) == result and cond1.context != cond2.context + + @pytest.mark.parametrize("term, cnf_term", _get_normal_forms("cnf")) + def test_to_cnf(self, term, cnf_term): + """Bring condition tag into cnf-form.""" + assert term.to_cnf().is_equal_to(cnf_term) + + @pytest.mark.parametrize("term, dnf_term", _get_normal_forms("dnf")) + def test_to_dnf(self, term, dnf_term): + """Bring condition tag into cnf-form.""" + input_term = str(term) + assert term.to_dnf().is_equal_to(dnf_term) and input_term == str(term) + + @pytest.mark.parametrize( + "term, simplified", + [ + ( + custom_x(1, world := World()) + & ~custom_x(2, world) + & (custom_x(3, world) | ~(custom_x(4, world) & custom_x(2, world))) + & ~(custom_x(5, world) & custom_x(2, world) & ~custom_x(1, world)), + custom_x(1, world := World()) & ~custom_x(2, world), + ), + ( + custom_x(1, world := World()) + | (custom_x(2, world) & ~custom_x(1, world)) + | (custom_x(3, world) & ~(custom_x(1, world) | custom_x(2, world))) + | (custom_x(5, world) & custom_x(4, world) & ~custom_x(1, world)), + CustomLogicCondition.disjunction_of( + [custom_x(1, world := World()), custom_x(2, world), custom_x(3, world), (custom_x(5, world) & custom_x(4, world))] + ), + ), + ( + (custom_x(1, world := World()) & ~custom_x(1, world)) + | ~custom_x(2, world) + | (custom_x(3, world) & (custom_x(4, world) | ~custom_x(4, world))) + | ~(custom_x(5, world) & custom_x(2, world) & ~custom_x(1, world)) + | (~(custom_x(5, world) & ~custom_x(5, world)) & custom_x(1, world)) + | ~(custom_x(3, world) | ~custom_x(3, world)), + CustomLogicCondition.disjunction_of( + [custom_x(1, world := World()), ~custom_x(5, world), ~custom_x(2, world), custom_x(3, world)] + ), + ), + ], + ) + def test_simplify(self, term, simplified): + cond = term.simplify() + assert cond.is_equal_to(simplified) + + def test_simplify_tmp_variable(self): + world = World() + cond = world.bitwise_and(world.variable("x1", 1), world.bitwise_negate(world.variable("x2", 1))) + log_cond = CustomLogicCondition(cond, tmp=True) + log_cond.simplify() + assert log_cond + + @pytest.mark.parametrize( + "term, result", + [ + (true_value(World()), []), + (false_value(World()), []), + (custom_x(1, world := World()), [custom_x(1, world)]), + (~custom_x(1, world := World()), [custom_x(1, world)]), + ( + custom_x(1, world := World()) + & ~custom_x(2, world) + & (custom_x(3, world) | ~(custom_x(4, world) & custom_x(2, world))) + & ~(custom_x(5, world) & custom_x(2, world) & ~custom_x(1, world)), + [custom_x(1, world), custom_x(2, world), custom_x(3, world), custom_x(4, world), custom_x(5, world)], + ), + ], + ) + def test_get_symbols(self, term, result): + assert [str(symbol) for symbol in term.get_symbols()] == [str(symbol) for symbol in result] + + @pytest.mark.parametrize( + "term, result", + [ + (true_value(World()), []), + (false_value(World()), []), + (custom_x(1, world := World()), [custom_x(1, world)]), + (~custom_x(1, world := World()), [~custom_x(1, world)]), + (custom_x(1, world := World()) | custom_x(2, world), [custom_x(1, world), custom_x(2, world)]), + (~custom_x(1, world := World()) | custom_x(2, world), [~custom_x(1, world), custom_x(2, world)]), + (custom_x(1, world := World()) & custom_x(2, world), [custom_x(1, world), custom_x(2, world)]), + ( + custom_x(1, world := World()) + & ~custom_x(2, world) + & (custom_x(3, world) | ~(custom_x(4, world) & custom_x(2, world))) + & ~(custom_x(5, world) & custom_x(2, world) & ~custom_x(1, world)), + [ + custom_x(1, world), + ~custom_x(2, world), + custom_x(3, world), + custom_x(4, world), + custom_x(2, world), + custom_x(5, world), + custom_x(2, world), + ~custom_x(1, world), + ], + ), + ], + ) + def test_get_literals(self, term, result): + assert [str(literal) for literal in term.get_literals()] == [str(literal) for literal in result] + + def test_get_literals_error(self): + init_world = World() + term = CustomLogicCondition( + init_world.bitwise_or( + init_world.bitwise_and(b_x(1, init_world), init_world.signed_lt(init_world.variable("a", 32), init_world.constant(5, 32))), + b_x(3, init_world), + ) + ) + with pytest.raises(AssertionError): + list(term.get_literals()) + + @pytest.mark.parametrize( + "term, condition, result", + [ + (true_value(world := World()), custom_x(2, world), true_value(World())), + (false_value(world := World()), custom_x(2, world), false_value(World())), + (custom_x(2, world := World()), custom_x(2, world), true_value(World())), + (custom_x(2, world := World()), custom_x(3, world), custom_x(2, World())), + (custom_x(1, world := World()) | custom_x(2, world), custom_x(2, world), true_value(World())), + ], + ) + def test_substitute_by_true_basics(self, term, condition, result): + assert term.substitute_by_true(condition).is_equal_to(result) + + @pytest.mark.parametrize( + "condition, result", + [ + ( + (custom_x(1, world := World()) | custom_x(2, world) | custom_x(3, world)) + & (custom_x(4, world) | custom_x(5, world)) + & custom_x(6, world) + & custom_x(7, world), + true_value(World()), + ), + ( + custom_x(6, World()), + (custom_x(1, world := World()) | custom_x(2, world) | custom_x(3, world)) + & (custom_x(4, world) | custom_x(5, world)) + & custom_x(7, world), + ), + ( + custom_x(4, world := World()) | custom_x(5, world), + (custom_x(1, world := World()) | custom_x(2, world) | custom_x(3, world)) & custom_x(6, world) & custom_x(7, world), + ), + ( + custom_x(6, world := World()) & (custom_x(4, world) | custom_x(5, world)), + (custom_x(1, world := World()) | custom_x(2, world) | custom_x(3, world)) & custom_x(7, world), + ), + ( + custom_x(6, world := World()) & custom_x(7, world), + (custom_x(1, world := World()) | custom_x(2, world) | custom_x(3, world)) & (custom_x(4, world) | custom_x(5, world)), + ), + ( + custom_x(1, world := World()) | custom_x(2, world), + (custom_x(1, world := World()) | custom_x(2, world) | custom_x(3, world)) + & (custom_x(4, world) | custom_x(5, world)) + & custom_x(6, world) + & custom_x(7, world), + ), + ( + (custom_x(1, world := World()) | custom_x(2, world) | custom_x(3, world)) + & (custom_x(4, world) | custom_x(5, world)) + & custom_x(6, world) + & custom_x(7, world) + & custom_x(8, world), + true_value(World()), + ), + ], + ) + def test_substitute_by_true(self, condition, result): + world = condition.context + term = ( + (custom_x(1, world) | custom_x(2, world) | custom_x(3, world)) + & (custom_x(4, world) | custom_x(5, world)) + & custom_x(6, world) + & custom_x(7, world) + ) + term.substitute_by_true(condition) + term.simplify() + assert term.is_equal_to(result.simplify()) + + @pytest.mark.parametrize( + "term, conditions, result", + [ + ( + custom_x(1, world := World()) & custom_x(2, world), + [Condition(OperationType.equal, [var_a, constant_5]), Condition(OperationType.less_or_equal_us, [var_a, constant_10])], + custom_x(1, world), + ), + ( + custom_x(1, world) & custom_x(2, world) & ~custom_x(3, world), + [ + Condition(OperationType.equal, [var_a, constant_5]), + Condition(OperationType.less_or_equal_us, [var_a, constant_10]), + Condition(OperationType.equal, [var_b, constant_10]), + ], + custom_x(1, world) & ~custom_x(3, world), + ), + ( + custom_x(1, world) & custom_x(2, world), + [Condition(OperationType.less, [var_a, constant_20]), Condition(OperationType.less_or_equal_us, [var_a, constant_10])], + custom_x(2, world), + ), + ( + custom_x(1, world) & ~custom_x(2, world), + [Condition(OperationType.less, [var_a, constant_20]), Condition(OperationType.greater_us, [var_a, constant_10])], + ~custom_x(2, world), + ), + ], + ) + def test_remove_redundancy(self, term, conditions, result): + # TODO --> new symbols + condition_handler = MockConditionHandler() + condition_handler._logic_context = term.context + for cond in conditions: + condition_handler.add_condition(cond) + assert term.remove_redundancy(condition_handler).is_equal_to(result) + + def test_remove_redundancy_new_symbol_1(self): + world = World() + term = custom_x(1, world) & custom_x(2, world) + condition_handler = MockConditionHandler() + condition_handler._logic_context = world + for cond in [Condition(OperationType.less, [var_a, constant_5]), Condition(OperationType.less_or_equal_us, [var_a, constant_10])]: + condition_handler.add_condition(cond) + term.remove_redundancy(condition_handler) + assert term.is_symbol + assert condition_handler.get_condition_of(term) == Condition(OperationType.less_or_equal_us, [var_a, constant_4]) + assert condition_handler.get_z3_condition_of(term) == u_lower_eq(custom_variable(world, "a,eax#3"), 4) + + def test_remove_redundancy_new_symbol_2(self): + world = World() + term = custom_x(1, world) & custom_x(2, world) + condition_handler = MockConditionHandler() + condition_handler._logic_context = world + expr = BinaryOperation(OperationType.plus, [var_a, constant_5]) + for cond in [Condition(OperationType.less, [expr, constant_5]), Condition(OperationType.less_or_equal_us, [expr, constant_10])]: + condition_handler.add_condition(cond) + term.remove_redundancy(condition_handler) + assert term.is_symbol + assert condition_handler.get_condition_of(term) == Condition(OperationType.less_or_equal_us, [expr, constant_4]) + assert condition_handler.get_z3_condition_of(term) == u_lower_eq(custom_variable(world), 4) + + @pytest.mark.parametrize( + "world, term, result", + [ + ( + world := World(), + CustomLogicCondition( + world.bitwise_or( + world.bitwise_and(b_x(1, world)), + world.bitwise_negate(b_x(2, world)), + world.bitwise_and(b_x(3, world), world.bitwise_or(world.bitwise_negate(b_x(4, world)))), + world.bitwise_negate(world.bitwise_and(b_x(5, world), b_x(2, world), world.bitwise_negate(b_x(1, world)))), + world.bitwise_and( + world.bitwise_negate(world.bitwise_and(b_x(5, world), world.bitwise_negate(b_x(5, world)))), + b_x(1, world), + ), + world.bitwise_negate(world.bitwise_or(b_x(3, world), world.bitwise_negate(b_x(3, world)))), + ) + ), + "(a < 0x1 | b == 0x2 | (c <= 0x3 & d <= 0x4) | !(e >= 0x5 & b != 0x2 & a >= 0x1) | (!(e >= 0x5 & e < 0x5) & a < 0x1) | " + "!(c <= 0x3 | c > 0x3))", + ), + ( + world := World(), + CustomLogicCondition( + world.bitwise_or( + world.bitwise_and(b_x(1, world), world.bitwise_negate(b_x(1, world))), + world.bitwise_negate(b_x(2, world)), + world.bitwise_and(b_x(3, world), world.bitwise_or(b_x(4, world), world.bitwise_negate(b_x(4, world)))), + world.bitwise_negate(world.bitwise_and(b_x(5, world), b_x(2, world), world.bitwise_negate(b_x(1, world)))), + world.bitwise_and( + world.bitwise_negate(world.bitwise_and(b_x(5, world), world.bitwise_negate(b_x(5, world)))), + b_x(1, world), + ), + world.bitwise_negate(world.bitwise_or(b_x(3, world), world.bitwise_negate(b_x(3, world)))), + ) + ), + "((a < 0x1 & a >= 0x1) | b == 0x2 | (c <= 0x3 & (d > 0x4 | d <= 0x4)) | !(e >= 0x5 & b != 0x2 & a >= 0x1) | " + "(!(e >= 0x5 & e < 0x5) & a < 0x1) | !(c <= 0x3 | c > 0x3))", + ), + ], + ) + def test_rich_string_representation(self, world, term, result): + condition_map = { + custom_x(1, world): Condition(OperationType.less, [Variable("a"), Constant(1)]), + custom_x(2, world): Condition(OperationType.not_equal, [Variable("b"), Constant(2)]), + custom_x(3, world): Condition(OperationType.less_or_equal, [Variable("c"), Constant(3)]), + custom_x(4, world): Condition(OperationType.greater, [Variable("d"), Constant(4)]), + custom_x(5, world): Condition(OperationType.greater_or_equal, [Variable("e"), Constant(5)]), + } + assert term.rich_string_representation(condition_map) == result + + +class TestPseudoCustomLogicCondition: + @pytest.mark.parametrize( + "condition, result", + [ + (Condition(OperationType.equal, [var_a, constant_5]), "(a,eax#3 == 5)"), + ( + Condition(OperationType.less_or_equal, [BinaryOperation(OperationType.plus, [var_a, constant_5]), constant_5]), + "(a + 0x5,['eax#3'] s<= 5)", + ), + ( + Condition(OperationType.greater_or_equal_us, [BinaryOperation(OperationType.plus, [var_a, var_b]), constant_5]), + "(a + b,['eax#3', 'edx#5'] u>= 5)", + ), + ], + ) + def test_initialize_from_condition(self, condition, result): + world = World() + cond = PseudoCustomLogicCondition.initialize_from_condition(condition, world) + assert str(cond) == result and world == cond.context + + def test_initialize_from_formula(self): + pass + + @pytest.mark.parametrize( + "term, result", + [ + ( + PseudoCustomLogicCondition( + (world := World()).bitwise_negate(world.unsigned_le(custom_variable(world), custom_constant(world, 5))) + ), + PseudoCustomLogicCondition((world := World()).unsigned_gt(custom_variable(world), custom_constant(world, 5))), + ), + ( + PseudoCustomLogicCondition( + (world := World()).bitwise_and( + b_x(1, world), + world.bitwise_or(b_x(3, world), world.bitwise_negate(world.bitwise_and(b_x(4, world), b_x(2, world)))), + world.bitwise_negate(world.bitwise_and(b_x(5, world), b_x(2, world), world.bitwise_negate(b_x(1, world)))), + ) + ), + PseudoCustomLogicCondition( + (world := World()).bitwise_and( + b_x(1, world), + world.bitwise_or(b_x(3, world), world.bitwise_negate(b_x(4, world)), world.bitwise_negate(b_x(2, world))), + ) + ), + ), + ], + ) + def test_simplify(self, term, result): + assert term.simplify() == result + + @pytest.mark.parametrize( + "expression, result", + [ + (constant_5, "5"), + (var_a, "a,eax#3"), + (BinaryOperation(OperationType.plus, [var_a, constant_5]), "a + 0x5,['eax#3']"), + (BinaryOperation(OperationType.plus, [var_a, var_b]), "a + b,['eax#3', 'edx#5']"), + ], + ) + def test_convert_expression(self, expression, result): + world = World() + world_5 = World() + custom_expression = PseudoCustomLogicCondition._convert_expression(expression, 32, world) + custom_expression_5 = PseudoCustomLogicCondition._convert_expression(expression, 5, world_5) + assert str(custom_expression) == result and custom_expression.size == 32 + assert str(custom_expression_5) == result and custom_expression_5.size == 5 diff --git a/tests/structures/logic/test_logic_condition.py b/tests/structures/logic/test_logic_condition.py index 48e28ce27..51e094203 100644 --- a/tests/structures/logic/test_logic_condition.py +++ b/tests/structures/logic/test_logic_condition.py @@ -1,11 +1,12 @@ import pytest +from decompiler.structures.ast.condition_symbol import ConditionHandler, ConditionSymbol from decompiler.structures.logic.logic_condition import generate_logic_condition_class, generate_pseudo_logic_condition_class from decompiler.structures.logic.z3_logic import PseudoZ3LogicCondition, Z3LogicCondition from decompiler.structures.pseudo import BinaryOperation, Condition, Constant, Integer, OperationType, Variable from z3 import UGT, ULE, And, BitVec, BitVecVal, Bool, BoolVal, Not, Or LogicCondition = generate_logic_condition_class(Z3LogicCondition) -PseudoLogicCondition = generate_pseudo_logic_condition_class(PseudoZ3LogicCondition) +PseudoLogicCondition = generate_pseudo_logic_condition_class(PseudoZ3LogicCondition, LogicCondition) context = LogicCondition.generate_new_context() z3_symbol = [Bool(f"x{i}", ctx=context) for i in [0, 1, 2, 3, 4, 5, 6]] logic_x = [LogicCondition.initialize_symbol(f"x{i}", context) for i in [0, 1, 2, 3, 4, 5, 6, 7, 8]] @@ -18,6 +19,8 @@ var_ugt_10 = PseudoLogicCondition(UGT(z3_variable, BitVecVal(10, 32, ctx=context))).simplify() constant_5 = Constant(5, Integer.int32_t()) +constant_10 = Constant(10, Integer.int32_t()) +constant_20 = Constant(20, Integer.int32_t()) var_a = Variable( "a", Integer.int32_t(), ssa_label=None, is_aliased=False, ssa_name=Variable("eax", Integer.int32_t(), ssa_label=3, is_aliased=False) @@ -315,6 +318,7 @@ def test_get_literals(self, term, result): ), (logic_x[2].copy(), logic_x[2].copy(), LogicCondition.initialize_true(context)), (logic_x[2].copy(), logic_x[3].copy(), logic_x[2].copy()), + (logic_x[2].copy() | logic_x[3].copy(), logic_x[3].copy(), LogicCondition.initialize_true(context)), ], ) def test_substitute_by_true_basics(self, term, condition, result): @@ -359,10 +363,7 @@ def test_substitute_by_true_basics(self, term, condition, result): & logic_x[6].copy() & logic_x[7].copy() & logic_x[8].copy(), - (logic_x[1].copy() | logic_x[2].copy() | logic_x[3].copy()) - & (logic_x[4].copy() | logic_x[5].copy()) - & logic_x[6].copy() - & logic_x[7].copy(), + LogicCondition.initialize_true(context), ), ], ) @@ -377,20 +378,53 @@ def test_substitute_by_true(self, condition, result): assert term == result @pytest.mark.parametrize( - "term, condition_map, result", + "term, conditions, result", [ - (logic_x[1].copy() & logic_x[2].copy(), {logic_x[1].copy(): var_eq_5, logic_x[2].copy(): var_ule_10}, logic_x[1].copy()), ( logic_x[1].copy() & logic_x[2].copy(), - {logic_x[1].copy(): var_l_5, logic_x[2].copy(): var_ule_10}, + [Condition(OperationType.equal, [var_a, constant_5]), Condition(OperationType.less_or_equal_us, [var_a, constant_10])], + logic_x[1].copy(), + ), + ( + logic_x[1].copy() & logic_x[2].copy(), + [Condition(OperationType.less, [var_a, constant_5]), Condition(OperationType.less_or_equal_us, [var_a, constant_10])], logic_x[1].copy() & logic_x[2].copy(), ), - (logic_x[1].copy() & logic_x[2].copy(), {logic_x[1].copy(): var_l_20, logic_x[2].copy(): var_ule_10}, logic_x[2].copy()), - (logic_x[1].copy() & ~logic_x[2].copy(), {logic_x[1].copy(): var_l_20, logic_x[2].copy(): var_ugt_10}, ~logic_x[2].copy()), + ( + logic_x[1].copy() & logic_x[2].copy(), + [Condition(OperationType.less, [var_a, constant_20]), Condition(OperationType.less_or_equal_us, [var_a, constant_10])], + logic_x[2].copy(), + ), + ( + logic_x[1].copy() & ~logic_x[2].copy(), + [Condition(OperationType.less, [var_a, constant_20]), Condition(OperationType.greater_us, [var_a, constant_10])], + ~logic_x[2].copy(), + ), ], ) - def test_remove_redundancy(self, term, condition_map, result): - term.remove_redundancy(condition_map) + def test_remove_redundancy(self, term, conditions, result): + class MockConditionHandler(ConditionHandler): + LogicCondition = generate_logic_condition_class(Z3LogicCondition) + PseudoLogicCondition = generate_logic_condition_class(Z3LogicCondition) + + def add_condition(self, condition: Condition) -> ConditionSymbol: + """Adds a condition to the condition map.""" + symbol = self._get_next_symbol() + z3_condition = PseudoLogicCondition.initialize_from_condition(condition, self._logic_context) + condition_symbol = ConditionSymbol(condition, symbol, z3_condition) + self._condition_map[symbol] = condition_symbol + return condition_symbol + + def _get_next_symbol(self) -> Z3LogicCondition: + """Get the next unused symbol name.""" + self._symbol_counter += 1 + return LogicCondition.initialize_symbol(f"x{self._symbol_counter}", self._logic_context) + + condition_handler = MockConditionHandler() + condition_handler._logic_context = term.context + for cond in conditions: + condition_handler.add_condition(cond) + term.remove_redundancy(condition_handler) assert term == result @pytest.mark.parametrize( @@ -433,7 +467,6 @@ def test_remove_redundancy(self, term, condition_map, result): ], ) def test_simplify_to_shortest(self, term, bound, result): - assert term.z3.is_equal(term.simplify_to_shortest(bound)._condition, result) @pytest.mark.parametrize( diff --git a/tests/structures/logic/test_z3_logic.py b/tests/structures/logic/test_z3_logic.py index ed53e1606..4296199a0 100644 --- a/tests/structures/logic/test_z3_logic.py +++ b/tests/structures/logic/test_z3_logic.py @@ -1,6 +1,7 @@ from typing import List, Tuple import pytest as pytest +from decompiler.structures.ast.condition_symbol import ConditionHandler, ConditionSymbol from decompiler.structures.logic.z3_implementations import Z3Implementation from decompiler.structures.logic.z3_logic import PseudoZ3LogicCondition, Z3LogicCondition from decompiler.structures.pseudo import BinaryOperation, Condition, Constant, Integer, OperationType, Variable @@ -25,6 +26,8 @@ var_ugt_10 = PseudoZ3LogicCondition(UGT(z3_variable, const10)).simplify() constant_5 = Constant(5, Integer.int32_t()) +constant_10 = Constant(10, Integer.int32_t()) +constant_20 = Constant(20, Integer.int32_t()) var_a = Variable( "a", Integer.int32_t(), ssa_label=None, is_aliased=False, ssa_name=Variable("eax", Integer.int32_t(), ssa_label=3, is_aliased=False) @@ -518,6 +521,7 @@ def test_get_literals_error(self): (false_value, z3_x[2].copy(), false_value), (z3_x[2].copy(), z3_x[2].copy(), true_value), (z3_x[2].copy(), z3_x[3].copy(), z3_x[2].copy()), + (z3_x[1].copy() | z3_x[2].copy(), z3_x[2].copy(), true_value), ], ) def test_substitute_by_true_basics(self, term, condition, result): @@ -544,7 +548,7 @@ def test_substitute_by_true_basics(self, term, condition, result): & z3_x[6].copy() & z3_x[7].copy() & z3_x[8].copy(), - (z3_x[1].copy() | z3_x[2].copy() | z3_x[3].copy()) & (z3_x[4].copy() | z3_x[5].copy()) & z3_x[6].copy() & z3_x[7].copy(), + true_value, ), ], ) @@ -554,16 +558,50 @@ def test_substitute_by_true(self, condition, result): assert term.simplify() == result.simplify() @pytest.mark.parametrize( - "term, condition_map, result", + "term, conditions, result", [ - (z3_x[1].copy() & z3_x[2].copy(), {z3_x[1].copy(): var_eq_5, z3_x[2].copy(): var_ule_10}, z3_x[1].copy()), - (z3_x[1].copy() & z3_x[2].copy(), {z3_x[1].copy(): var_l_5, z3_x[2].copy(): var_ule_10}, z3_x[1].copy() & z3_x[2].copy()), - (z3_x[1].copy() & z3_x[2].copy(), {z3_x[1].copy(): var_l_20, z3_x[2].copy(): var_ule_10}, z3_x[2].copy()), - (z3_x[1].copy() & ~z3_x[2].copy(), {z3_x[1].copy(): var_l_20, z3_x[2].copy(): var_ugt_10}, ~z3_x[2].copy()), + ( + z3_x[1].copy() & z3_x[2].copy(), + [Condition(OperationType.equal, [var_a, constant_5]), Condition(OperationType.less_or_equal_us, [var_a, constant_10])], + z3_x[1].copy(), + ), + ( + z3_x[1].copy() & z3_x[2].copy(), + [Condition(OperationType.less, [var_a, constant_5]), Condition(OperationType.less_or_equal_us, [var_a, constant_10])], + z3_x[1].copy() & z3_x[2].copy(), + ), + ( + z3_x[1].copy() & z3_x[2].copy(), + [Condition(OperationType.less, [var_a, constant_20]), Condition(OperationType.less_or_equal_us, [var_a, constant_10])], + z3_x[2].copy(), + ), + ( + z3_x[1].copy() & ~z3_x[2].copy(), + [Condition(OperationType.less, [var_a, constant_20]), Condition(OperationType.greater_us, [var_a, constant_10])], + ~z3_x[2].copy(), + ), ], ) - def test_remove_redundancy(self, term, condition_map, result): - assert term.remove_redundancy(condition_map) == result + def test_remove_redundancy(self, term, conditions, result): + class MockConditionHandler(ConditionHandler): + def add_condition(self, condition: Condition) -> ConditionSymbol: + """Adds a condition to the condition map.""" + symbol = self._get_next_symbol() + z3_condition = PseudoZ3LogicCondition.initialize_from_condition(condition, self._logic_context).simplify() + condition_symbol = ConditionSymbol(condition, symbol, z3_condition) + self._condition_map[symbol] = condition_symbol + return condition_symbol + + def _get_next_symbol(self) -> Z3LogicCondition: + """Get the next unused symbol name.""" + self._symbol_counter += 1 + return Z3LogicCondition.initialize_symbol(f"x{self._symbol_counter}", self._logic_context) + + condition_handler = MockConditionHandler() + condition_handler._logic_context = term.context + for cond in conditions: + condition_handler.add_condition(cond) + assert term.remove_redundancy(condition_handler) == result @pytest.mark.parametrize( "term, bound1, bound2, result",