Skip to content

Commit

Permalink
fixed enumeration for non-linear rules
Browse files Browse the repository at this point in the history
  • Loading branch information
mrhaandi committed Aug 26, 2024
1 parent 6ec3d83 commit c5dd4dc
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 104 deletions.
162 changes: 65 additions & 97 deletions clsp/enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .grammar import (
GVar,
Literal,
ParameterizedTreeGrammar,
Predicate,
RHSRule,
Expand All @@ -26,25 +27,19 @@
S = TypeVar("S") # non-terminals
T = TypeVar("T", covariant=True, bound=Hashable)


# Tree: TypeAlias = tuple[T, tuple["Tree[T]", ...]]
@dataclass(slots=True)
class Tree(Generic[T]):
root: T
children: tuple["Tree[T]", ...] = field(default=())
variable_names: list[str] = field(default_factory=list)
hashing_function: Optional[Callable[["Tree[T]"], int]] = field(
default=None, compare=False, hash=False, repr=False
)

size: int = field(init=False, compare=True, repr=False)
_hash: int = field(init=False, compare=False, repr=False)

def __post_init__(self) -> None:
self.size = 1 + sum(child.size for child in self.children)
if self.hashing_function is not None:
self._hash = self.hashing_function(self)
else:
self._hash = hash((self.root, self.children))
self._hash = hash((self.root, self.children))

@property
def parameters(self) -> dict[str, "Tree[T]"]:
Expand Down Expand Up @@ -74,14 +69,6 @@ def __getitem__(self, i: typing.Literal[0] | typing.Literal[1]) -> T | tuple["Tr
def __hash__(self) -> int:
return self._hash

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Tree):
return False

if self.hashing_function is None:
return self.root == other.root and self.children == other.children
return hash(self) == hash(other)

def __lt__(self, other: "Tree[T]") -> bool:
return self.size < other.size

Expand Down Expand Up @@ -121,94 +108,78 @@ def tree_size(tree: Tree[T]) -> int:
# return


def enumerate_term_vectors(
non_terminals: tuple[S] | list[S],
existing_terms: dict[S, set[Tree[T]]],
nt_term: Optional[tuple[S, Tree[T]]] = None,
) -> Iterable[Tree[T]]:
"""Enumerate possible term vectors for a given list of non-terminals and existing terms. Use nt_term at least once (if given)."""
if nt_term is None:
yield from itertools.product(*(existing_terms[n] for n in non_terminals))
else:
nt, term = nt_term
for i, n in enumerate(non_terminals):
if n == nt:
yield from itertools.product(*([term] if i == j else existing_terms[m] for j, m in enumerate(non_terminals)))

def generate_new_terms(
rule: RHSRule[S, T],
existing_terms: dict[S, set[Tree[T]]],
max_count: Optional[int] = None,
nt_old_term: Optional[tuple[S, Tree[T]]] = None,
hashing_function: Optional[Callable[["Tree[T]"], int]] = None,
) -> set[Tree[T]]:
# Genererate new terms for rule `rule` from existing terms up to `max_count`
# the term `old_term` should be a subterm of all resulting terms, at a position, that corresponds to `nt`

output_set: set[Tree[T]] = set()
number_of_parameters = len(rule.parameters)

# Get all possible positions of a term, that is build by nt in the output term
all_arguments = [
rule.binder[p.name] if isinstance(p, GVar) else p for p in rule.parameters
] + rule.args
if max_count == 0:
return output_set

names, param_nts = zip(*rule.binder.items()) if len(rule.binder) > 0 else ((), ())
literals = [Tree(p.value) if isinstance(p, Literal) else p.name for p in rule.parameters]
interleave = lambda substitution: tuple(substitution[t] if isinstance(t, str) else t for t in literals)
new_term = lambda params_args: Tree(rule.terminal, params_args, variable_names=rule.variable_names,)

if nt_old_term is None:
positions_of_nt = [-1]
nt = None
old_term = None
all_args = list(enumerate_term_vectors(rule.args, existing_terms, None))
all_params = [
interleave(substitution)
for param_terms in enumerate_term_vectors(param_nts, existing_terms, None)
for substitution in (dict(zip(names, param_terms)),)
if all(predicate.eval(substitution) for predicate in rule.predicates)
]
for params in all_params:
for args in all_args:
output_set.add(new_term(params + args))
if max_count is not None and len(output_set) >= max_count:
return output_set
else:
nt, old_term = nt_old_term
positions_of_nt = [i for i, e in enumerate(all_arguments) if e == nt]

cached_complete_parameter_parts: Optional[list[tuple[Tree[T], ...]]] = None
cached_complete_argument_parts: Optional[list[tuple[Tree[T], ...]]] = None

for pos in positions_of_nt:
pos_in_parameters = pos < number_of_parameters
all_parameter_parts: list[tuple[Tree[T], ...]] = []

if 0 < pos < number_of_parameters or cached_complete_parameter_parts is None:
for parameter_part in itertools.product(
*(
(
(
existing_terms[rule.binder[param.name]]
if i != pos
else [old_term] if old_term is not None else []
)
if isinstance(param, GVar)
else [Tree(param.value, hashing_function=hashing_function)]
)
for i, param in enumerate(rule.parameters)
)
):
if rule.check(parameter_part):
all_parameter_parts.append(parameter_part)
if not pos_in_parameters:
cached_complete_parameter_parts = all_parameter_parts
else:
all_parameter_parts = cached_complete_parameter_parts

if pos >= number_of_parameters or cached_complete_argument_parts is None:
all_argument_parts = list(
itertools.product(
*(
(
existing_terms[arg]
if i + number_of_parameters != pos
else [old_term] if old_term is not None else []
)
for i, arg in enumerate(rule.args)
)
)
)
if pos_in_parameters:
cached_complete_argument_parts = all_argument_parts
else:
all_argument_parts = cached_complete_argument_parts

new_terms = (
Tree(
rule.terminal,
param_part + arg_part,
variable_names=rule.variable_names,
hashing_function=hashing_function,
)
for param_part, arg_part in itertools.product(all_parameter_parts, all_argument_parts)
)

# Add new terms to the output set and check max_count
output_set.update(new_terms)
if max_count is not None and len(output_set) >= max_count:
return set(itertools.islice(output_set, max_count))

if nt in param_nts:
cached_all_args = None
for param_terms in enumerate_term_vectors(param_nts, existing_terms, (nt, old_term)):
substitution = dict(zip(names, param_terms))
if all(predicate.eval(substitution) for predicate in rule.predicates):
cached_all_args = list(enumerate_term_vectors(rule.args, existing_terms, None)) if cached_all_args is None else cached_all_args
for args in cached_all_args:
output_set.add(new_term(interleave(substitution) + args))
if max_count is not None and len(output_set) >= max_count:
return output_set

if nt in rule.args:
cached_all_params = None
for args in enumerate_term_vectors(rule.args, existing_terms, (nt, old_term)):
cached_all_params = [
interleave(substitution)
for param_terms in enumerate_term_vectors(param_nts, existing_terms, None)
for substitution in (dict(zip(names, param_terms)),)
if all(predicate.eval(substitution) for predicate in rule.predicates)
] if cached_all_params is None else cached_all_params
for params in cached_all_params:
output_set.add(new_term(params + args))
if max_count is not None and len(output_set) >= max_count:
return output_set

return output_set


Expand All @@ -229,7 +200,6 @@ def enumerate_terms_fast(
grammar: ParameterizedTreeGrammar[S, T],
max_count: Optional[int] = None,
max_bucket_size: Optional[int] = None,
hashing_function: Optional[Callable[[Tree[T]], int]] = None,
) -> Iterable[Tree[T]]:
"""
Enumerate terms as an iterator efficiently - all terms are enumerated, no guaranteed term order.
Expand All @@ -248,9 +218,7 @@ def enumerate_terms_fast(
for expr in exprs:
for m in expr.non_terminals():
inverse_grammar[m].append((n, expr))
for new_term in generate_new_terms(
expr, existing_terms, hashing_function=hashing_function
):
for new_term in generate_new_terms(expr, existing_terms):
queues[n].put(new_term)
if n == start and new_term not in all_results:
if max_count is not None and len(all_results) >= max_count:
Expand Down Expand Up @@ -278,7 +246,7 @@ def enumerate_terms_fast(
non_terminals.add(m)
if m == start:
for new_term in generate_new_terms(
expr, existing_terms, max_count, (n, term), hashing_function
expr, existing_terms, max_count, (n, term)
):
if new_term not in all_results:
if max_count is not None and len(all_results) >= max_count:
Expand All @@ -288,7 +256,7 @@ def enumerate_terms_fast(
queues[start].put(new_term)
else:
for new_term in generate_new_terms(
expr, existing_terms, max_bucket_size, (n, term), hashing_function
expr, existing_terms, max_bucket_size, (n, term)
):
queues[m].put(new_term)
current_bucket_size += 1
Expand Down Expand Up @@ -791,7 +759,7 @@ def __call__(self, a: str, b: str) -> str:

for i, r in enumerate(
itertools.islice(
enumerate_terms_fast("X", d, max_count=10_000, hashing_function=lambda t: t.size),
enumerate_terms_fast("X", d, max_count=10_000),
10,
)
):
Expand Down
19 changes: 12 additions & 7 deletions tests/test_parameter_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,28 @@ def setUp(self) -> None:
)
self.grammar.add_rule(
"Y",
RHSRule({}, [Predicate(lambda _: True, "⊤")], "y1", [Literal(3, "int")], ["n"], []),
RHSRule({}, [Predicate(lambda _: True, "⊤")], "y1", [Literal(1, "int")], ["n"], []),
)
self.grammar.add_rule("Y", RHSRule({}, [Predicate(lambda _: False, "⊥")], "y2", [], [], []))
self.grammar.add_rule(
"Y",
RHSRule({}, [], "y2", [], [], []),
)
self.grammar.add_rule("Y", RHSRule({}, [Predicate(lambda _: False, "⊥")], "y3", [], [], []))

def test_grammar(self) -> None:
self.logger.info(self.grammar.show())
self.assertEqual(
"X ~> ∀(y:Y).x(<y>)(<y>)\nY ~> ⊤ ⇛ y1([3, int]) | ⊥ ⇛ y2",
"X ~> ∀(y:Y).x(<y>)(<y>)\nY ~> ⊤ ⇛ y1([1, int]) | y2 | ⊥ ⇛ y3",
self.grammar.show(),
)

def test_enum(self) -> None:
enumeration = enumerate_terms("X", self.grammar)

for t in enumeration:
self.logger.info(t)
self.assertEqual(Tree("x", (Tree("y1", (Tree(3, ()),)), Tree("y1", (Tree(3, ()),)))), t)
expected_results = [
Tree("x", (Tree("y1", (Tree(1, ()),), ["n"]), Tree("y1", (Tree(1, ()),), ["n"])), ["y", "y"],),
Tree("x", (Tree("y2", ()), Tree("y2", ())), ["y", "y"]),
]
self.assertCountEqual(enumeration, expected_results)


if __name__ == "__main__":
Expand Down

0 comments on commit c5dd4dc

Please sign in to comment.