Skip to content

Commit

Permalink
Changed default behavior for As/SetTo. If the computed value is not i…
Browse files Browse the repository at this point in the history
…n the respective literal set, it is discarded.

An option override was added to mimic the old behavior
  • Loading branch information
christofsteel committed Apr 8, 2024
1 parent ed5002e commit 0ddcb30
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
12 changes: 9 additions & 3 deletions clsp/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def Use(self, name: str, group: str | Type) -> DSL:
self._accumulator.append((name, group, [DSL.TRUE]))
return self

def As(self, set_to: Callable[..., Any]) -> DSL:
def As(self, set_to: Callable[..., Any], override: bool = False) -> DSL:
"""
Set the previous variable directly to the result of a computation.
Expand All @@ -138,18 +138,21 @@ def As(self, set_to: Callable[..., Any]) -> DSL:
parameters to this function correspond directly to the names of the variables,
previously introduced.
:type set_to: Callable[..., Any]
:param override: Whether the result of the computation should be discarded, if it
is not in the literal set for the group. Default is False (discard).
:type override: bool
:return: The DSL object. To concatenate multiple calls.
:rtype: DSL
"""
last_element = self._accumulator[-1]
self._accumulator[-1] = (
last_element[0],
last_element[1],
last_element[2] + [SetTo(DSL._unwrap_predicate(set_to))],
last_element[2] + [SetTo(DSL._unwrap_predicate(set_to), override)],
)
return self

def AsRaw(self, set_to: Callable[[Mapping[str, Any]], Any]) -> DSL:
def AsRaw(self, set_to: Callable[[Mapping[str, Any]], Any], override: bool = False) -> DSL:
"""
Set the previous variable directly to the result of a computation.
Expand All @@ -160,6 +163,9 @@ def AsRaw(self, set_to: Callable[[Mapping[str, Any]], Any]) -> DSL:
:param set_to: The function computing the value for the variable.
:type set_to: Callable[[Mapping[str, Any]], Any]
:param override: Whether the result of the computation should be discarded, if it
is not in the literal set for the group. Default is False (discard).
:type override: bool
:return: The DSL object. To concatenate multiple calls.
:rtype: DSL
"""
Expand Down
20 changes: 16 additions & 4 deletions clsp/fcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,25 @@ def split_params(
]
return (instantiations, multiarrows)

@staticmethod
def _add_set_to(
name: str,
set_to_pred: SetTo,
substitutions: deque[dict[str, Literal]],
group: str,
literals: Mapping[str, list[Any]],
) -> Iterable[dict[str, Literal]]:
for s in substitutions:
value = set_to_pred.compute(s)
if set_to_pred.override or value in literals[group]:
yield s | {name: Literal(value, group)}

@staticmethod
def _instantiate(
literals: Mapping[str, list[Any]],
params: Sequence[LitParamSpec | TermParamSpec],
) -> Iterable[InstantiationMeta]:
substitutions: Iterable[dict[str, Literal]] = deque([{}])
substitutions: deque[dict[str, Literal]] = deque([{}])
args: deque[str | GVar] = deque()
term_params: list[TermParamSpec] = []

Expand All @@ -137,9 +150,8 @@ def _instantiate(
lambda substs: all(
callable(npred) and npred(substs) for npred in normal_preds
),
(
s | {param.name: Literal(pred.compute(s), param.group)}
for s in substitutions
FiniteCombinatoryLogic._add_set_to(
param.name, pred, substitutions, param.group, literals
),
)
)
Expand Down
1 change: 1 addition & 0 deletions clsp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ class TermParamSpec:
@dataclass
class SetTo:
compute: Callable[[dict[str, Any]], Any]
override: bool = field(default=False)


# @dataclass(frozen=True)
Expand Down

0 comments on commit 0ddcb30

Please sign in to comment.