Skip to content

Commit e58a439

Browse files
committed
Honestly don't know what I was last working on here
1 parent 451b81f commit e58a439

File tree

5 files changed

+236
-39
lines changed

5 files changed

+236
-39
lines changed

notes/invsets.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Invsets
2+
3+
Given a set $U$, a permutation oracle $O \rightarrow \{0, 1\}$
4+
5+
An order is defined as a series of sets $S_1$, $S_2$, ..., $S_n$ where $|S_n| = n$ and $S_n \subseteq U$, i.e. $\emptyset \subset S_1 \subset S_2 \subset ... \subset S_n \subseteq U$.
6+
7+
The inversion set ($invset(\pi)$) of a permutation $\pi$ of $n$ elements is the set of pairs $(i, j)$ such that $i < j$ and $\pi(i) > \pi(j)$
8+
9+
For each pair $(i, j) | i \neq j$, initialize counters:
10+
- $M^+_{(i, j)} \gets 0$ // positive (accepted) count
11+
- $M^-_{(i, j)} \gets 0$ // negative (rejected) count
12+
13+
For each permutation $\pi$ queried:
14+
- If $O(\pi) = 1$ (accepted): for each $(i, j) \in invset(\pi)$, increment $M^+_{(i, j)} \gets M^+_{(i, j)} + 1$
15+
- If $O(\pi) = 0$ (rejected): for each $(i, j) \in invset(\pi)$, increment $M^-_{(i, j)} \gets M^-_{(i, j)} + 1$
16+
17+
$P_{observed}^+(i, j) = M^+_{(i, j)} / N^+$ (frequency in positive samples)
18+
19+
$P_{observed}^-(i, j) = M^-_{(i, j)} / N^-$ (frequency in negative samples)
20+
21+
$P_{observed+}(x) = (1-\epsilon)P(1) + \epsilon(1 - P(1))$
22+
$P_{observed-}(x) = (1-\epsilon)P(0) + \epsilon(1 - P(0))$
23+
$P(1) = 1 - P(0)$
24+
25+
26+
Greatest Lower Bound
27+
28+
$meet(A, B) \rightarrow invset(A) \cap invset(B)$
29+
30+
Greatest Upper Bound
31+
$join(A, B) \rightarrow invset(A) \cup invset(B)$
32+
33+
Assume oracle outputs are flipped with probability $\epsilon$. We can estimate the 'true' probability $P(x)$ from our observed samples $P_{obs}(x)$ as follows:
34+
35+
$P_{observed}(x) = (1 - \epsilon) * P(x) + \epsilon * (1 - P(x))$
36+
37+
$P_{observed}(x) = P(x) - \epsilon P(x) + \epsilon - \epsilon P(x)$
38+
39+
$P_{observed}(x) = P(x) - \epsilon P(x) - \epsilon P(x) + \epsilon$
40+
41+
$P_{observed}(x) = P(x) (1 - \epsilon - \epsilon) + \epsilon$
42+
43+
$P_{observed}(x) = P(x) (1 - 2\epsilon) + \epsilon$
44+
45+
$P_{observed}(x) - \epsilon = P(x) (1 - 2\epsilon)$
46+
47+
$\frac{P_{observed}(x) - \epsilon}{1 - 2\epsilon} = P(x)$
48+
49+
50+
$$ f = (1-p)q + p(1-q) $$
51+
$$ f = q - pq + p - pq $$
52+
$$ f = q - pq + p - pq $$
53+
$$ f = q(1-p-p) + p $$
54+
$$ f = q(1-2p) + p $$
55+
$$ f - p = q (1 - 2p) $$
56+
$$ \frac{f - p}{1 - 2p} = q $$
57+
x = (1 - p) * x + p * (x)
58+
59+
60+
$P(\text{observed} = 1 \mid \text{true} = 0) = p$
61+
$P(\text{observed} = 1 \mid \text{true} = 1) = 1-p$
62+
$P(\text{observed} = 0 \mid \text{true} = 0) = 1-p$
63+
$P(\text{observed} = 0 \mid \text{true} = 1) = p$
64+
65+
66+
if $O(A) = 1$ and $O(B) = 1$
67+
68+
?
69+
70+
if $O(A) = 0$ and $O(B) = 0$
71+
72+
?
73+
74+
if $O(A) = 1$ and $O(B) = 0$
75+
76+
?
77+
78+
if $O(A) = 0$ and $O(B) = 1$
79+
80+
?
81+
82+
# DB
83+
84+
Record:
85+
- Key
86+
- Version
87+
- Payload

plugin_oracle/base/db.py

Lines changed: 107 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pickle
33
import random
44
from typing import TypeVar, Generic, Protocol
5+
from copy import deepcopy
56

67
from plugin_oracle.util.mod.mod import Mod
78

@@ -14,9 +15,9 @@ class HasHashAndInd(Protocol):
1415
class MSet(Generic[T]):
1516
def __init__(self) -> None:
1617
self._U: dict[bytes, T] = {}
17-
self._fsets: dict[bytes, set[bytes]] = {}
18-
self._isets: dict[bytes, set[bytes]] = {}
18+
self._fsets: dict[bytes, dict[bytes, float]] = {}
1919
self._bound: tuple[list[bytes], list[bytes]] = ([], [])
20+
self._e: float = 0.2
2021

2122
@property
2223
def U(self) -> dict[bytes, T]:
@@ -30,42 +31,65 @@ def min(self) -> list[bytes]:
3031
def max(self) -> list[bytes]:
3132
return self._bound[1]
3233

33-
def __getitem__(self, hash: bytes) -> tuple[T | None, set[bytes] | None, set[bytes] | None]:
34-
return (self._U.get(hash, None), self._fsets.get(hash, None), self._isets.get(hash, None))
34+
def __getitem__(self, hash: bytes) -> tuple[T, dict[bytes, float]]:
35+
return (self._U[hash], self._fsets[hash])
36+
37+
def __contains__(self, hash: bytes) -> bool:
38+
return hash in self._U
3539

3640
def addU(self, v: T) -> None:
3741
if v.hash in self._U:
3842
return
3943
self._U[v.hash] = v
40-
self._fsets[v.hash] = self._fsets.get(v.hash, set())
41-
self._fsets[v.hash].update(self._U.keys())
42-
self._fsets[v.hash].discard(v.hash)
43-
for k in self._U.keys():
44-
if k != v.hash:
45-
self._fsets[k].add(v.hash)
46-
self._isets[v.hash] = self._isets.get(v.hash, set())
47-
self._isets[v.hash].update(self._U.keys())
48-
self._isets[v.hash].discard(v.hash)
44+
self._fsets[v.hash] = self._fsets.get(v.hash, {})
45+
self._fsets[v.hash].update({k: 1.0 for k in self._U.keys()})
46+
del self._fsets[v.hash][v.hash]
4947
for k in self._U.keys():
5048
if k != v.hash:
51-
self._isets[k].add(v.hash)
49+
self._fsets[k][v.hash] = 1.0
5250

5351
def permutation(self, perm: list[bytes], state: bool) -> None:
5452
if any(hash not in self._U for hash in perm):
5553
raise ValueError('Unfiltered permutation encountered')
5654
if state:
57-
if perm < self._bound[0]:
55+
if perm < self._bound[0] or self._bound[0] == []:
5856
self._bound = (perm, self._bound[1])
59-
if perm > self._bound[1]:
57+
if perm > self._bound[1] or self._bound[1] == []:
6058
self._bound = (self._bound[0], perm)
61-
62-
for i in range(len(perm)):
63-
fset = self._fsets[perm[i]] if state else self._isets[perm[i]]
64-
for j in range(i):
65-
fset.discard(perm[j])
59+
if state:
60+
for i in range(len(perm)):
61+
fset = self._fsets[perm[i]]
62+
for j in range(i):
63+
fset[perm[j]] *= self._e
64+
else:
65+
l = len(perm)
66+
tU = set(self._U.keys())
67+
tlen = len(tU)
68+
for i in range(l):
69+
tU.discard(perm[i])
70+
fset = self._fsets[perm[i]]
71+
tlen -= 1
72+
for j in fset.keys():
73+
if j in tU:
74+
fset[j] += ((1 - self._e) / (tlen * (tlen - 1)))
75+
if fset[j] > 1.0:
76+
fset[j] = 1.0
77+
78+
def edrop(self) -> 'MSet[T]':
79+
out = MSet[T]()
80+
out._U = self._U.copy() # pyright: ignore [reportConstantRedefinition]
81+
out._fsets = deepcopy(self._fsets)
82+
for k in out._U.keys():
83+
fset = out._fsets[k]
84+
follow = list(fset.keys())
85+
for m in follow:
86+
if fset[m] < (1 - self._e):
87+
del fset[m]
88+
out._bound = self._bound
89+
return out
6690

67-
def rtoposort(self, state: bool, seed: int = 0) -> list[bytes] | None:
68-
adj = self._fsets if state else self._isets
91+
def rtoposort(self, seed: int = 0) -> list[bytes] | None:
92+
adj = self._fsets
6993
indegree = {k: 0 for k in adj}
7094
for vs in adj.values():
7195
for v in vs:
@@ -76,7 +100,7 @@ def rtoposort(self, state: bool, seed: int = 0) -> list[bytes] | None:
76100
while queue:
77101
n = queue.pop(random.randrange(len(queue)))
78102
L.append(n)
79-
for m in adj.get(n, set()):
103+
for m in adj.get(n, {}):
80104
indegree[m] -= 1
81105
if indegree[m] == 0:
82106
queue.append(m)
@@ -85,8 +109,8 @@ def rtoposort(self, state: bool, seed: int = 0) -> list[bytes] | None:
85109
return None
86110
return L
87111

88-
def toposort(self, state: bool) -> list[bytes] | None:
89-
adj = self._fsets if state else self._isets
112+
def toposort(self) -> list[bytes] | None:
113+
adj = self._fsets
90114
indegree = {k: 0 for k in adj}
91115
for vs in adj.values():
92116
for v in vs:
@@ -96,16 +120,71 @@ def toposort(self, state: bool) -> list[bytes] | None:
96120
while queue:
97121
n = queue.pop()
98122
L.append(n)
99-
for m in adj.get(n, set()):
123+
for m in adj.get(n, {}):
100124
indegree[m] -= 1
101125
if indegree[m] == 0:
102126
queue.append(m)
103127
if len(L) != len(indegree):
104128
return None
105129
return L
130+
131+
def assemble(self) -> None | list[tuple[bytes, bytes]]:
132+
leximin: list[bytes] = self._bound[0]
133+
leximax: list[bytes] = self._bound[1]
134+
if leximin == [] or leximax == []:
135+
return None
136+
anc: dict[bytes, list[None | bytes]] = {b: [None, None] for b in leximin}
137+
# Boolean markers for validation.
138+
hv: dict[bytes, bool] = {b: False for b in leximin}
139+
stack: list[bytes] = []
140+
141+
# First pass: based on leximin.
142+
for b in leximin:
143+
while stack and stack[-1] < b:
144+
_ = stack.pop()
145+
if stack:
146+
anc[b][0] = stack[-1]
147+
stack.append(b)
148+
149+
stack.clear()
150+
151+
# Second pass: based on leximax.
152+
for b in leximax:
153+
while stack and stack[-1] > b:
154+
_ = stack.pop()
155+
hv[b] = True
156+
# If a parent was set in first pass but hasn't been marked yet, it's inconsistent.
157+
r = anc[b][0]
158+
if r is not None and not hv.get(r, False):
159+
return None
160+
if stack:
161+
if anc[b][0] is not None:
162+
anc[b][1] = stack[-1]
163+
else:
164+
anc[b][0] = stack[-1]
165+
stack.append(b)
166+
167+
# Reset hv markers.
168+
for b in leximin:
169+
hv[b] = False
170+
171+
# Third pass: validate that all assigned parents appear earlier.
172+
for b in leximin:
173+
hv[b] = True
174+
for parent in anc[b]:
175+
if parent is not None and not hv[parent]:
176+
return None
177+
178+
# Build the edge list: for every byte b, each non-None parent becomes an edge (parent -> b).
179+
elist: list[tuple[bytes, bytes]] = []
180+
for b in leximin:
181+
for parent in anc[b]:
182+
if parent is not None:
183+
elist.append((parent, b))
184+
return elist
106185

107186
class MDB:
108-
_fname: str = '/db_0.pkl'
187+
_fname: str = '/db_1.pkl'
109188

110189
def __init__(self) -> None:
111190
self.mod: MSet[Mod] = MSet[Mod]()

plugin_oracle/base/oracle/oracle.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def observe(self, result: bool, mlist: IModList, organizer: IOrganizer) -> None:
9393

9494
def sample(self, mlist: IModList, plist: IPluginList, organizer: IOrganizer) -> None:
9595
t0 = time()
96-
ld = self.db.mod.rtoposort(True)
96+
ld = self.db.mod.edrop().rtoposort()
9797
if ld is None:
9898
self._log.warning('Failed to find a topological sort!')
9999
return
@@ -126,9 +126,9 @@ def predict(self, mlist: IModList, organizer: IOrganizer) -> str:
126126
loadorder = self.permutation(mlist, organizer)
127127
report: list[tuple[str, str]] = []
128128
for i, hash in enumerate(loadorder):
129-
_, fset, _ = self.db.mod[hash]
129+
_, fset = self.db.mod[hash]
130130
for j in range(i + 1, len(loadorder)):
131-
if fset is None or loadorder[j] not in fset:
131+
if loadorder[j] not in fset:
132132
m0: Mod = self.db.mod.U.get(hash, Mod(hash))
133133
m1: Mod = self.db.mod.U.get(loadorder[j], Mod(loadorder[j]))
134134
self.db.mod.addU(m0)

plugin_oracle/base/window.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,28 @@ def __init__(self, oracle: Oracle, samplers: list[Callable[[bool], None]], repor
1818
layout = QVBoxLayout()
1919
self.setLayout(layout)
2020
tab_widget = QTabWidget()
21-
tabs: list[QWidget] = [QWidget()]
22-
layouts: list[QVBoxLayout] = [QVBoxLayout()]
23-
tabnames: list[str] = ['Graph']
21+
tabs: list[QWidget] = [QWidget(), QWidget()]
22+
layouts: list[QVBoxLayout] = [QVBoxLayout(), QVBoxLayout()]
23+
tabnames: list[str] = ['Graph', 'Bounds']
2424

25-
tsort = self.oracle.db.mod.rtoposort(True)
25+
graph = self.oracle.db.mod.edrop()
26+
tsort = graph.toposort()
2627
if tsort is None:
2728
label = QLabel("Cannot render: the mod graph contains a cycle.")
2829
layouts[0].addWidget(label)
2930
else:
30-
graph_widget = OracleGraph(tsort, self.oracle.db.mod) # pyright: ignore [reportArgumentType]
31+
graph_widget = OracleGraph(tsort, graph) # pyright: ignore [reportArgumentType]
3132
layouts[0].addWidget(graph_widget)
33+
34+
elist = self.oracle.db.mod.assemble()
35+
if not elist:
36+
label = QLabel("Cannot render: the mod graph is empty.")
37+
layouts[1].addWidget(label)
38+
else:
39+
labels = {k: v.name for k, v in self.oracle.db.mod.U.items()}
40+
bound_graph_widget = OracleGraph.from_elist(elist, labels)
41+
layouts[1].addWidget(bound_graph_widget)
42+
3243
for i in range(len(tabs)):
3344
tabs[i].setLayout(layouts[i])
3445
_ = tab_widget.addTab(tabs[i], tabnames[i])
@@ -66,10 +77,22 @@ def on_predict(self):
6677
_ = QMessageBox.warning(self, "Predict Error", str(e))
6778

6879
class OracleGraph(QWidget):
69-
def __init__(self, order: list[bytes], mset: MSet[HasHashAndInd], parent: QWidget | None = None) -> None:
80+
@staticmethod
81+
def from_mset(order: list[bytes], mset: MSet[HasHashAndInd]) -> 'OracleGraph':
82+
config: MetroConfig = MetroConfig()
83+
renderer: MetroRender = MetroRender.from_mset(config, order, mset)
84+
return OracleGraph(config, renderer)
85+
86+
@staticmethod
87+
def from_elist(elist: list[tuple[bytes, bytes]], labels: dict[bytes, str]) -> 'OracleGraph':
88+
config: MetroConfig = MetroConfig()
89+
renderer: MetroRender = MetroRender.from_edgelist_and_labels(config, elist, labels)
90+
return OracleGraph(config, renderer)
91+
92+
def __init__(self, config: MetroConfig, renderer: MetroRender, parent: QWidget | None = None) -> None:
7093
super().__init__(parent)
71-
self.config: MetroConfig = MetroConfig()
72-
self.renderer: MetroRender = MetroRender.from_mset(self.config, order, mset)
94+
self.config: MetroConfig = config
95+
self.renderer: MetroRender = renderer
7396
self._scale: float = 4.0
7497
self._offset_x: float = 100
7598
self._offset_y: float = 0

plugin_oracle/util/ml/order.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
def inversion(x: list[bytes]) -> set[tuple[bytes, bytes]]:
3+
oset: set[tuple[bytes, bytes]] = set()
4+
for i in range(len(x)):
5+
for j in range(i + 1, len(x)):
6+
if x[i] > x[j]:
7+
oset.add((x[i], x[j]))
8+
return oset

0 commit comments

Comments
 (0)