Skip to content

Commit

Permalink
Better typing in benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
christofsteel committed Dec 12, 2024
1 parent 99cc45c commit 54d5b35
Showing 1 changed file with 22 additions and 30 deletions.
52 changes: 22 additions & 30 deletions tests/benchmarks/benchmark_labyrinth_clsp_setto_freedelta.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,66 @@
from collections.abc import Callable, Mapping
from abc import ABC, abstractmethod
import timeit
from typing import Any, Generic, TypeVar
from typing import Generic, TypeAlias, TypeVar
from itertools import product

from clsp.dsl import DSL
from clsp.enumeration import enumerate_terms, interpret_term
from clsp.fcl import FiniteCombinatoryLogic

from clsp.types import Constructor, Literal, Param, LVar, Type

from clsp.types import Literal, Param, LVar, Type

Position: TypeAlias = tuple[int, int]
Carrier = TypeVar("Carrier")


class MazeSig(Generic[Carrier], ABC):
@abstractmethod
def up(self, a: tuple[int, int], b: tuple[int, int], p: Carrier) -> Carrier: ...
def up(self, a: Position, b: Position, p: Carrier) -> Carrier: ...
@abstractmethod
def down(self, a: tuple[int, int], b: tuple[int, int], p: Carrier) -> Carrier: ...
def down(self, a: Position, b: Position, p: Carrier) -> Carrier: ...
@abstractmethod
def left(self, a: tuple[int, int], b: tuple[int, int], p: Carrier) -> Carrier: ...
def left(self, a: Position, b: Position, p: Carrier) -> Carrier: ...
@abstractmethod
def right(self, a: tuple[int, int], b: tuple[int, int], p: Carrier) -> Carrier: ...
def right(self, a: Position, b: Position, p: Carrier) -> Carrier: ...
@abstractmethod
def start(self) -> Carrier: ...

def as_dict(self) -> dict[str, Any]:
def as_dict(self) -> dict[str, Callable[..., Carrier]]:
return {"U": self.up, "D": self.down, "L": self.left, "R": self.right, "START": self.start}


class MazeString(MazeSig[str]):
def up(self, a: tuple[int, int], b: tuple[int, int], p: str) -> str:
def up(self, a: Position, b: Position, p: str) -> str:
return f"{p} => UP({b})"

def down(self, a: tuple[int, int], b: tuple[int, int], p: str) -> str:
def down(self, a: Position, b: Position, p: str) -> str:
return f"{p} => DOWN({b})"

def left(self, a: tuple[int, int], b: tuple[int, int], p: str) -> str:
def left(self, a: Position, b: Position, p: str) -> str:
return f"{p} => LEFT({b})"

def right(self, a: tuple[int, int], b: tuple[int, int], p: str) -> str:
def right(self, a: Position, b: Position, p: str) -> str:
return f"{p} => RIGHT({b})"

def start(self) -> str:
return "START"


class MazePoints(MazeSig[list[tuple[int, int]]]):
def up(
self, a: tuple[int, int], b: tuple[int, int], p: list[tuple[int, int]]
) -> list[tuple[int, int]]:
class MazePoints(MazeSig[list[Position]]):
def up(self, a: Position, b: Position, p: list[Position]) -> list[Position]:
return p + [b]

def down(
self, a: tuple[int, int], b: tuple[int, int], p: list[tuple[int, int]]
) -> list[tuple[int, int]]:
def down(self, a: Position, b: Position, p: list[Position]) -> list[Position]:
return p + [b]

def left(
self, a: tuple[int, int], b: tuple[int, int], p: list[tuple[int, int]]
) -> list[tuple[int, int]]:
def left(self, a: Position, b: Position, p: list[Position]) -> list[Position]:
return p + [b]

def right(
self, a: tuple[int, int], b: tuple[int, int], p: list[tuple[int, int]]
) -> list[tuple[int, int]]:
def right(self, a: Position, b: Position, p: list[Position]) -> list[Position]:
return p + [b]

def start(self) -> list[tuple[int, int]]:
def start(self) -> list[Position]:
return [(0, 0)]


Expand All @@ -80,7 +72,7 @@ def is_free(col: int, row: int) -> bool:
else:
return pow(11, (row + col + SEED) * (row + col + SEED) + col + 7, 1000003) % 5 > 0

pos: Callable[[str], Type] = lambda p: Constructor("pos", LVar(p))
pos: Callable[[str], Type] = lambda p: "pos" @ LVar(p)

repo: Mapping[
str,
Expand Down Expand Up @@ -129,15 +121,15 @@ def is_free(col: int, row: int) -> bool:
start = timeit.default_timer()
grammar = fcl.inhabit(fin)

for term in enumerate_terms(fin, grammar, 3):
for term in enumerate_terms(fin, grammar, 10):
t = interpret_term(term, MazeString().as_dict())
p = interpret_term(term, MazePoints().as_dict())
if output:
print(t)
for row in range(SIZE):
for col in range(SIZE):
if (col, row) in p:
print("X", end="")
print("\033[93mX\033[0m", end="")
elif is_free(col, row):
print("-", end="")
else:
Expand All @@ -148,4 +140,4 @@ def is_free(col: int, row: int) -> bool:


if __name__ == "__main__":
main(20)
main()

0 comments on commit 54d5b35

Please sign in to comment.