Skip to content

Commit

Permalink
Allow for As/SetTo to return multiple values
Browse files Browse the repository at this point in the history
  • Loading branch information
christofsteel committed Sep 12, 2024
1 parent f0b4fd7 commit 4feb10e
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 57 deletions.
145 changes: 103 additions & 42 deletions clsp/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from __future__ import annotations

from collections.abc import Callable, Mapping
from collections.abc import Callable, Iterable, Mapping
from functools import reduce
from typing import Any
from typing import Any, overload
from inspect import signature
import typing

from .types import Arrow, Literal, Param, SetTo, Type

Expand Down Expand Up @@ -138,7 +139,54 @@ def Use(self, name: str, group: str | Type) -> DSL:
self._accumulator.append((name, group, [DSL.TRUE]))
return self

def As(self, set_to: Callable[..., Any], override: bool = False) -> DSL:
@overload
def As(
self,
set_to: Callable[..., Any],
/,
raw: typing.Literal[False] = ...,
multi_value: typing.Literal[False] = ...,
override: bool = False,
) -> DSL: ...

@overload
def As(
self,
set_to: Callable[[Mapping[str, Any]], Iterable[Any]],
/,
raw: typing.Literal[True] = ...,
multi_value: typing.Literal[True] = ...,
override: bool = False,
) -> DSL: ...

@overload
def As(
self,
set_to: Callable[[Mapping[str, Any]], Any],
/,
raw: typing.Literal[True] = ...,
multi_value: typing.Literal[False] = ...,
override: bool = False,
) -> DSL: ...

@overload
def As(
self,
set_to: Callable[..., Iterable[Any]],
/,
raw: typing.Literal[False] = ...,
multi_value: typing.Literal[True] = ...,
override: bool = False,
) -> DSL: ...

def As(
self,
set_to: Callable[..., Any],
/,
raw: bool = False,
multi_value: bool = False,
override: bool = False,
) -> DSL:
"""
Set the previous variable directly to the result of a computation.
Expand All @@ -151,43 +199,67 @@ def As(self, set_to: Callable[..., Any], override: bool = False) -> DSL:
:param override: Whether the result of the computation should be discarded, if it
is not in the literal set for the group. Default is False (discard).
:type override: bool
:param raw: If True, `set_to` will be called with a dictionary, mapping variable names
to values, instead of each variable as a parameter.
:param multi_value: If `multi_value` is True, the function `set_to` needs to return an
`Iterable`. Potential values for the variable range over those computed elements.
:type multi_value: bool
:return: The DSL object. To concatenate multiple calls.
:rtype: DSL
"""

unwrapper = DSL._unwrap_predicate if not raw else DSL._extracted_values
last_element = self._accumulator[-1]
self._accumulator[-1] = (
last_element[0],
last_element[1],
last_element[2] + [SetTo(DSL._unwrap_predicate(set_to), override)],
last_element[2] + [SetTo(unwrapper(set_to), override, multi_value)],
)
return self

def AsRaw(self, set_to: Callable[[Mapping[str, Any]], Any], override: bool = False) -> DSL:
@overload
def AsRaw(
self,
set_to: Callable[[Mapping[str, Any]], Any],
/,
override: bool = False,
multi_value: typing.Literal[False] = False,
) -> DSL: ...

@overload
def AsRaw(
self,
set_to: Callable[[Mapping[str, Any]], Iterable[Any]],
/,
override: bool = False,
multi_value: typing.Literal[True] = True,
) -> DSL: ...

def AsRaw(
self,
set_to: Callable[[Mapping[str, Any]], Any],
/,
override: bool = False,
multi_value: bool = False,
) -> DSL:
"""
Set the previous variable directly to the result of a computation.
Similar to `As`, but the `set_to` function gets a dictionary, mapping the variable names
to their values instead.
Only available to `Literal` variables. And can only access `Literal` variables.
:param set_to: The function computing the value for the variable.
:type set_to: Callable[[Mapping[str, Any]], Any]
:param override: Whether the result of the computation should be discarded, if it
is not in the literal set for the group. Default is False (discard).
:type override: bool
:return: The DSL object. To concatenate multiple calls.
:rtype: DSL
Deprecated, use As(... , raw = True) instead
"""
last_element = self._accumulator[-1]
self._accumulator[-1] = (
last_element[0],
last_element[1],
last_element[2] + [SetTo(DSL._extracted_values(set_to))],
)
if multi_value: # This is for typing reasons
self.As(set_to, override=override, raw=True, multi_value=True)
else:
self.As(set_to, override=override, raw=True, multi_value=False)
return self

def With(self, predicate: Callable[..., Any]) -> DSL:
@overload
def With(
self, predicate: Callable[[Mapping[str, Any]], Any], /, raw: typing.Literal[True] = True
) -> DSL: ...

@overload
def With(self, predicate: Callable[..., Any], /, raw: typing.Literal[False] = False) -> DSL: ...

def With(self, predicate: Callable[..., Any], /, raw: bool = False) -> DSL:
"""
Filter on the previously definied variables.
Expand All @@ -207,32 +279,21 @@ def With(self, predicate: Callable[..., Any]) -> DSL:
:return: The DSL object.
:rtype: DSL
"""
unwrapper = DSL._unwrap_predicate if not raw else DSL._extracted_values

last_element = self._accumulator[-1]
self._accumulator[-1] = (
last_element[0],
last_element[1],
last_element[2] + [DSL._unwrap_predicate(predicate)],
last_element[2] + [unwrapper(predicate)],
)
return self

def WithRaw(self, predicate: Callable[[Mapping[str, Any]], Any]) -> DSL:
"""
Filter on the previously definied variables.
Similar to `With`, but the `pred` function gets a dictionary, mapping the variable names
to their values instead.
:param predicate: A predicate deciding, if the currently chosen values are valid.
:type predicate: Callable[[Mapping[str, Any]], bool]
:return: The DSL object.
:rtype: DSL
Deprecated, use As(... , raw = True) instead
"""
last_element = self._accumulator[-1]
self._accumulator[-1] = (
last_element[0],
last_element[1],
last_element[2] + [DSL._extracted_values(predicate)],
)
self.With(predicate, raw=True)
return self

def In(self, ty: Type) -> Param | Type:
Expand Down
19 changes: 11 additions & 8 deletions clsp/fcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def __init__(
self.subtypes = subtypes

@staticmethod
def _function_types(
p_or_ty: Param | Type, literals: Mapping[str, Sequence[Any]]
) -> tuple[
def _function_types(p_or_ty: Param | Type, literals: Mapping[str, Sequence[Any]]) -> tuple[
ParamInfo,
None,
list[list[MultiArrow]],
Expand Down Expand Up @@ -178,13 +176,18 @@ def _add_set_to(
literals: Mapping[str, Sequence[Any]],
) -> Iterable[dict[str, Literal]]:
for s in substitutions:
values = {pred.compute(s) for pred in set_to_preds}
if len(values) != 1:
all_values = {
frozenset(pred.compute(s)) if pred.multi_value else frozenset({pred.compute(s)})
for pred in set_to_preds
}
values = reduce(lambda acc, v: acc & v, all_values)
if len(values) == 0:
continue
value = tuple(values)[0]
# value = tuple(values)[0]

if any(pred.override for pred in set_to_preds) or value in literals[group]:
yield s | {name: Literal(value, group)}
for value in values:
if any(pred.override for pred in set_to_preds) or value in literals[group]:
yield s | {name: Literal(value, group)}

@staticmethod
def _instantiate(
Expand Down
1 change: 1 addition & 0 deletions clsp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ class TermParamSpec:
class SetTo:
compute: Callable[[dict[str, Any]], Any]
override: bool = field(default=False)
multi_value: bool = field(default=False)


# @dataclass(frozen=True)
Expand Down
87 changes: 80 additions & 7 deletions tests/test_dsl_as.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from clsp.dsl import DSL
from clsp.enumeration import enumerate_terms, interpret_term
from clsp.fcl import FiniteCombinatoryLogic
from clsp.types import (
Type, Constructor, LVar, Literal
)
from clsp.types import Type, Constructor, LVar, Literal


class TestDSLAs(unittest.TestCase):
logger = logging.getLogger(__name__)
Expand All @@ -18,17 +17,27 @@ def test_param(self) -> None:
def X(x: bool, y: bool, z: bool) -> str:
return f"X {x} {y} {z}"

Gamma = {X: DSL()
Gamma = {
X: DSL()
.Use("x", "bool")
.Use("y", "bool")
.As(lambda x: True)
.Use("z", "bool")
.As(lambda x: x)
.In(Constructor("a", LVar('x')) & Constructor("b", LVar('y')) & Constructor("c", LVar('z')))}
.In(
Constructor("a", LVar("x"))
& Constructor("b", LVar("y"))
& Constructor("c", LVar("z"))
)
}

def xyz(x: bool, y: bool, z: bool) -> Type:
return Constructor("a", Literal(x, "bool")) & Constructor("b", Literal(y, "bool")) & Constructor("c", Literal(z, "bool"))

return (
Constructor("a", Literal(x, "bool"))
& Constructor("b", Literal(y, "bool"))
& Constructor("c", Literal(z, "bool"))
)

fcl = FiniteCombinatoryLogic(Gamma, literals={"bool": [True, False]})

for x in [True, False]:
Expand All @@ -40,5 +49,69 @@ def xyz(x: bool, y: bool, z: bool) -> Type:
self.assertLessEqual(len(result), 1)
self.assertTrue(result.issubset({"X True True True", "X False True False"}))

def test_multi_as1(self) -> None:
def X(a: int, b: int) -> str:
return f"X {a} {b}"

literals = {"int": [0, 1, 2, 3]}
Gamma = {
X: DSL(cache=True)
.Use("a", "int")
.Use("b", "int")
.As(lambda a: {a - 1, a + 1}, multi_value=True)
.In(Constructor("c", LVar("a")))
}

fcl = FiniteCombinatoryLogic(Gamma, literals=literals)
target = Constructor("c", Literal(0, "int"))

result = fcl.inhabit(target)
self.assertEqual(
list(interpret_term(x) for x in enumerate_terms(target, result)), ["X 0 1"]
)

def test_multi_as2(self) -> None:
def X(a: int, b: int) -> str:
return f"X {a} {b}"

literals = {"int": [0, 1, 2, 3]}
Gamma = {
X: DSL(cache=True)
.Use("a", "int")
.Use("b", "int")
.As(lambda a: {a - 1, a + 1}, multi_value=True)
.In(Constructor("c", LVar("a")))
}

fcl = FiniteCombinatoryLogic(Gamma, literals=literals)
target = Constructor("c", Literal(1, "int"))

result = fcl.inhabit(target)
self.assertSetEqual(
set(interpret_term(x) for x in enumerate_terms(target, result)), {"X 1 2", "X 1 0"}
)

def test_multi_as3(self) -> None:
def X(a: int, b: int) -> str:
return f"X {a} {b}"

literals = {"int": [0, 1, 2, 3]}
Gamma = {
X: DSL(cache=True)
.Use("a", "int")
.Use("b", "int")
.As(lambda a: {a - 1, a + 1}, multi_value=True, override=True)
.In(Constructor("c", LVar("a")))
}

fcl = FiniteCombinatoryLogic(Gamma, literals=literals)
target = Constructor("c", Literal(0, "int"))

result = fcl.inhabit(target)
self.assertSetEqual(
set(interpret_term(x) for x in enumerate_terms(target, result)), {"X 0 -1", "X 0 1"}
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4feb10e

Please sign in to comment.