From 4feb10e4e1708e69e7754df0c2c2c2938298f0ce Mon Sep 17 00:00:00 2001 From: Christoph Stahl Date: Thu, 12 Sep 2024 14:33:05 +0200 Subject: [PATCH] Allow for As/SetTo to return multiple values --- clsp/dsl.py | 145 ++++++++++++++++++++++++++++++------------- clsp/fcl.py | 19 +++--- clsp/types.py | 1 + tests/test_dsl_as.py | 87 +++++++++++++++++++++++--- 4 files changed, 195 insertions(+), 57 deletions(-) diff --git a/clsp/dsl.py b/clsp/dsl.py index 7f63d96..4997c2d 100644 --- a/clsp/dsl.py +++ b/clsp/dsl.py @@ -7,10 +7,11 @@ from __future__ import annotations -from collections.abc import Callable, Mapping +from collections.abc import Callable, Iterable, Mapping from functools import reduce -from typing import Any +from typing import Any, overload from inspect import signature +import typing from .types import Arrow, Literal, Param, SetTo, Type @@ -138,7 +139,54 @@ def Use(self, name: str, group: str | Type) -> DSL: self._accumulator.append((name, group, [DSL.TRUE])) return self - def As(self, set_to: Callable[..., Any], override: bool = False) -> DSL: + @overload + def As( + self, + set_to: Callable[..., Any], + /, + raw: typing.Literal[False] = ..., + multi_value: typing.Literal[False] = ..., + override: bool = False, + ) -> DSL: ... + + @overload + def As( + self, + set_to: Callable[[Mapping[str, Any]], Iterable[Any]], + /, + raw: typing.Literal[True] = ..., + multi_value: typing.Literal[True] = ..., + override: bool = False, + ) -> DSL: ... + + @overload + def As( + self, + set_to: Callable[[Mapping[str, Any]], Any], + /, + raw: typing.Literal[True] = ..., + multi_value: typing.Literal[False] = ..., + override: bool = False, + ) -> DSL: ... + + @overload + def As( + self, + set_to: Callable[..., Iterable[Any]], + /, + raw: typing.Literal[False] = ..., + multi_value: typing.Literal[True] = ..., + override: bool = False, + ) -> DSL: ... + + def As( + self, + set_to: Callable[..., Any], + /, + raw: bool = False, + multi_value: bool = False, + override: bool = False, + ) -> DSL: """ Set the previous variable directly to the result of a computation. @@ -151,43 +199,67 @@ def As(self, set_to: Callable[..., Any], override: bool = False) -> DSL: :param override: Whether the result of the computation should be discarded, if it is not in the literal set for the group. Default is False (discard). :type override: bool + :param raw: If True, `set_to` will be called with a dictionary, mapping variable names + to values, instead of each variable as a parameter. + :param multi_value: If `multi_value` is True, the function `set_to` needs to return an + `Iterable`. Potential values for the variable range over those computed elements. + :type multi_value: bool :return: The DSL object. To concatenate multiple calls. :rtype: DSL """ + + unwrapper = DSL._unwrap_predicate if not raw else DSL._extracted_values last_element = self._accumulator[-1] self._accumulator[-1] = ( last_element[0], last_element[1], - last_element[2] + [SetTo(DSL._unwrap_predicate(set_to), override)], + last_element[2] + [SetTo(unwrapper(set_to), override, multi_value)], ) return self - def AsRaw(self, set_to: Callable[[Mapping[str, Any]], Any], override: bool = False) -> DSL: + @overload + def AsRaw( + self, + set_to: Callable[[Mapping[str, Any]], Any], + /, + override: bool = False, + multi_value: typing.Literal[False] = False, + ) -> DSL: ... + + @overload + def AsRaw( + self, + set_to: Callable[[Mapping[str, Any]], Iterable[Any]], + /, + override: bool = False, + multi_value: typing.Literal[True] = True, + ) -> DSL: ... + + def AsRaw( + self, + set_to: Callable[[Mapping[str, Any]], Any], + /, + override: bool = False, + multi_value: bool = False, + ) -> DSL: """ - Set the previous variable directly to the result of a computation. - - Similar to `As`, but the `set_to` function gets a dictionary, mapping the variable names - to their values instead. - - Only available to `Literal` variables. And can only access `Literal` variables. - - :param set_to: The function computing the value for the variable. - :type set_to: Callable[[Mapping[str, Any]], Any] - :param override: Whether the result of the computation should be discarded, if it - is not in the literal set for the group. Default is False (discard). - :type override: bool - :return: The DSL object. To concatenate multiple calls. - :rtype: DSL + Deprecated, use As(... , raw = True) instead """ - last_element = self._accumulator[-1] - self._accumulator[-1] = ( - last_element[0], - last_element[1], - last_element[2] + [SetTo(DSL._extracted_values(set_to))], - ) + if multi_value: # This is for typing reasons + self.As(set_to, override=override, raw=True, multi_value=True) + else: + self.As(set_to, override=override, raw=True, multi_value=False) return self - def With(self, predicate: Callable[..., Any]) -> DSL: + @overload + def With( + self, predicate: Callable[[Mapping[str, Any]], Any], /, raw: typing.Literal[True] = True + ) -> DSL: ... + + @overload + def With(self, predicate: Callable[..., Any], /, raw: typing.Literal[False] = False) -> DSL: ... + + def With(self, predicate: Callable[..., Any], /, raw: bool = False) -> DSL: """ Filter on the previously definied variables. @@ -207,32 +279,21 @@ def With(self, predicate: Callable[..., Any]) -> DSL: :return: The DSL object. :rtype: DSL """ + unwrapper = DSL._unwrap_predicate if not raw else DSL._extracted_values + last_element = self._accumulator[-1] self._accumulator[-1] = ( last_element[0], last_element[1], - last_element[2] + [DSL._unwrap_predicate(predicate)], + last_element[2] + [unwrapper(predicate)], ) return self def WithRaw(self, predicate: Callable[[Mapping[str, Any]], Any]) -> DSL: """ - Filter on the previously definied variables. - - Similar to `With`, but the `pred` function gets a dictionary, mapping the variable names - to their values instead. - - :param predicate: A predicate deciding, if the currently chosen values are valid. - :type predicate: Callable[[Mapping[str, Any]], bool] - :return: The DSL object. - :rtype: DSL + Deprecated, use As(... , raw = True) instead """ - last_element = self._accumulator[-1] - self._accumulator[-1] = ( - last_element[0], - last_element[1], - last_element[2] + [DSL._extracted_values(predicate)], - ) + self.With(predicate, raw=True) return self def In(self, ty: Type) -> Param | Type: diff --git a/clsp/fcl.py b/clsp/fcl.py index c4642fc..406a6d8 100644 --- a/clsp/fcl.py +++ b/clsp/fcl.py @@ -112,9 +112,7 @@ def __init__( self.subtypes = subtypes @staticmethod - def _function_types( - p_or_ty: Param | Type, literals: Mapping[str, Sequence[Any]] - ) -> tuple[ + def _function_types(p_or_ty: Param | Type, literals: Mapping[str, Sequence[Any]]) -> tuple[ ParamInfo, None, list[list[MultiArrow]], @@ -178,13 +176,18 @@ def _add_set_to( literals: Mapping[str, Sequence[Any]], ) -> Iterable[dict[str, Literal]]: for s in substitutions: - values = {pred.compute(s) for pred in set_to_preds} - if len(values) != 1: + all_values = { + frozenset(pred.compute(s)) if pred.multi_value else frozenset({pred.compute(s)}) + for pred in set_to_preds + } + values = reduce(lambda acc, v: acc & v, all_values) + if len(values) == 0: continue - value = tuple(values)[0] + # value = tuple(values)[0] - if any(pred.override for pred in set_to_preds) or value in literals[group]: - yield s | {name: Literal(value, group)} + for value in values: + if any(pred.override for pred in set_to_preds) or value in literals[group]: + yield s | {name: Literal(value, group)} @staticmethod def _instantiate( diff --git a/clsp/types.py b/clsp/types.py index 16e505a..eaa7f98 100644 --- a/clsp/types.py +++ b/clsp/types.py @@ -421,6 +421,7 @@ class TermParamSpec: class SetTo: compute: Callable[[dict[str, Any]], Any] override: bool = field(default=False) + multi_value: bool = field(default=False) # @dataclass(frozen=True) diff --git a/tests/test_dsl_as.py b/tests/test_dsl_as.py index fd1a098..e035186 100644 --- a/tests/test_dsl_as.py +++ b/tests/test_dsl_as.py @@ -3,9 +3,8 @@ from clsp.dsl import DSL from clsp.enumeration import enumerate_terms, interpret_term from clsp.fcl import FiniteCombinatoryLogic -from clsp.types import ( - Type, Constructor, LVar, Literal -) +from clsp.types import Type, Constructor, LVar, Literal + class TestDSLAs(unittest.TestCase): logger = logging.getLogger(__name__) @@ -18,17 +17,27 @@ def test_param(self) -> None: def X(x: bool, y: bool, z: bool) -> str: return f"X {x} {y} {z}" - Gamma = {X: DSL() + Gamma = { + X: DSL() .Use("x", "bool") .Use("y", "bool") .As(lambda x: True) .Use("z", "bool") .As(lambda x: x) - .In(Constructor("a", LVar('x')) & Constructor("b", LVar('y')) & Constructor("c", LVar('z')))} + .In( + Constructor("a", LVar("x")) + & Constructor("b", LVar("y")) + & Constructor("c", LVar("z")) + ) + } def xyz(x: bool, y: bool, z: bool) -> Type: - return Constructor("a", Literal(x, "bool")) & Constructor("b", Literal(y, "bool")) & Constructor("c", Literal(z, "bool")) - + return ( + Constructor("a", Literal(x, "bool")) + & Constructor("b", Literal(y, "bool")) + & Constructor("c", Literal(z, "bool")) + ) + fcl = FiniteCombinatoryLogic(Gamma, literals={"bool": [True, False]}) for x in [True, False]: @@ -40,5 +49,69 @@ def xyz(x: bool, y: bool, z: bool) -> Type: self.assertLessEqual(len(result), 1) self.assertTrue(result.issubset({"X True True True", "X False True False"})) + def test_multi_as1(self) -> None: + def X(a: int, b: int) -> str: + return f"X {a} {b}" + + literals = {"int": [0, 1, 2, 3]} + Gamma = { + X: DSL(cache=True) + .Use("a", "int") + .Use("b", "int") + .As(lambda a: {a - 1, a + 1}, multi_value=True) + .In(Constructor("c", LVar("a"))) + } + + fcl = FiniteCombinatoryLogic(Gamma, literals=literals) + target = Constructor("c", Literal(0, "int")) + + result = fcl.inhabit(target) + self.assertEqual( + list(interpret_term(x) for x in enumerate_terms(target, result)), ["X 0 1"] + ) + + def test_multi_as2(self) -> None: + def X(a: int, b: int) -> str: + return f"X {a} {b}" + + literals = {"int": [0, 1, 2, 3]} + Gamma = { + X: DSL(cache=True) + .Use("a", "int") + .Use("b", "int") + .As(lambda a: {a - 1, a + 1}, multi_value=True) + .In(Constructor("c", LVar("a"))) + } + + fcl = FiniteCombinatoryLogic(Gamma, literals=literals) + target = Constructor("c", Literal(1, "int")) + + result = fcl.inhabit(target) + self.assertSetEqual( + set(interpret_term(x) for x in enumerate_terms(target, result)), {"X 1 2", "X 1 0"} + ) + + def test_multi_as3(self) -> None: + def X(a: int, b: int) -> str: + return f"X {a} {b}" + + literals = {"int": [0, 1, 2, 3]} + Gamma = { + X: DSL(cache=True) + .Use("a", "int") + .Use("b", "int") + .As(lambda a: {a - 1, a + 1}, multi_value=True, override=True) + .In(Constructor("c", LVar("a"))) + } + + fcl = FiniteCombinatoryLogic(Gamma, literals=literals) + target = Constructor("c", Literal(0, "int")) + + result = fcl.inhabit(target) + self.assertSetEqual( + set(interpret_term(x) for x in enumerate_terms(target, result)), {"X 0 -1", "X 0 1"} + ) + + if __name__ == "__main__": unittest.main()