diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement.py index f803d1695..ecddafdca 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement.py @@ -15,7 +15,6 @@ SwitchExtractor, ) from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest -from decompiler.structures.logic.logic_condition import LogicCondition class ConditionAwareRefinement(BaseClassConditionAwareRefinement): @@ -28,10 +27,6 @@ class ConditionAwareRefinement(BaseClassConditionAwareRefinement): MissingCaseFinder.find_in_sequence, ] - def __init__(self, asforest: AbstractSyntaxForest): - self.asforest = asforest - super().__init__(asforest.condition_handler) - @classmethod def refine(cls, asforest: AbstractSyntaxForest): condition_aware_refinement = cls(asforest) diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py index be5a8ecfc..b96377065 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/base_class_car.py @@ -3,16 +3,10 @@ from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, SwitchNode from decompiler.structures.ast.condition_symbol import ConditionHandler +from decompiler.structures.ast.switch_node_handler import ExpressionUsages +from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.logic.logic_condition import LogicCondition, PseudoLogicCondition -from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType, Variable - - -@dataclass(frozen=True) -class ExpressionUsages: - """Dataclass that maintain for a condition the used SSA-variables.""" - - expression: Expression - ssa_usages: Tuple[Optional[Variable]] +from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType @dataclass @@ -48,12 +42,13 @@ def __hash__(self) -> int: class BaseClassConditionAwareRefinement: """Base Class in charge of logic and condition related things we need during the condition aware refinement.""" - def __init__(self, condition_handler: ConditionHandler): - self.condition_handler: ConditionHandler = condition_handler + def __init__(self, asforest: AbstractSyntaxForest): + self.asforest: AbstractSyntaxForest = asforest + self.condition_handler: ConditionHandler = asforest.condition_handler def _get_constant_equality_check_expressions_and_conditions( self, condition: LogicCondition - ) -> Iterator[Tuple[Expression, LogicCondition]]: + ) -> Iterator[Tuple[ExpressionUsages, LogicCondition]]: """ Check whether the given condition is a simple comparison of an expression with one or more constants + perhaps a conjunction with another condition. @@ -65,11 +60,11 @@ def _get_constant_equality_check_expressions_and_conditions( if condition.is_conjunction: for disjunction in condition.operands: if expression := self._get_const_eq_check_expression_of_disjunction(disjunction): - yield (expression, disjunction) + yield expression, disjunction elif expression := self._get_const_eq_check_expression_of_disjunction(condition): - yield (expression, condition) + yield expression, condition - def _get_const_eq_check_expression_of_disjunction(self, condition: LogicCondition) -> Optional[Expression]: + def _get_const_eq_check_expression_of_disjunction(self, condition: LogicCondition) -> Optional[ExpressionUsages]: """ Check whether the given condition is a composition of comparisons of the same expression with constants. @@ -89,47 +84,21 @@ def _get_const_eq_check_expression_of_disjunction(self, condition: LogicConditio compared_expressions = [self._get_expression_compared_with_constant(literal) for literal in operands] if len(set(compared_expressions)) != 1 or compared_expressions[0] is None: return None - used_variables = tuple(var.ssa_name for var in compared_expressions[0].requirements) - return ( - compared_expressions[0] - if all(used_variables == tuple(var.ssa_name for var in expression.requirements) for expression in compared_expressions[1:]) - else None - ) - - def _get_expression_compared_with_constant(self, reaching_condition: LogicCondition) -> Optional[Expression]: + return compared_expressions[0] + + def _get_expression_compared_with_constant(self, reaching_condition: LogicCondition) -> Optional[ExpressionUsages]: """ Check whether the given reaching condition, which is a literal, i.e., a z3-symbol or its negation is of the form `expr == const`. If this is the case, then we return the expression `expr`. """ - condition = self._get_literal_condition(reaching_condition) - if condition is not None and condition.operation == OperationType.equal: - return self._get_expression_compared_with_constant_in(condition) - return None - - def _get_literal_condition(self, condition: LogicCondition) -> Optional[Condition]: - """Check whether the given condition is a literal. If this is the case then it returns the condition that belongs to the literal.""" - if condition.is_symbol: - return self.condition_handler.get_condition_of(condition) - if condition.is_negation and (neg_cond := ~condition).is_symbol: - return self.condition_handler.get_condition_of(neg_cond).negate() - return None - - @staticmethod - def _get_expression_compared_with_constant_in(condition: Condition) -> Optional[Expression]: - """ - Check whether the given condition, of type Condition, compares a constant with an expression + return self.asforest.switch_node_handler.get_potential_switch_expression(reaching_condition) - - If this is the case, the function returns the expression - - Otherwise, it returns None. + def _get_constant_compared_with_expression(self, reaching_condition: LogicCondition) -> Optional[Constant]: """ - non_constants = [operand for operand in condition.operands if not isinstance(operand, Constant)] - return non_constants[0] if len(non_constants) == 1 else None - - @staticmethod - def _get_constant_compared_in_condition(condition: Condition) -> Optional[Constant]: - """Return the constant of a Condition, i.e., for `expr == const` it returns `const`.""" - constant_operands = [operand for operand in condition.operands if isinstance(operand, Constant)] - return constant_operands[0] if len(constant_operands) == 1 else None + Check whether the given reaching condition, which is a literal, i.e., a z3-symbol or its negation is of the form `expr == const`. + If this is the case, then we return the constant `const`. + """ + return self.asforest.switch_node_handler.get_potential_switch_constant(reaching_condition) def _convert_to_z3_condition(self, condition: LogicCondition) -> PseudoLogicCondition: return PseudoLogicCondition.initialize_from_formula(condition, self.condition_handler.get_z3_condition_map()) @@ -145,7 +114,7 @@ def _condition_is_redundant_for_switch_node(self, switch_node: AbstractSyntaxTre """ 1. Check whether the given node is a switch node. 2. If this is the case then we check whether condition is always fulfilled when one of the switch cases is fulfilled - and return the switch node. Otherwise we return None. + and return the switch node. Otherwise, we return None. - If the switch node has a default case, then we can not add any more cases. """ diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py index 8d9a4907e..6a8e7cc48 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/initial_switch_node_constructer.py @@ -5,13 +5,13 @@ from decompiler.pipeline.controlflowanalysis.restructuring_commons.condition_aware_refinement_commons.base_class_car import ( BaseClassConditionAwareRefinement, CaseNodeCandidate, - ExpressionUsages, ) -from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, CodeNode, ConditionNode, SeqNode, SwitchNode +from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, CodeNode, ConditionNode, SeqNode, SwitchNode, TrueNode from decompiler.structures.ast.reachability_graph import CaseDependencyGraph, LinearOrderDependency, SiblingReachability +from decompiler.structures.ast.switch_node_handler import ExpressionUsages from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.logic.logic_condition import LogicCondition -from decompiler.structures.pseudo import Condition, Constant, Expression +from decompiler.structures.pseudo import Break, Constant, Expression @dataclass @@ -24,26 +24,117 @@ class SwitchNodeCandidate: def construct_switch_cases(self) -> Iterator[Tuple[CaseNode, AbstractSyntaxTreeNode]]: """Construct Switch-case for itself.""" for case_candidate in self.cases: - yield (case_candidate.construct_case_node(self.expression), case_candidate.node) + yield case_candidate.construct_case_node(self.expression), case_candidate.node class InitialSwitchNodeConstructor(BaseClassConditionAwareRefinement): """Class that constructs switch nodes.""" - def __init__(self, asforest: AbstractSyntaxForest): - """ - self.asforest: The asforst where we try to construct switch nodes - """ - self.asforest = asforest - super().__init__(asforest.condition_handler) - @classmethod def construct(cls, asforest: AbstractSyntaxForest): """Constructs initial switch nodes if possible.""" initial_switch_constructor = cls(asforest) + for cond_node in asforest.get_condition_nodes_post_order(asforest.current_root): + initial_switch_constructor._extract_case_nodes_from_nested_condition(cond_node) for seq_node in asforest.get_sequence_nodes_post_order(asforest.current_root): initial_switch_constructor._try_to_construct_initial_switch_node_for(seq_node) + def _extract_case_nodes_from_nested_condition(self, cond_node: ConditionNode) -> None: + """ + Extract CaseNodeCandidates from nested if-conditions. + + - Nested if-conditions can belong to a switch, i.e., Condition node whose condition is a '==' or '!=' comparison of a variable v and + a constant, i.e., v == 2 or v != 2 + - The branch with the '!=' condition is + (i) either a Condition node whose condition is a '==' or '!=' comparison of the same variable v and a different constant, or a + Code node whose reaching condition is of this form, i.e., v == 1 or v != 1 + (ii) a sequence node whose first and last node is a condition node or code node with the properties described in (i) + - We extract the conditions into a sequence, such that _try_to_construct_initial_switch_node_for can reconstruct the switch. + """ + if cond_node.false_branch is None: + return + if first_case_candidate_expression := self._get_possible_case_candidate_for_condition_node(cond_node): + if second_case_candidate := self._second_case_candidate_exists_in_branch( + cond_node.false_branch_child, first_case_candidate_expression + ): + self._extract_conditions_to_obtain_switch(cond_node, second_case_candidate) + + def _get_possible_case_candidate_for_condition_node(self, cond_node: ConditionNode) -> Optional[ExpressionUsages]: + """ + Check whether one branch condition is a possible switch case + + - Make sure, that the possible switch case is always the true-branch + - If we find a candidate, return a CaseNodeCandidate containing the branch and the switch expression, else return None. + """ + possible_expressions: List[Tuple[ExpressionUsages, LogicCondition]] = list( + self._get_constant_equality_check_expressions_and_conditions(cond_node.condition) + ) + if not possible_expressions and cond_node.false_branch_child: + if possible_expressions := list(self._get_constant_equality_check_expressions_and_conditions(~cond_node.condition)): + cond_node.switch_branches() + + if len(possible_expressions) == 1: + return possible_expressions[0][0] + + def _second_case_candidate_exists_in_branch( + self, ast_node: AbstractSyntaxTreeNode, first_case_expression: ExpressionUsages + ) -> Optional[AbstractSyntaxTreeNode]: + """ + Check whether a possible case candidate whose expression is equal to first_case_expression, is contained in the given ast_node. + + - The case candidate can either be: + - the ast-node itself if the reaching condition matches a case-condition + - the true or false branch if the ast_node is a condition node where the condition or negation matches a case-condition + - the first or last child, if the node is a Sequence node and it has one of the above conditions. + """ + candidates = [ast_node] + if isinstance(ast_node, SeqNode): + candidates += [ast_node.children[0], ast_node.children[-1]] + for node in candidates: + second_case_candidate = self._find_second_case_candidate_in(node) + if second_case_candidate is not None and second_case_candidate[0] == first_case_expression: + return second_case_candidate[1] + + def _find_second_case_candidate_in(self, ast_node: AbstractSyntaxTreeNode) -> Optional[Tuple[ExpressionUsages, AbstractSyntaxTreeNode]]: + """Check whether the ast-node fulfills the properties of the second-case node to extract from nested conditions.""" + if isinstance(ast_node, ConditionNode): + return self._get_possible_case_candidate_for_condition_node(ast_node), ast_node.true_branch_child + if case_candidate := self._get_possible_case_candidate_for(ast_node): + return case_candidate.expression, ast_node + + def _extract_conditions_to_obtain_switch(self, cond_node: ConditionNode, second_case_node: AbstractSyntaxTreeNode) -> None: + """ + First of all, we extract both branches of the condition node and handle the reaching conditions. + If a branch contains a sequence node, we propagate the reaching condition to its children. This ensures that + the sequence node can be cleaned and the possible case candidates are all children of the same sequence node. + """ + first_case_node = cond_node.true_branch_child + first_case_node.reaching_condition &= cond_node.condition + + common_condition = LogicCondition.conjunction_of(self.__parent_conditions(second_case_node, cond_node)) + second_case_node.reaching_condition &= common_condition + + default_case_node = None + + if isinstance(second_case_node.parent, TrueNode): + inner_condition_node = second_case_node.parent.parent + assert isinstance(inner_condition_node, ConditionNode), "parent of True Branch must be a condition node." + second_case_node.reaching_condition &= inner_condition_node.condition + if default_case_node := inner_condition_node.false_branch_child: + default_case_node.reaching_condition &= LogicCondition.conjunction_of( + (common_condition, ~inner_condition_node.condition, ~cond_node.condition) + ) + + cond_node.reaching_condition = self.condition_handler.get_true_value() + self.asforest.extract_branch_from_condition_node(cond_node, cond_node.true_branch, update_reachability=False) + new_seq_node = cond_node.parent + if default_case_node: + self.asforest._remove_edge(default_case_node.parent, default_case_node) + self.asforest._add_edge(new_seq_node, default_case_node) + self.asforest._remove_edge(second_case_node.parent, second_case_node) + self.asforest._add_edge(new_seq_node, second_case_node) + self.asforest.clean_up(new_seq_node) + def _try_to_construct_initial_switch_node_for(self, seq_node: SeqNode) -> None: """ Construct a switch node whose cases are children of the current sequence node. @@ -96,14 +187,13 @@ def _get_possible_case_candidate_for(self, ast_node: AbstractSyntaxTreeNode) -> - Otherwise, the function returns None. - Note: Cases can not end with a loop-break statement """ - possible_expressions: List[Tuple[Expression, LogicCondition]] = list() + possible_conditions: List[Tuple[ExpressionUsages, LogicCondition]] = list() if (possible_case_condition := ast_node.get_possible_case_candidate_condition()) is not None: - possible_expressions = list(self._get_constant_equality_check_expressions_and_conditions(possible_case_condition)) + possible_conditions = list(self._get_constant_equality_check_expressions_and_conditions(possible_case_condition)) - if len(possible_expressions) == 1: - expression, condition = possible_expressions[0] - used_variables = tuple(var.ssa_name for var in expression.requirements) - return CaseNodeCandidate(ast_node, ExpressionUsages(expression, used_variables), possible_expressions[0][1]) + if len(possible_conditions) == 1: + expression_usage, condition = possible_conditions[0] + return CaseNodeCandidate(ast_node, expression_usage, condition) return None @@ -168,6 +258,8 @@ def _add_constants_to_cases(self, switch_node: SwitchNode, case_dependency_graph new_start_node = self._add_constants_for_linear_order_starting_at( starting_case, linear_ordering_starting_at, linear_order_dependency_graph, considered_conditions ) + if starting_case in cross_nodes and starting_case != new_start_node: + cross_nodes = [new_start_node if id(n) == id(starting_case) else n for n in cross_nodes] conditions_considered_at[new_start_node] = considered_conditions self._get_linear_order_for(cross_nodes, linear_ordering_starting_at, linear_order_dependency_graph) else: @@ -237,8 +329,7 @@ def _add_constants_to_cases_for( self._update_reaching_condition_of(case_node, considered_conditions) if case_node.reaching_condition.is_literal: - condition: Condition = self._get_literal_condition(case_node.reaching_condition) - case_node.constant = self._get_constant_compared_in_condition(condition) + case_node.constant = self._get_constant_compared_with_expression(case_node.reaching_condition) considered_conditions.add(case_node.reaching_condition) elif case_node.reaching_condition.is_false: case_node.constant = Constant("add_to_previous_case") @@ -320,8 +411,8 @@ def prepend_empty_cases_to_case_with_or_condition(self, case: CaseNode) -> List[ """ condition_for_constant: Dict[Constant, LogicCondition] = dict() for literal in case.reaching_condition.operands: - if condition := self._get_literal_condition(literal): - condition_for_constant[self._get_constant_compared_in_condition(condition)] = literal + if constant := self._get_constant_compared_with_expression(literal): + condition_for_constant[constant] = literal else: raise ValueError( f"The case node should have a reaching-condition that is a disjunction of literals, but it has the clause {literal}." @@ -471,3 +562,9 @@ def _clean_up_reaching_conditions(self, switch_node: SwitchNode) -> None: continue elif not case_node.reaching_condition.is_true: raise ValueError(f"{case_node} should have a literal as reaching condition, but RC = {case_node.reaching_condition}.") + + def __parent_conditions(self, second_case_node: AbstractSyntaxTreeNode, cond_node: ConditionNode): + yield self.condition_handler.get_true_value() + current_node = second_case_node + while (current_node := current_node.parent) != cond_node: + yield current_node.reaching_condition diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder.py index 07c71eb36..91c3ad815 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/missing_case_finder.py @@ -5,10 +5,10 @@ from decompiler.pipeline.controlflowanalysis.restructuring_commons.condition_aware_refinement_commons.base_class_car import ( BaseClassConditionAwareRefinement, CaseNodeCandidate, - ExpressionUsages, ) from decompiler.structures.ast.ast_nodes import AbstractSyntaxTreeNode, CaseNode, ConditionNode, FalseNode, SeqNode, SwitchNode, TrueNode from decompiler.structures.ast.reachability_graph import SiblingReachabilityGraph +from decompiler.structures.ast.switch_node_handler import ExpressionUsages from decompiler.structures.ast.syntaxforest import AbstractSyntaxForest from decompiler.structures.logic.logic_condition import LogicCondition, PseudoLogicCondition from decompiler.structures.pseudo import Condition, Constant, OperationType @@ -28,8 +28,7 @@ def __init__(self, asforest: AbstractSyntaxForest): self._current_seq_node: The seq_node which we consider to find missing cases. self._switch_node_of_expression: a dictionary that maps to each expression the corresponding switch node. """ - self.asforest = asforest - super().__init__(asforest.condition_handler) + super().__init__(asforest) self._current_seq_node: Optional[SeqNode] = None self._switch_node_of_expression: Dict[ExpressionUsages, SwitchNode] = dict() @@ -92,7 +91,12 @@ def _can_insert_missing_case_node(self, condition_node: ConditionNode) -> Option if not switch_node.reaching_condition.is_true or possible_case_node._has_descendant_code_node_breaking_ancestor_loop(): return None - if not self._get_const_eq_check_expression_of_disjunction(case_condition) == switch_node.expression: + expression_usage = self._get_const_eq_check_expression_of_disjunction(case_condition) + if ( + expression_usage is None + or expression_usage.expression != switch_node.expression + or expression_usage.ssa_usages != tuple(var.ssa_name for var in switch_node.expression.requirements) + ): return None new_case_constants = set(self._get_case_constants_for_condition(case_condition)) @@ -196,9 +200,7 @@ def _find_switch_expression_and_case_condition_for( :param condition: The reaching condition of the AST node of which we want to know whether it can be a case node of a switch node. :return: If we find a switch node, the tuple of switch node and case condition and None otherwise. """ - for expression, cond in self._get_constant_equality_check_expressions_and_conditions(condition): - used_variables = tuple(var.ssa_name for var in expression.requirements) - expression_usage = ExpressionUsages(expression, used_variables) + for expression_usage, cond in self._get_constant_equality_check_expressions_and_conditions(condition): if expression_usage in self._switch_node_of_expression: return expression_usage, cond return None @@ -233,12 +235,11 @@ def _add_new_case_nodes_to_switch_node( def _get_case_constants_for_condition(self, case_condition: LogicCondition) -> Iterable[Constant]: """Return all constants for the given condition.""" assert case_condition.is_disjunction_of_literals, f"The condition {case_condition} can not be the condition of a case node." - if condition := self._get_literal_condition(case_condition): - yield self._get_constant_compared_in_condition(condition) + if constant := self._get_constant_compared_with_expression(case_condition): + yield constant else: for literal in case_condition.operands: - condition = self._get_literal_condition(literal) - yield self._get_constant_compared_in_condition(condition) + yield self._get_constant_compared_with_expression(literal) def _insert_case_node(self, new_case_node: AbstractSyntaxTreeNode, case_constants: Set[Constant], switch_node: SwitchNode) -> None: """Insert new case node into switch node with the given set of constants.""" diff --git a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/switch_extractor.py b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/switch_extractor.py index 1100472eb..87d258cdb 100644 --- a/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/switch_extractor.py +++ b/decompiler/pipeline/controlflowanalysis/restructuring_commons/condition_aware_refinement_commons/switch_extractor.py @@ -12,15 +12,13 @@ class SwitchExtractor(BaseClassConditionAwareRefinement): def __init__(self, asforest: AbstractSyntaxForest): """ - self.asforest: The asforst where we try to construct switch nodes self.current_cond_node: The condition node which we consider to extract switch nodes. """ - self.asforest = asforest + super().__init__(asforest) self._current_cond_node: Optional[ConditionNode] = None - super().__init__(asforest.condition_handler) @classmethod - def extract(cls, asforest): + def extract(cls, asforest: AbstractSyntaxForest): """ Extract switch nodes from condition nodes, i.e., if a switch node is a branch of a condition node whose condition is redundant for the switch node, we extract it from the condition node. diff --git a/decompiler/pipeline/dataflowanalysis/dead_loop_elimination.py b/decompiler/pipeline/dataflowanalysis/dead_loop_elimination.py index f821f336d..9daedf3e6 100644 --- a/decompiler/pipeline/dataflowanalysis/dead_loop_elimination.py +++ b/decompiler/pipeline/dataflowanalysis/dead_loop_elimination.py @@ -38,8 +38,8 @@ def __init__(self): def run(self, task: DecompilerTask) -> None: """Run dead loop elimination on the given task object.""" self._timeout = task.options.getint(f"{self.name}.timeout_satisfiable") - self.engine = task.options.getstring("logic-engine.engine") # choice of z3 or delogic - if self.engine == "delogic": + engine = task.options.getstring("logic-engine.engine") # choice of z3 or delogic + if engine == "delogic": self._logic_converter = DelogicConverter() if not task.graph.root: warning(f"[{self.__class__.__name__}] Can not detect dead blocks because the cfg has no head.") diff --git a/decompiler/pipeline/dataflowanalysis/dead_path_elimination.py b/decompiler/pipeline/dataflowanalysis/dead_path_elimination.py index db62dfcf3..d7771877a 100644 --- a/decompiler/pipeline/dataflowanalysis/dead_path_elimination.py +++ b/decompiler/pipeline/dataflowanalysis/dead_path_elimination.py @@ -26,8 +26,8 @@ def __init__(self): def run(self, task: DecompilerTask) -> None: """Run dead path elimination on the given task object.""" self._timeout = task.options.getint(f"{self.name}.timeout_satisfiable") - self.engine = task.options.getstring("logic-engine.engine") # choice of z3 or delogic - if self.engine == "delogic": + engine = task.options.getstring("logic-engine.engine") # choice of z3 or delogic + if engine == "delogic": self._logic_converter = DelogicConverter() if task.graph.root is None: warning(f"[{self.__class__.__name__}] Can not detect dead blocks because the cfg has no head.") diff --git a/decompiler/structures/ast/ast_nodes.py b/decompiler/structures/ast/ast_nodes.py index 07db1b62e..67905670f 100644 --- a/decompiler/structures/ast/ast_nodes.py +++ b/decompiler/structures/ast/ast_nodes.py @@ -556,7 +556,7 @@ def get_possible_case_candidate_condition(self) -> Optional[LogicCondition]: self.clean() if self.false_branch is None and not self._has_descendant_code_node_breaking_ancestor_loop(): return self.reaching_condition & self.condition - return None + return super().get_possible_case_candidate_condition() def simplify_reaching_condition(self, condition_handler: ConditionHandler): """ diff --git a/decompiler/structures/ast/condition_symbol.py b/decompiler/structures/ast/condition_symbol.py index a7ef2b79d..d6c5f3850 100644 --- a/decompiler/structures/ast/condition_symbol.py +++ b/decompiler/structures/ast/condition_symbol.py @@ -29,7 +29,7 @@ class ConditionHandler: """Class that handles all the conditions of a transition graph and syntax-forest.""" def __init__(self, condition_map: Optional[Dict[LogicCondition, ConditionSymbol]] = None): - """Initialize a new condition handler with an dictionary that maps the symbol to its according ConditionSymbol.""" + """Initialize a new condition handler with a dictionary that maps the symbol to its according ConditionSymbol.""" self._condition_map: Dict[LogicCondition, ConditionSymbol] = dict() if condition_map is None else condition_map self._symbol_counter = 0 self._logic_context = next(iter(self._condition_map)).context if self._condition_map else LogicCondition.generate_new_context() @@ -79,6 +79,13 @@ def get_z3_condition_map(self) -> Dict[LogicCondition, PseudoLogicCondition]: """Return the z3-condition map that maps symbols to z3-conditions.""" return dict((symbol, condition_symbol.z3_condition) for symbol, condition_symbol in self._condition_map.items()) + def update_z3_condition_of(self, symbol: LogicCondition, condition: Condition): + """Change the z3-condition of the given symbol according to the given condition.""" + assert symbol.is_symbol, "Input must be a symbol!" + z3_condition = PseudoLogicCondition.initialize_from_condition(condition, self._logic_context) + pseudo_condition = self.get_condition_of(symbol) + self._condition_map[symbol] = ConditionSymbol(pseudo_condition, symbol, z3_condition) + def add_condition(self, condition: Condition) -> LogicCondition: """Adds a new condition to the condition map and returns the corresponding condition_symbol""" z3_condition = PseudoLogicCondition.initialize_from_condition(condition, self._logic_context) diff --git a/decompiler/structures/ast/switch_node_handler.py b/decompiler/structures/ast/switch_node_handler.py new file mode 100644 index 000000000..02efbe47a --- /dev/null +++ b/decompiler/structures/ast/switch_node_handler.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Iterator, List, Optional, Set, Tuple + +from decompiler.structures.ast.condition_symbol import ConditionHandler +from decompiler.structures.logic.logic_condition import LogicCondition +from decompiler.structures.logic.z3_implementations import Z3Implementation +from decompiler.structures.pseudo import Condition, Constant, Expression, OperationType, Variable, Z3Converter +from z3 import BoolRef + + +@dataclass(frozen=True) +class ExpressionUsages: + """Dataclass maintaining for a condition the used SSA-variables.""" + + expression: Expression + ssa_usages: Tuple[Optional[Variable]] + + +@dataclass +class ZeroCaseCondition: + """Possible switch expression together with its zero-case condition.""" + + expression: Expression + ssa_usages: Set[Optional[Variable]] + z3_condition: BoolRef + + +@dataclass +class CaseNodeProperties: + """ + Class for mapping possible expression and constant of a symbol for a switch-case. + + -> symbol: symbol that belongs to the expression and constant + -> constant: the compared constant + -> The condition that the new case node should get. + """ + + symbol: LogicCondition + expression: ExpressionUsages + constant: Constant + negation: bool + + def __eq__(self, other) -> bool: + """ + We want to be able to compare CaseNodeCandidates with AST-nodes, more precisely, + we want that an CaseNodeCandidate 'case_node' is equal to the AST node 'case_node.node'. + """ + if isinstance(other, CaseNodeProperties): + return self.symbol == other.symbol + return False + + +class SwitchNodeHandler: + """Handler for switch node reconstruction knowing possible constants and expressions for switch-nodes for each symbol.""" + + def __init__(self, condition_handler: ConditionHandler): + """ + Initialize the switch-node constructor. + + self._zero_case_of_switch_expression: maps to each possible switch-expression the possible zero-case condition. + self._case_node_property_of_symbol: maps to each symbol the possible expression and constant for a switch it can belong to. + """ + self._condition_handler: ConditionHandler = condition_handler + self._z3_converter: Z3Converter = Z3Converter() + self._zero_case_of_switch_expression: Dict[ExpressionUsages, ZeroCaseCondition] = dict() + self._get_zero_cases_for_possible_switch_expressions() + self._case_node_properties_of_symbol: Dict[LogicCondition, Optional[CaseNodeProperties]] = dict() + self._initialize_case_node_properties_for_symbols() + + def is_potential_switch_case(self, condition: LogicCondition) -> bool: + """Check whether the given condition is a potential switch case.""" + return self._get_case_node_property_of(condition) is not None + + def get_potential_switch_expression(self, condition: LogicCondition) -> Optional[ExpressionUsages]: + """Check whether the given condition is a potential switch case, and if return the corresponding expression.""" + if (case_node_property := self._get_case_node_property_of(condition)) is not None: + return case_node_property.expression + + def get_potential_switch_constant(self, condition: LogicCondition) -> Optional[Constant]: + """Check whether the given condition is a potential switch case, and if return the corresponding constant.""" + if (case_node_property := self._get_case_node_property_of(condition)) is not None: + return case_node_property.constant + + def _get_case_node_property_of(self, condition: LogicCondition) -> Optional[CaseNodeProperties]: + """Return the case-property of a given literal.""" + negation = False + if condition.is_negation: + condition = condition.operands[0] + negation = True + if condition.is_symbol: + if condition not in self._case_node_properties_of_symbol: + self._case_node_properties_of_symbol[condition] = self.__get_case_node_property_of_symbol(condition) + if (case_property := self._case_node_properties_of_symbol[condition]) is not None and case_property.negation == negation: + return case_property + return None + + def _get_zero_cases_for_possible_switch_expressions(self) -> None: + """Get all possible switch expressions, i.e., all expression compared with a constant, together with the potential zero case.""" + for symbol in self._condition_handler.get_all_symbols(): + self.__add_switch_expression_and_zero_case_for_symbol(symbol) + + def __add_switch_expression_and_zero_case_for_symbol(self, symbol: LogicCondition) -> None: + """Add possible switch condition for symbol if comparison of expression with constant.""" + assert symbol.is_symbol, f"Each symbol should be a single Literal, but we have {symbol}" + non_constants = [op for op in self._condition_handler.get_condition_of(symbol).operands if not isinstance(op, Constant)] + if len(non_constants) != 1: + return None + expression_usage = ExpressionUsages(non_constants[0], tuple(var.ssa_name for var in non_constants[0].requirements)) + if expression_usage not in self._zero_case_of_switch_expression: + self.__add_switch_expression(expression_usage) + + def __add_switch_expression(self, expression_usage: ExpressionUsages) -> None: + """Construct the zero case condition and add it to the dictionary.""" + ssa_expression = self.__get_ssa_expression(expression_usage) + try: + z3_condition = self._z3_converter.convert(Condition(OperationType.equal, [ssa_expression, Constant(0, ssa_expression.type)])) + except ValueError: + return + self._zero_case_of_switch_expression[expression_usage] = ZeroCaseCondition( + expression_usage.expression, set(expression_usage.ssa_usages), z3_condition + ) + + @staticmethod + def __get_ssa_expression(expression_usage: ExpressionUsages) -> Expression: + """Construct SSA-expression of the given expression.""" + if isinstance(expression_usage.expression, Variable): + return expression_usage.expression.ssa_name if expression_usage.expression.ssa_name else expression_usage.expression + ssa_expression = expression_usage.expression.copy() + for variable in [var for var in ssa_expression.requirements if var.ssa_name]: + ssa_expression.substitute(variable, variable.ssa_name) + return ssa_expression + + def _initialize_case_node_properties_for_symbols(self) -> None: + """Initialize for each symbol the possible switch case properties""" + for symbol in self._condition_handler.get_all_symbols(): + self._case_node_properties_of_symbol[symbol] = self.__get_case_node_property_of_symbol(symbol) + + def __get_case_node_property_of_symbol(self, symbol: LogicCondition) -> Optional[CaseNodeProperties]: + """Return CaseNodeProperty of the given symbol, if it exists.""" + condition = self._condition_handler.get_condition_of(symbol) + if condition.operation not in {OperationType.equal, OperationType.not_equal}: + return None + constants: List[Constant] = [operand for operand in condition.operands if isinstance(operand, Constant)] + expressions: List[Expression] = [operand for operand in condition.operands if not isinstance(operand, Constant)] + + if len(constants) == 1 or len(expressions) == 1: + expression_usage = ExpressionUsages(expressions[0], tuple(var.ssa_name for var in expressions[0].requirements)) + const: Constant = constants[0] + elif len(constants) == 0 and (zero_case_condition := self.__check_for_zero_case_condition(condition)): + expression_usage, const = zero_case_condition + self._condition_handler.update_z3_condition_of(symbol, Condition(condition.operation, [expression_usage.expression, const])) + else: + return None + if expression_usage not in self._zero_case_of_switch_expression: + self.__add_switch_expression(expression_usage) + return CaseNodeProperties(symbol, expression_usage, const, condition.operation == OperationType.not_equal) + + def __check_for_zero_case_condition(self, condition: Condition) -> Optional[Tuple[ExpressionUsages, Constant]]: + """ + Check whether the condition belongs to a zero-case of a switch expression. + + If this is the case, we return the switch expression and the zero-constant + """ + tuple_ssa_usages = tuple(var.ssa_name for var in condition.requirements) + ssa_usages = set(tuple_ssa_usages) + ssa_condition = None + for expression_usage, zero_case_condition in self._zero_case_of_switch_expression.items(): + if zero_case_condition.ssa_usages != ssa_usages: + continue + if ssa_condition is None: + ssa_condition = self.__get_z3_condition(ExpressionUsages(condition, tuple_ssa_usages)) + zero_case_z3_condition = zero_case_condition.z3_condition + if self.__is_equivalent(ssa_condition, zero_case_z3_condition): + return expression_usage, Constant(0, expression_usage.expression.type) + + def __get_z3_condition(self, expression_usage: ExpressionUsages) -> BoolRef: + """Get z3-condition of the expression usage in SSA-form""" + ssa_condition = self.__get_ssa_expression(expression_usage) + assert isinstance(ssa_condition, Condition), f"{ssa_condition} must be of type Condition!" + ssa_condition = ssa_condition.negate() if ssa_condition.operation == OperationType.not_equal else ssa_condition + z3_condition = self._z3_converter.convert(ssa_condition) + return z3_condition + + @staticmethod + def __is_equivalent(cond1: BoolRef, cond2: BoolRef): + """Check whether the given conditions are equivalent.""" + z3_implementation = Z3Implementation(True) + if z3_implementation.is_equal(cond1, cond2): + return True + return z3_implementation.does_imply(cond1, cond2) and z3_implementation.does_imply(cond2, cond1) diff --git a/decompiler/structures/ast/syntaxforest.py b/decompiler/structures/ast/syntaxforest.py index 11ffe8462..d689f11e4 100644 --- a/decompiler/structures/ast/syntaxforest.py +++ b/decompiler/structures/ast/syntaxforest.py @@ -17,6 +17,7 @@ VirtualRootNode, ) from decompiler.structures.ast.condition_symbol import ConditionHandler +from decompiler.structures.ast.switch_node_handler import SwitchNodeHandler from decompiler.structures.ast.syntaxgraph import AbstractSyntaxInterface from decompiler.structures.graphs.restructuring_graph.transition_cfg import TransitionBlock from decompiler.structures.logic.logic_condition import LogicCondition @@ -37,6 +38,7 @@ def __init__(self, condition_handler: ConditionHandler): self.condition_handler: ConditionHandler = condition_handler self._current_root: VirtualRootNode = self.factory.create_virtual_node() self._add_node(self._current_root) + self.switch_node_handler: SwitchNodeHandler = SwitchNodeHandler(condition_handler) @property def current_root(self) -> Optional[AbstractSyntaxTreeNode]: @@ -473,7 +475,7 @@ def __reverse_iterate_case_conditions(self, switch_node: SwitchNode) -> Iterable yield child, condition, case_node.break_case def __add_condition_before_nodes( - self, condition: LogicCondition, true_branch: AbstractSyntaxTreeNode, false_branch: Optional[AbstractSyntaxTreeNode] = None + self, condition: LogicCondition, true_branch: AbstractSyntaxTreeNode, false_branch: Optional[AbstractSyntaxTreeNode] = None ) -> ConditionNode: """ Add the given condition before the true_branch and its negation before the false branch. @@ -508,4 +510,3 @@ def __handle_fall_through_case(self, case_node, case_condition, condition_node) new_condition_node = self.__add_condition_before_nodes(case_condition, case_node) self._add_edge(true_branch, new_condition_node) true_branch._sorted_children = (new_condition_node,) + true_branch._sorted_children - diff --git a/decompiler/structures/logic/logic_interface.py b/decompiler/structures/logic/logic_interface.py index 458e5e37b..2a0ba97c6 100644 --- a/decompiler/structures/logic/logic_interface.py +++ b/decompiler/structures/logic/logic_interface.py @@ -235,10 +235,10 @@ def rich_string_representation(self, condition_map: Dict[LogicInterface, Conditi """Replaces each symbol by the condition of the condition map.""" def get_complexity(self, condition_map: Dict[LogicInterface, Condition]) -> int: - """ Returns the complexity of a logic condition""" + """Returns the complexity of a logic condition""" complexity_sum = 0 for literal in self.get_literals(): - if literal.is_negation: + if literal.is_negation: complexity_sum += condition_map[~literal].complexity else: complexity_sum += condition_map[literal].complexity diff --git a/decompiler/structures/pseudo/delogic_logic.py b/decompiler/structures/pseudo/delogic_logic.py index e78f40d1e..e097206c1 100644 --- a/decompiler/structures/pseudo/delogic_logic.py +++ b/decompiler/structures/pseudo/delogic_logic.py @@ -1,9 +1,8 @@ """Implements translating pseudo instructions into delogic statements.""" from __future__ import annotations -from typing import Generic, Union +from typing import Union -from simplifier.world.nodes import Operation as WorldOperation from simplifier.world.nodes import Variable as WorldVariable from simplifier.world.nodes import WorldObject from simplifier.world.world import World @@ -70,7 +69,7 @@ def _full_simplification(self, condition: WorldObject, timeout: int = 2000) -> W def _convert_variable(self, variable: Variable, default_size: int = 32) -> WorldObject: """Represent the given variable as a WorldObject.""" - return self._world.from_string(f"{variable.name}@{variable.type.size or default_size}") + return WorldVariable(self._world, str(variable), variable.type.size or default_size) def _convert_constant(self, constant: Constant, default_size: int = 32) -> WorldObject: """Represent the given constant as a WorldObject.""" diff --git a/decompiler/structures/pseudo/logic.py b/decompiler/structures/pseudo/logic.py index 1bfbeccde..c5c22feaa 100644 --- a/decompiler/structures/pseudo/logic.py +++ b/decompiler/structures/pseudo/logic.py @@ -7,7 +7,7 @@ from .expressions import Constant, Expression, Variable from .instructions import Branch, GenericBranch -from .operations import Condition, Operation +from .operations import Condition, Operation, OperationType, UnaryOperation T = TypeVar("T") @@ -29,6 +29,8 @@ def convert(self, expression: Union[Expression, Branch], **kwargs: T) -> T: return self._convert_branch(expression) if isinstance(expression, Condition): return self._convert_condition(expression) + if isinstance(expression, UnaryOperation) and expression.operation == OperationType.dereference: + return self._convert_variable(Variable(str(expression), expression.type)) if isinstance(expression, Operation): return self._convert_operation(expression) raise ValueError(f"Could not convert {expression} into a logic statement.") diff --git a/decompiler/structures/pseudo/z3_logic.py b/decompiler/structures/pseudo/z3_logic.py index ddce26c59..5a052fbe5 100644 --- a/decompiler/structures/pseudo/z3_logic.py +++ b/decompiler/structures/pseudo/z3_logic.py @@ -3,7 +3,7 @@ import logging import operator -from typing import Iterator, List, Union +from typing import Iterator, List, Type, TypeVar, Union from z3 import ( UGE, @@ -19,6 +19,7 @@ ExprRef, Extract, If, + LShR, Not, Or, RotateLeft, @@ -37,8 +38,10 @@ from .logic import BaseConverter from .operations import Condition, Operation, OperationType +OP = TypeVar("OP", operator.sub, operator.add, operator.mul, operator.truediv, operator.lshift, operator.rshift, operator.mod) -def _convert_invalid_boolref_op(a: BoolRef, b: BoolRef, op: Operation) -> BitVecRef: + +def _convert_invalid_boolref_op(a: BoolRef, b: BoolRef, op: Type[OP]) -> BitVecRef: return op( If(a, BitVecVal(1, 1, ctx=a.ctx), BitVecVal(0, 1, ctx=a.ctx), ctx=a.ctx), If(b, BitVecVal(1, 1, ctx=b.ctx), BitVecVal(0, 1, ctx=b.ctx), ctx=b.ctx), @@ -60,15 +63,15 @@ def negate(self, expr: BoolRef) -> BoolRef: """Negate a given expression.""" return Not(expr) - def _convert_variable(self, variable: Variable) -> BitVecRef: + def _convert_variable(self, variable: Variable, **kwargs) -> BitVecRef: """Represent the given Variable as a BitVector in z3.""" - return BitVec(variable.name, variable.type.size if variable.type.size else 32, ctx=self._context) + return BitVec(str(variable), variable.type.size if variable.type.size else 32, ctx=self._context) - def _convert_constant(self, constant: Constant) -> BitVecRef: + def _convert_constant(self, constant: Constant, **kwargs) -> BitVecRef: """Represent the given variable as a BitVector (no types).""" return BitVecVal(constant.value, constant.type.size if constant.type.size else 32, ctx=self._context) - def _convert_branch(self, branch: Branch) -> BoolRef: + def _convert_branch(self, branch: Branch, **kwargs) -> BoolRef: """ Convert the given branch into z3 logic. @@ -79,7 +82,7 @@ def _convert_branch(self, branch: Branch) -> BoolRef: return self._convert_condition(branch.condition) return self._convert_condition(Condition(OperationType.not_equal, [branch.condition, Constant(0, branch.condition.type)])) - def _convert_condition(self, condition: Condition) -> BoolRef: + def _convert_condition(self, condition: Condition, **kwargs) -> BoolRef: """ Convert the given condition into z3 logic. @@ -88,7 +91,7 @@ def _convert_condition(self, condition: Condition) -> BoolRef: _operation = self._get_operation(condition) return self._ensure_bool_sort(_operation) - def _convert_operation(self, operation: Operation) -> BitVecRef: + def _convert_operation(self, operation: Operation, **kwargs) -> BitVecRef: """ Convert the given operation into a z3 logic. @@ -100,12 +103,12 @@ def _convert_operation(self, operation: Operation) -> BitVecRef: def _get_operation(self, operation: Operation) -> Union[BoolRef, BitVecRef]: """Convert the given operation into a z3 expression utilizing the handler functions.""" operands = self._ensure_same_sort([self.convert(operand) for operand in operation.operands]) - if isinstance(operands[0], BoolRef) and operation.operation in self.OPERATIONS_BOOLREF: + if not operands: + raise ValueError("FOUND") + if operands and isinstance(operands[0], BoolRef) and operation.operation in self.OPERATIONS_BOOLREF: converter = self.OPERATIONS_BOOLREF.get(operation.operation, None) - elif isinstance(operands[0], BoolRef) and operation.operation in self.OPERATIONS_INVALID_BOOLREF_OP: - converter = lambda a, b: _convert_invalid_boolref_op( - a, b, self.OPERATIONS_INVALID_BOOLREF_OP.get(operation.operation, None) - ) + elif operands and isinstance(operands[0], BoolRef) and operation.operation in self.OPERATIONS_INVALID_BOOLREF_OP: + converter = lambda a, b: _convert_invalid_boolref_op(a, b, self.OPERATIONS_INVALID_BOOLREF_OP.get(operation.operation, None)) else: converter = self.OPERATIONS.get(operation.operation, None) if not converter: @@ -168,7 +171,6 @@ def check(self, *condition: BoolRef, timeout: int = 2000) -> str: return BaseConverter.UNSAT return BaseConverter.SAT - LOGIC_OPERATIONS = { OperationType.bitwise_or, OperationType.bitwise_and, @@ -183,8 +185,11 @@ def check(self, *condition: BoolRef, timeout: int = 2000) -> str: OperationType.bitwise_xor: lambda a, b: a ^ b, OperationType.bitwise_or: lambda a, b: a | b, OperationType.bitwise_and: lambda a, b: a & b, + OperationType.logical_or: lambda a, b: Or(a != 0, b != 0), + OperationType.logical_and: lambda a, b: And(a != 0, b != 0), OperationType.left_shift: lambda a, b: a << b, OperationType.right_shift: lambda a, b: a >> b, + OperationType.right_shift_us: LShR, OperationType.left_rotate: RotateLeft, OperationType.right_rotate: RotateRight, OperationType.equal: lambda a, b: a == b, @@ -205,7 +210,6 @@ def check(self, *condition: BoolRef, timeout: int = 2000) -> str: OperationType.less_or_equal_us: ULE, } - OPERATIONS_BOOLREF = { OperationType.bitwise_and: And, OperationType.bitwise_xor: Xor, @@ -214,7 +218,6 @@ def check(self, *condition: BoolRef, timeout: int = 2000) -> str: OperationType.negate: Not, } - OPERATIONS_INVALID_BOOLREF_OP = { OperationType.minus: operator.sub, OperationType.plus: operator.add, diff --git a/pyproject.toml b/pyproject.toml index 95f7ea9bc..f72691a2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 140 -target-version = ['py38'] +target-version = ['py310'] [tool.isort] profile = "black" diff --git a/tests/pipeline/controlflowanalysis/restructuring_commons/test_condition_aware_refinement.py b/tests/pipeline/controlflowanalysis/restructuring_commons/test_condition_aware_refinement.py index 4ce6f64e0..7a518d0b0 100644 --- a/tests/pipeline/controlflowanalysis/restructuring_commons/test_condition_aware_refinement.py +++ b/tests/pipeline/controlflowanalysis/restructuring_commons/test_condition_aware_refinement.py @@ -1850,10 +1850,178 @@ def test_switch_in_switch_complicated(task): pass -@pytest.mark.skip("Not implemented yet") def test_switch_only_if_else(task): - """test_condition test6 -> later""" - pass + """ + test_condition test6 + +------------------+ +----------------------------------------------------------------+ + | | | 0. | + | | | __x86.get_pc_thunk.bx() | + | 2. | | printf("Enter week number (1-7): ") | + | printf("Monday") | | var_1 = &(var_0) | + | | | __isoc99_scanf("%d", var_1) | + | | <-- | if(var_0 != 0x1) | + +------------------+ +----------------------------------------------------------------+ + | | + | | + | v + | +----------------------------------------------------------------+ +--------------------+ + | | 1. | | 4. | + | | if(var_0 != 0x2) | --> | printf("Tuesday") | + | +----------------------------------------------------------------+ +--------------------+ + | | | + | | | + | v | + +---------------------+ | +----------------------------------------------------------------+ | + | 7. | | | 3. | | + | printf("Wednesday") | <-----+--------------------- | if(var_0 != 0x3) | | + +---------------------+ | +----------------------------------------------------------------+ | + | | | | + | | | | + | | v | + | | +----------------------------------------------------------------+ | +--------------------+ + | | | 6. | | | 9. | + | | | if(var_0 != 0x4) | ------+----------------------> | printf("Thursday") | + | | +----------------------------------------------------------------+ | +--------------------+ + | | | | | + | | | | | + | | v | | + | | +----------------------------------------------------------------+ | | + | | | 8. | | | + | | +- | if(var_0 != 0x5) | | | + | | | +----------------------------------------------------------------+ | | + | | | | | | + | | | | | | + | | | v | | + | | | +----------------------------------------------------------------+ | | + | | | | 10. | | | + | | | | if(var_0 != 0x6) | -+ | | + | | | +----------------------------------------------------------------+ | | | + | | | | | | | + | | | | | | | + | | | v | | | + | | | +----------------------------------------------------------------+ | | | + | | | | 12. | | | | + | | | | if(var_0 != 0x7) | -+----+--------------------------+--------------------------+ + | | | +----------------------------------------------------------------+ | | | | + | | | | | | | | + | | | | | | | | + | | | v | | | | + | | | +----------------------------------------------------------------+ | | | | + | | | | 14. | | | | | + | | | | printf("Invalid Input! Please enter week number between 1-7.") | +----+--------------------------+---------------------+ | + | | | +----------------------------------------------------------------+ | | | | + | | | | | | | | + | +----+-------------------+ | | | | | + | | | v v v | | + | | | +----------------------------------------------------------------------------------------------------------------------+ | | + | | +--------------------> | | | | + | | | 5. | | | + | | | return 0x0 | | | + +----------------------+-------------------------> | | | | + | +----------------------------------------------------------------------------------------------------------------------+ | | + | ^ ^ ^ | | + | | | | | | + | | | | | | + | +----------------------------------------------------------------+ +--------------------+ | | | + | | 11. | | 13. | | | | + +-------------------------> | printf("Friday") | | printf("Saturday") | <-----+---------------------+ | + +----------------------------------------------------------------+ +--------------------+ | | + +----------------------------------------------------------------+ | | + | 15. | | | + | printf("Sunday") | ---------------------------------+ | + +----------------------------------------------------------------+ | + ^ | + +----------------------------------------------------------------------------------------------------------------------------+ + """ + var_1 = Variable( + "var_1", Pointer(Integer(32, True), 32), None, False, Variable("var_28", Pointer(Integer(32, True), 32), 1, False, None) + ) + var_0 = Variable("var_0", Integer(32, True), None, True, Variable("var_10", Integer(32, True), 0, True, None)) + task.graph.add_nodes_from( + vertices := [ + BasicBlock( + 0, + [ + Assignment( + ListOperation([]), Call(imp_function_symbol("__x86.get_pc_thunk.bx"), [], Pointer(CustomType("void", 0), 32), 1) + ), + Assignment(ListOperation([]), print_call("Enter week number(1-7): ", 1)), + Assignment(var_1, UnaryOperation(OperationType.address, [var_0], Pointer(Integer(32, True), 32), None, False)), + Assignment(ListOperation([]), scanf_call(var_1, 134524965, 2)), + Branch(Condition(OperationType.not_equal, [var_1, Constant(1, Integer(32, True))], CustomType("bool", 1))), + ], + ), + BasicBlock(1, [Branch(Condition(OperationType.not_equal, [var_1, Constant(2, Integer(32, True))], CustomType("bool", 1)))]), + BasicBlock(2, [Branch(Condition(OperationType.not_equal, [var_1, Constant(3, Integer(32, True))], CustomType("bool", 1)))]), + BasicBlock(3, [Branch(Condition(OperationType.not_equal, [var_1, Constant(4, Integer(32, True))], CustomType("bool", 1)))]), + BasicBlock(4, [Branch(Condition(OperationType.not_equal, [var_1, Constant(5, Integer(32, True))], CustomType("bool", 1)))]), + BasicBlock(5, [Branch(Condition(OperationType.not_equal, [var_1, Constant(6, Integer(32, True))], CustomType("bool", 1)))]), + BasicBlock(6, [Branch(Condition(OperationType.not_equal, [var_1, Constant(7, Integer(32, True))], CustomType("bool", 1)))]), + BasicBlock(7, [Assignment(ListOperation([]), print_call("Invalid input! Please enter week number between 1-7.", 14))]), + BasicBlock(8, [Assignment(ListOperation([]), print_call("Monday", 3))]), + BasicBlock(9, [Assignment(ListOperation([]), print_call("Tuesday", 5))]), + BasicBlock(10, [Assignment(ListOperation([]), print_call("Wednesday", 6))]), + BasicBlock(11, [Assignment(ListOperation([]), print_call("Thursday", 8))]), + BasicBlock(12, [Assignment(ListOperation([]), print_call("Friday", 9))]), + BasicBlock(13, [Assignment(ListOperation([]), print_call("Saturday", 11))]), + BasicBlock(14, [Assignment(ListOperation([]), print_call("Sunday", 13))]), + BasicBlock(15, [Return(ListOperation([Constant(0, Integer(32, True))]))]), + ] + ) + task.graph.add_edges_from( + [ + FalseCase(vertices[0], vertices[8]), + TrueCase(vertices[0], vertices[1]), + FalseCase(vertices[1], vertices[9]), + TrueCase(vertices[1], vertices[2]), + FalseCase(vertices[2], vertices[10]), + TrueCase(vertices[2], vertices[3]), + FalseCase(vertices[3], vertices[11]), + TrueCase(vertices[3], vertices[4]), + FalseCase(vertices[4], vertices[12]), + TrueCase(vertices[4], vertices[5]), + FalseCase(vertices[5], vertices[13]), + TrueCase(vertices[5], vertices[6]), + FalseCase(vertices[6], vertices[14]), + TrueCase(vertices[6], vertices[7]), + UnconditionalEdge(vertices[7], vertices[15]), + UnconditionalEdge(vertices[8], vertices[15]), + UnconditionalEdge(vertices[9], vertices[15]), + UnconditionalEdge(vertices[10], vertices[15]), + UnconditionalEdge(vertices[11], vertices[15]), + UnconditionalEdge(vertices[12], vertices[15]), + UnconditionalEdge(vertices[13], vertices[15]), + UnconditionalEdge(vertices[14], vertices[15]), + ] + ) + + PatternIndependentRestructuring().run(task) + + assert isinstance(seq_node := task._ast.root, SeqNode) and len(seq_node.children) == 3 + assert isinstance(seq_node.children[0], CodeNode) and seq_node.children[0].instructions == vertices[0].instructions[:-1] + assert isinstance(switch := seq_node.children[1], SwitchNode) + assert isinstance(seq_node.children[2], CodeNode) and seq_node.children[2].instructions == vertices[-1].instructions + + # switch node: + assert switch.expression == var_1 and len(switch.children) == 8 + assert isinstance(case1 := switch.cases[0], CaseNode) and case1.constant == Constant(1, Integer(32, True)) and case1.break_case is True + assert isinstance(case2 := switch.cases[1], CaseNode) and case2.constant == Constant(2, Integer(32, True)) and case2.break_case is True + assert isinstance(case3 := switch.cases[2], CaseNode) and case3.constant == Constant(3, Integer(32, True)) and case3.break_case is True + assert isinstance(case4 := switch.cases[3], CaseNode) and case4.constant == Constant(4, Integer(32, True)) and case4.break_case is True + assert isinstance(case5 := switch.cases[4], CaseNode) and case5.constant == Constant(5, Integer(32, True)) and case5.break_case is True + assert isinstance(case6 := switch.cases[5], CaseNode) and case6.constant == Constant(6, Integer(32, True)) and case6.break_case is True + assert isinstance(case7 := switch.cases[6], CaseNode) and case7.constant == Constant(7, Integer(32, True)) and case7.break_case is True + assert isinstance(default := switch.default, CaseNode) and default.constant == "default" and default.break_case is False + + # children of cases + assert isinstance(case1.child, CodeNode) and case1.child.instructions == vertices[8].instructions + assert isinstance(case2.child, CodeNode) and case2.child.instructions == vertices[9].instructions + assert isinstance(case3.child, CodeNode) and case3.child.instructions == vertices[10].instructions + assert isinstance(case4.child, CodeNode) and case4.child.instructions == vertices[11].instructions + assert isinstance(case5.child, CodeNode) and case5.child.instructions == vertices[12].instructions + assert isinstance(case6.child, CodeNode) and case6.child.instructions == vertices[13].instructions + assert isinstance(case7.child, CodeNode) and case7.child.instructions == vertices[14].instructions + assert isinstance(default.child, CodeNode) and default.child.instructions == vertices[7].instructions def test_two_entries_to_one_case(task): @@ -3918,6 +4086,79 @@ def test_only_one_occurrence_of_each_case(task): assert isinstance(seq_node.children[2], ConditionNode) +def test_case_0_different_condition(task): + """ + Consideration of conditions as "a == b" as case 0 conditions for switch-statements with expressions a-b and b-a + + simplified version of test-samples/coreutils/shred main + """ + argc = Variable("argc", Integer(32, True), None, False, Variable("argc", Integer(32, True), 0, False, None)) + var_0 = Variable("var_0", Integer(32, True), None, True, Variable("var_10", Integer(32, True), 0, True, None)) + var_4 = Variable("arg1", Integer(32, True), None, True, Variable("eax", Integer(32, True), 1, True, None)) + task.graph.add_nodes_from( + vertices := [ + BasicBlock( + 0, + [ + Assignment(ListOperation([]), print_call("Enter any number: ", 1)), + Assignment( + ListOperation([]), + scanf_call( + UnaryOperation(OperationType.address, [var_4], Pointer(Integer(32, True), 32), None, False), 134524965, 2 + ), + ), + Branch(Condition(OperationType.equal, [argc, var_4], CustomType("bool", 1))), + ], + ), + BasicBlock(1, [Assignment(var_4, BinaryOperation(OperationType.plus, [var_4, Constant(1, Integer(32, True))]))]), + BasicBlock( + 2, + [ + Assignment(var_0, BinaryOperation(OperationType.left_shift, [var_4, Constant(3, Integer(32, True))])), + Branch( + Condition( + OperationType.not_equal, + [BinaryOperation(OperationType.minus, [argc, var_4]), Constant(1, Integer(32, True))], + CustomType("bool", 1), + ) + ), + ], + ), + BasicBlock( + 3, + [Return(ListOperation([var_4]))], + ), + BasicBlock( + 4, + [ + Assignment(ListOperation([]), print_call("var_0", 3)), + Assignment( + ListOperation([]), Call(FunctionSymbol("usage", 10832), [Constant(1, Integer(32, True))], Integer(32, True), 11) + ), + ], + ), + BasicBlock(5, [Assignment(var_4, BinaryOperation(OperationType.minus, [var_4, var_0]))]), + ] + ) + task.graph.add_edges_from( + [ + TrueCase(vertices[0], vertices[1]), + FalseCase(vertices[0], vertices[2]), + UnconditionalEdge(vertices[1], vertices[3]), + TrueCase(vertices[2], vertices[4]), + FalseCase(vertices[2], vertices[5]), + UnconditionalEdge(vertices[5], vertices[3]), + ] + ) + + PatternIndependentRestructuring().run(task) + + switch_nodes = list(task.syntax_tree.get_switch_nodes_post_order()) + assert len(switch_nodes) == 1 + assert len(switch_nodes[0].cases) == 2 and switch_nodes[0].default is not None + assert vertices[0].instructions[-1].condition in task.syntax_tree.condition_map.values() + + @pytest.mark.parametrize( "graph", [_basic_switch_cfg, _switch_empty_fallthrough, _switch_no_empty_fallthrough, _switch_in_switch, _switch_test_19] ) diff --git a/tests/samples/src/systemtests/test_condition.c b/tests/samples/src/systemtests/test_condition.c index 7e72f4ee9..9326bd4e4 100644 --- a/tests/samples/src/systemtests/test_condition.c +++ b/tests/samples/src/systemtests/test_condition.c @@ -153,6 +153,247 @@ int test6() return 0; } +int test6b() +{ + int week; + printf("Enter week number (1-7): "); + scanf("%d", &week); + + + if(week != 1) + { + if(week != 2) + { + if(week != 3) + { + if(week != 4) + { + if(week != 5) + { + if(week != 6) + { + if(week != 7) + { + printf("Invalid Input! Please enter week number between 1-7."); + } + else + { + printf("Sunday"); + } + } + else + { + printf("Saturday"); + } + } + else + { + printf("Friday"); + } + } + else + { + printf("Thursday"); + } + } + else + { + printf("Wednesday"); + } + } + else + { + printf("Tuesday"); + } + } + else + { + printf("Monday"); + } + + return 0; +} + +int test6c() +{ + int week; + printf("Enter week number (1-7): "); + scanf("%d", &week); + + + if(week == 1) + { + printf("Monday"); + } + else if(week != 2) + { + if(week != 3) + { + if(week == 4) + { + printf("Thursday"); + } + else if(week != 5) + { + if(week == 6) + { + printf("Saturday"); + } + else if(week != 7) + { + printf("Invalid Input! Please enter week number between 1-7."); + } + else + { + printf("Sunday"); + } + } + else + { + printf("Friday"); + } + } + else + { + printf("Wednesday"); + } + } + else + { + printf("Tuesday"); + } + + return 0; +} + +int test6d() +{ + int week; + printf("Enter week number (1-7): "); + scanf("%d", &week); + + + if(week == 1) + { + printf("Monday"); + } + else if(week != 2) + { + if(week != 3) + { + if(week == 4) + { + printf("Thursday"); + } + else if(week != 5) + { + if(week == 6) + { + printf("Saturday"); + } + else if(week != 7) + { + printf("Invalid Input! Please enter week number between 1-7."); + } + else + { + printf("Sunday"); + } + } + else + { + printf("Friday"); + } + } + else + { + printf("Wednesday"); + } + } + + return 0; +} + +int test6e() +{ + int week; + printf("Enter week number (1-7): "); + scanf("%d", &week); + + + if(week == 1) + { + printf("Monday"); + } + else if(week != 2) + { + if(week == 3) + { + printf("Wednesday"); + } + else if(week == 4) + { + printf("Thursday"); + } + else if(week == 5) + { + printf("Friday"); + } + else if(week == 6) + { + printf("Saturday"); + } + else if(week == 7) + { + printf("Sunday"); + } + else + { + printf("Invalid Input! Please enter week number between 1-7."); + } + } + + return 0; +} + +int test6f() +{ + int week; + printf("Enter week number (1-7): "); + scanf("%d", &week); + + + if(week != 2) + { + if(week == 3) + { + printf("Wednesday"); + } + else if(week == 4) + { + printf("Thursday"); + } + else if(week == 5) + { + printf("Friday"); + } + else if(week == 6) + { + printf("Saturday"); + } + else if(week == 7) + { + printf("Sunday"); + } + else + { + printf("Invalid Input! Please enter week number between 1-7."); + } + } + + return 0; +} + int test7() { int side1, side2, side3; diff --git a/tests/samples/src/systemtests/test_switch.c b/tests/samples/src/systemtests/test_switch.c index b07f5032b..97ebfa4e9 100644 --- a/tests/samples/src/systemtests/test_switch.c +++ b/tests/samples/src/systemtests/test_switch.c @@ -39,6 +39,85 @@ int test0(int a, int b) return b; } +int test0_b(int a, int b) +{ + for(int i = 0; i < 10; i++){ + switch(a){ + case 1: + printf("You chose the 1\n"); + break; + case 2: + printf("You chose the prime number 2\n"); + case 4: + printf("You chose an even number\n"); + break; + case 5: + printf("both numbers are 5\n"); + goto L; + case 3: + printf("Another prime\n"); + break; + case 7: + printf("The 7 is a prime"); + goto L; + default: + printf("Number not between 1 and 5\n"); + if(a > 5){ + a -= 5; + } + else{ + a += 5; + } + + } + b += i; + printf("b= %d\n", b); + } + L: printf("final b= %d\n", b); + return b; +} + +int test0_c(int a, int b) +{ + for(int i = 0; i < 10; i++){ + switch(a){ + case 1: + printf("You chose the 1\n"); + break; + case 2: + printf("You chose the prime number 2\n"); + case 4: + printf("You chose an even number\n"); + break; + case 5: + printf("both numbers are 5\n"); + goto L; + case 3: + printf("Another prime\n"); + break; + case 7: + if(b > 7){ + b = b - 7; + } + printf("The 7 is a prime"); + goto L; + default: + printf("Number not between 1 and 5\n"); + if(a > 5){ + a -= 5; + } + else{ + a += 5; + } + + } + b += i; + printf("b= %d\n", b); + } + L: printf("final b= %d\n", b); + return b; +} + int test1() { int week; diff --git a/tests/structures/ast/test_syntaxforest.py b/tests/structures/ast/test_syntaxforest.py index 6c7362ddf..df125302c 100644 --- a/tests/structures/ast/test_syntaxforest.py +++ b/tests/structures/ast/test_syntaxforest.py @@ -406,8 +406,10 @@ def test_generate_from_code_nodes(): TransitionEdge(vertices[2], vertices[1], LogicCondition.initialize_true(context), EdgeProperty.back), ] ) + symbol_x1 = LogicCondition.initialize_symbol("x1", context) + pseudo_cond = Condition(OperationType.not_equal, [var("i"), Constant(3)]) t_cfg.condition_handler = ConditionHandler( - {LogicCondition.initialize_symbol("x1", context): Condition(OperationType.not_equal, [var("i"), Constant(3)])} + {symbol_x1: ConditionSymbol(pseudo_cond, symbol_x1, PseudoLogicCondition.initialize_from_condition(pseudo_cond, context))} ) asforest = AbstractSyntaxForest.generate_from_code_nodes([node.ast for node in vertices], t_cfg.condition_handler) @@ -421,8 +423,12 @@ def test_generate_from_code_nodes(): def test_construct_initial_ast_for_region(): context = LogicCondition.generate_new_context() + symbol_x1 = LogicCondition.initialize_symbol("x1", context) + pseudo_cond = Condition(OperationType.not_equal, [var("i"), Constant(3)]) asforest = AbstractSyntaxForest( - ConditionHandler({LogicCondition.initialize_symbol("x1", context): Condition(OperationType.not_equal, [var("i"), Constant(3)])}) + ConditionHandler( + {symbol_x1: ConditionSymbol(pseudo_cond, symbol_x1, PseudoLogicCondition.initialize_from_condition(pseudo_cond, context))} + ) ) code_node_0 = asforest.add_code_node([Assignment(var("i"), Constant(0)), Assignment(var("x"), Constant(42))]) code_node_1 = asforest.add_code_node() diff --git a/tests/structures/logic/test_custom_logic.py b/tests/structures/logic/test_custom_logic.py index 09131c665..f86f295ed 100644 --- a/tests/structures/logic/test_custom_logic.py +++ b/tests/structures/logic/test_custom_logic.py @@ -504,9 +504,9 @@ def test_is_symbol(self, world, term, result): False, ), ( - (world := World()).bitwise_and(b_x(1, world), b_x(2, world), b_x(3, world)), - (world := World()).bitwise_and(b_x(1, world), b_x(3, world), b_x(2, world)), - True, + (world := World()).bitwise_and(b_x(1, world), b_x(2, world), b_x(3, world)), + (world := World()).bitwise_and(b_x(1, world), b_x(3, world), b_x(2, world)), + True, ), ( (world := World()).bitwise_and(b_x(1, world), b_x(2, world), b_x(2, world)), diff --git a/tests/structures/pseudo/test_delogic_converter.py b/tests/structures/pseudo/test_delogic_converter.py index 8ce0301fe..f0a3bb191 100644 --- a/tests/structures/pseudo/test_delogic_converter.py +++ b/tests/structures/pseudo/test_delogic_converter.py @@ -5,6 +5,7 @@ from decompiler.structures.pseudo.logic import BaseConverter from decompiler.structures.pseudo.operations import BinaryOperation, Condition, OperationType, UnaryOperation from decompiler.structures.pseudo.typing import Float, Integer, Pointer +from simplifier.world.nodes import Variable as WorldVariable var_a = Variable("a", Integer.int32_t()) var_b = Variable("b", Integer.int32_t()) @@ -21,8 +22,6 @@ def converter(): def test_unsupported(converter): with pytest.raises(ValueError): converter.convert(Return([Variable("x")])) - with pytest.raises(ValueError): - converter.convert(UnaryOperation(OperationType.dereference, [Variable("x")])) with pytest.raises(ValueError): converter.convert(UnaryOperation(OperationType.address, [Variable("x")])) @@ -38,15 +37,16 @@ def test_constant(converter): @pytest.mark.parametrize( "to_parse, output", [ - (Variable("x", Integer.int32_t(), ssa_label=0), "x@32"), - (Variable("x", Integer.int32_t(), ssa_label=1), "x@32"), - (Variable("x", Float.float(), ssa_label=1), "x@32"), + (Variable("x", Integer.int32_t(), ssa_label=0), "x#0"), + (Variable("x", Integer.int32_t(), ssa_label=1), "x#1"), + (Variable("x", Integer.int32_t(), ssa_label=None), "x"), + (Variable("x", Float.float(), ssa_label=1), "x#1"), ], ) def test_variable(converter, to_parse, output): """When generating a variable, we can not transpose ssa labels or type information.""" w = converter._world - assert converter.convert(to_parse) == w.from_string(output) + assert converter.convert(to_parse) == WorldVariable(w, output, 32) @pytest.mark.parametrize( @@ -62,6 +62,18 @@ def test_unary_operation(converter, to_parse, output): assert converter.convert(to_parse) == w.from_string(output) +@pytest.mark.parametrize( + "to_parse, output", + [ + (UnaryOperation(OperationType.dereference, [Variable("x", Integer.int32_t(), ssa_label=1)]), "*(x#1)"), + (UnaryOperation(OperationType.dereference, [Variable("x", Integer.int32_t(), ssa_label=None)]), "*(x)"), + ], +) +def test_unary_dereference(converter, to_parse, output): + w = converter._world + assert converter.convert(to_parse) == WorldVariable(w, output, 32) + + @pytest.mark.parametrize( "to_parse, output", [ diff --git a/tests/structures/pseudo/test_z3_converter.py b/tests/structures/pseudo/test_z3_converter.py index d30f3c11b..9427cb47d 100644 --- a/tests/structures/pseudo/test_z3_converter.py +++ b/tests/structures/pseudo/test_z3_converter.py @@ -21,8 +21,6 @@ def converter(): def test_unsupported(converter): with pytest.raises(ValueError): converter.convert(Return([Variable("x")])) - with pytest.raises(ValueError): - converter.convert(UnaryOperation(OperationType.dereference, [Variable("x")])) with pytest.raises(ValueError): converter.convert(UnaryOperation(OperationType.address, [Variable("x")])) @@ -36,9 +34,10 @@ def test_constant(converter): def test_variable(converter): """When generating a variable, we can not transpose ssa labels or type information.""" - assert converter.convert(Variable("x", Integer.int32_t(), ssa_label=0)) == BitVec("x", 32, ctx=converter.context) - assert converter.convert(Variable("x", Integer.int32_t(), ssa_label=1)) == BitVec("x", 32, ctx=converter.context) - assert converter.convert(Variable("x", Float.float(), ssa_label=1)) == BitVec("x", 32, ctx=converter.context) + assert converter.convert(Variable("x", Integer.int32_t(), ssa_label=0)) == BitVec("x#0", 32, ctx=converter.context) + assert converter.convert(Variable("x", Integer.int32_t(), ssa_label=1)) == BitVec("x#1", 32, ctx=converter.context) + assert converter.convert(Variable("x", Integer.int32_t(), ssa_label=None)) == BitVec("x", 32, ctx=converter.context) + assert converter.convert(Variable("x", Float.float(), ssa_label=1)) == BitVec("x#1", 32, ctx=converter.context) def test_unary_operation(converter): @@ -46,6 +45,12 @@ def test_unary_operation(converter): assert converter.convert(UnaryOperation(OperationType.logical_not, [Variable("x", Integer(1))])) == ~BitVec( "x", 1, ctx=converter.context ) + assert converter.convert(UnaryOperation(OperationType.dereference, [Variable("x", Integer.int32_t(), ssa_label=0)])) == BitVec( + "*(x#0)", 32, ctx=converter.context + ) + assert converter.convert(UnaryOperation(OperationType.dereference, [Variable("x", Integer.int32_t(), ssa_label=None)])) == BitVec( + "*(x)", 32, ctx=converter.context + ) def test_binary_operation(converter):