Skip to content

Commit 3392474

Browse files
Merge pull request #9 from nikobockerman:digraph_container
Add Digraph tool class
2 parents 3400bf5 + bdbdcc8 commit 3392474

File tree

3 files changed

+166
-0
lines changed

3 files changed

+166
-0
lines changed

adventofcode/tooling/digraph.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from __future__ import annotations
2+
3+
import typing
4+
from dataclasses import dataclass
5+
from functools import cache
6+
from typing import Iterable, Protocol, runtime_checkable
7+
8+
9+
@runtime_checkable
10+
class NodeId(typing.Hashable, Protocol):
11+
pass
12+
13+
14+
@dataclass(kw_only=True, slots=True)
15+
class Digraph[Id: NodeId, N]:
16+
nodes: dict[Id, N] # TODO: Consider replacing with a frozendict
17+
arcs: tuple[DigraphArc[Id], ...]
18+
19+
def get_arcs_to(self, node_id: Id, /) -> list[DigraphArc[Id]]:
20+
return _get_arcs_to_node(node_id, self.arcs)
21+
22+
def get_arcs_from(self, node_id: Id, /) -> list[DigraphArc[Id]]:
23+
return _get_arcs_from_node(node_id, self.arcs)
24+
25+
26+
class DigraphArc[Id: NodeId](Protocol):
27+
@property
28+
def from_(self) -> Id: ...
29+
@property
30+
def to(self) -> Id: ...
31+
32+
33+
@dataclass(frozen=True, slots=True)
34+
class Arc[Id: NodeId]:
35+
from_: Id
36+
to: Id
37+
38+
39+
class DigraphCreator[Id: NodeId, N]:
40+
def __init__(self) -> None:
41+
self._nodes: dict[Id, N] = dict()
42+
self._arcs: list[DigraphArc[Id]] = list()
43+
44+
def add_node(self, node_id: Id, node: N, /) -> None:
45+
if node_id in self._nodes:
46+
raise ValueError(node_id)
47+
self._nodes[node_id] = node
48+
49+
def add_arc(self, arc: DigraphArc[Id], /) -> None:
50+
if arc.from_ not in self._nodes:
51+
raise ValueError(arc.from_)
52+
if arc.to not in self._nodes:
53+
raise ValueError(arc.to)
54+
self._arcs.append(arc)
55+
56+
def create(self) -> Digraph[Id, N]:
57+
return Digraph(nodes=self._nodes, arcs=tuple(self._arcs))
58+
59+
60+
@cache
61+
def _get_arcs_to_node[Id: NodeId](
62+
node_id: Id, arcs: Iterable[DigraphArc[Id]]
63+
) -> list[DigraphArc[Id]]:
64+
return list(arc for arc in arcs if arc.to == node_id)
65+
66+
67+
@cache
68+
def _get_arcs_from_node[Id: NodeId](
69+
node_id: Id, arcs: Iterable[DigraphArc[Id]]
70+
) -> list[DigraphArc[Id]]:
71+
return list(arc for arc in arcs if arc.from_ == node_id)

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ adventofcode = "adventofcode.main:app"
3939
[tool.mypy]
4040
python_version = "3.12"
4141
strict_optional = true
42+
enable_incomplete_feature = ["NewGenericSyntax"]
4243

4344
[tool.ruff.lint]
4445
select = [

tests/test_digraph.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from dataclasses import dataclass
2+
from typing import assert_type
3+
4+
from adventofcode.tooling.digraph import Arc, Digraph, DigraphCreator
5+
6+
7+
def test_digraph_creator_simple() -> None:
8+
creator = DigraphCreator[int, int]()
9+
creator.add_node(1, 11)
10+
creator.add_node(2, 22)
11+
creator.add_arc(Arc(1, 2))
12+
digraph = creator.create()
13+
assert digraph.nodes == {1: 11, 2: 22}
14+
assert digraph.arcs == (Arc(1, 2),)
15+
assert_type(digraph.nodes, dict[int, int])
16+
17+
18+
def test_digraph_creator_two_types() -> None:
19+
creator = DigraphCreator[int | str, int | str]()
20+
creator.add_node("a", "aa")
21+
creator.add_node(1, 11)
22+
creator.add_node("b", 22)
23+
creator.add_arc(Arc("a", 1))
24+
creator.add_arc(Arc("a", "b"))
25+
creator.add_arc(Arc(1, "b"))
26+
digraph = creator.create()
27+
assert digraph.nodes == {"a": "aa", 1: 11, "b": 22}
28+
assert digraph.arcs == (Arc("a", 1), Arc("a", "b"), Arc(1, "b"))
29+
assert_type(digraph.nodes, dict[int | str, int | str])
30+
31+
32+
def test_digraph_creator_multiple_inherited_classes() -> None:
33+
@dataclass
34+
class Base:
35+
name: str
36+
37+
class Child1(Base):
38+
pass
39+
40+
class Child2(Base):
41+
pass
42+
43+
creator = DigraphCreator[str, Child1 | Child2]()
44+
creator.add_node("a", Child1("a"))
45+
creator.add_node("b", Child2("b"))
46+
creator.add_arc(Arc("a", "b"))
47+
digraph = creator.create()
48+
assert digraph.nodes == {"a": Child1("a"), "b": Child2("b")}
49+
assert digraph.arcs == (Arc("a", "b"),)
50+
assert_type(digraph.nodes, dict[str, Child1 | Child2])
51+
52+
53+
def test_digraph_get_arcs() -> None:
54+
digraph = Digraph[int, int](
55+
nodes={1: 11, 2: 22, 3: 33, 4: 44},
56+
arcs=tuple((Arc(1, 2), Arc(1, 3), Arc(2, 3), Arc(3, 1))),
57+
)
58+
assert digraph.get_arcs_from(1) == [Arc(1, 2), Arc(1, 3)]
59+
assert digraph.get_arcs_from(2) == [Arc(2, 3)]
60+
assert digraph.get_arcs_from(3) == [Arc(3, 1)]
61+
assert digraph.get_arcs_from(4) == []
62+
assert digraph.get_arcs_to(1) == [Arc(3, 1)]
63+
assert digraph.get_arcs_to(2) == [Arc(1, 2)]
64+
assert digraph.get_arcs_to(3) == [Arc(1, 3), Arc(2, 3)]
65+
assert digraph.get_arcs_to(4) == []
66+
67+
68+
def test_digraph_weighted_arcs() -> None:
69+
@dataclass(frozen=True)
70+
class WeightedArc:
71+
from_: str
72+
to: str
73+
weight: int
74+
75+
digraph_creator = DigraphCreator[str, int]()
76+
digraph_creator.add_node("a", 1)
77+
digraph_creator.add_node("b", 2)
78+
digraph_creator.add_node("c", 3)
79+
digraph_creator.add_arc(WeightedArc("a", "b", 3))
80+
digraph_creator.add_arc(WeightedArc("a", "c", 4))
81+
digraph_creator.add_arc(WeightedArc("b", "c", 5))
82+
digraph = digraph_creator.create()
83+
assert digraph.get_arcs_from("a") == [
84+
WeightedArc("a", "b", 3),
85+
WeightedArc("a", "c", 4),
86+
]
87+
assert digraph.get_arcs_from("b") == [WeightedArc("b", "c", 5)]
88+
assert digraph.get_arcs_from("c") == []
89+
assert digraph.get_arcs_to("a") == []
90+
assert digraph.get_arcs_to("b") == [WeightedArc("a", "b", 3)]
91+
assert digraph.get_arcs_to("c") == [
92+
WeightedArc("a", "c", 4),
93+
WeightedArc("b", "c", 5),
94+
]

0 commit comments

Comments
 (0)