22import pickle
33import random
44from typing import TypeVar , Generic , Protocol
5+ from copy import deepcopy
56
67from plugin_oracle .util .mod .mod import Mod
78
@@ -14,9 +15,9 @@ class HasHashAndInd(Protocol):
1415class MSet (Generic [T ]):
1516 def __init__ (self ) -> None :
1617 self ._U : dict [bytes , T ] = {}
17- self ._fsets : dict [bytes , set [bytes ]] = {}
18- self ._isets : dict [bytes , set [bytes ]] = {}
18+ self ._fsets : dict [bytes , dict [bytes , float ]] = {}
1919 self ._bound : tuple [list [bytes ], list [bytes ]] = ([], [])
20+ self ._e : float = 0.2
2021
2122 @property
2223 def U (self ) -> dict [bytes , T ]:
@@ -30,42 +31,65 @@ def min(self) -> list[bytes]:
3031 def max (self ) -> list [bytes ]:
3132 return self ._bound [1 ]
3233
33- def __getitem__ (self , hash : bytes ) -> tuple [T | None , set [bytes ] | None , set [bytes ] | None ]:
34- return (self ._U .get (hash , None ), self ._fsets .get (hash , None ), self ._isets .get (hash , None ))
34+ def __getitem__ (self , hash : bytes ) -> tuple [T , dict [bytes , float ]]:
35+ return (self ._U [hash ], self ._fsets [hash ])
36+
37+ def __contains__ (self , hash : bytes ) -> bool :
38+ return hash in self ._U
3539
3640 def addU (self , v : T ) -> None :
3741 if v .hash in self ._U :
3842 return
3943 self ._U [v .hash ] = v
40- self ._fsets [v .hash ] = self ._fsets .get (v .hash , set ())
41- self ._fsets [v .hash ].update (self ._U .keys ())
42- self ._fsets [v .hash ].discard (v .hash )
43- for k in self ._U .keys ():
44- if k != v .hash :
45- self ._fsets [k ].add (v .hash )
46- self ._isets [v .hash ] = self ._isets .get (v .hash , set ())
47- self ._isets [v .hash ].update (self ._U .keys ())
48- self ._isets [v .hash ].discard (v .hash )
44+ self ._fsets [v .hash ] = self ._fsets .get (v .hash , {})
45+ self ._fsets [v .hash ].update ({k : 1.0 for k in self ._U .keys ()})
46+ del self ._fsets [v .hash ][v .hash ]
4947 for k in self ._U .keys ():
5048 if k != v .hash :
51- self ._isets [k ]. add ( v .hash )
49+ self ._fsets [k ][ v .hash ] = 1.0
5250
5351 def permutation (self , perm : list [bytes ], state : bool ) -> None :
5452 if any (hash not in self ._U for hash in perm ):
5553 raise ValueError ('Unfiltered permutation encountered' )
5654 if state :
57- if perm < self ._bound [0 ]:
55+ if perm < self ._bound [0 ] or self . _bound [ 0 ] == [] :
5856 self ._bound = (perm , self ._bound [1 ])
59- if perm > self ._bound [1 ]:
57+ if perm > self ._bound [1 ] or self . _bound [ 1 ] == [] :
6058 self ._bound = (self ._bound [0 ], perm )
61-
62- for i in range (len (perm )):
63- fset = self ._fsets [perm [i ]] if state else self ._isets [perm [i ]]
64- for j in range (i ):
65- fset .discard (perm [j ])
59+ if state :
60+ for i in range (len (perm )):
61+ fset = self ._fsets [perm [i ]]
62+ for j in range (i ):
63+ fset [perm [j ]] *= self ._e
64+ else :
65+ l = len (perm )
66+ tU = set (self ._U .keys ())
67+ tlen = len (tU )
68+ for i in range (l ):
69+ tU .discard (perm [i ])
70+ fset = self ._fsets [perm [i ]]
71+ tlen -= 1
72+ for j in fset .keys ():
73+ if j in tU :
74+ fset [j ] += ((1 - self ._e ) / (tlen * (tlen - 1 )))
75+ if fset [j ] > 1.0 :
76+ fset [j ] = 1.0
77+
78+ def edrop (self ) -> 'MSet[T]' :
79+ out = MSet [T ]()
80+ out ._U = self ._U .copy () # pyright: ignore [reportConstantRedefinition]
81+ out ._fsets = deepcopy (self ._fsets )
82+ for k in out ._U .keys ():
83+ fset = out ._fsets [k ]
84+ follow = list (fset .keys ())
85+ for m in follow :
86+ if fset [m ] < (1 - self ._e ):
87+ del fset [m ]
88+ out ._bound = self ._bound
89+ return out
6690
67- def rtoposort (self , state : bool , seed : int = 0 ) -> list [bytes ] | None :
68- adj = self ._fsets if state else self . _isets
91+ def rtoposort (self , seed : int = 0 ) -> list [bytes ] | None :
92+ adj = self ._fsets
6993 indegree = {k : 0 for k in adj }
7094 for vs in adj .values ():
7195 for v in vs :
@@ -76,7 +100,7 @@ def rtoposort(self, state: bool, seed: int = 0) -> list[bytes] | None:
76100 while queue :
77101 n = queue .pop (random .randrange (len (queue )))
78102 L .append (n )
79- for m in adj .get (n , set () ):
103+ for m in adj .get (n , {} ):
80104 indegree [m ] -= 1
81105 if indegree [m ] == 0 :
82106 queue .append (m )
@@ -85,8 +109,8 @@ def rtoposort(self, state: bool, seed: int = 0) -> list[bytes] | None:
85109 return None
86110 return L
87111
88- def toposort (self , state : bool ) -> list [bytes ] | None :
89- adj = self ._fsets if state else self . _isets
112+ def toposort (self ) -> list [bytes ] | None :
113+ adj = self ._fsets
90114 indegree = {k : 0 for k in adj }
91115 for vs in adj .values ():
92116 for v in vs :
@@ -96,16 +120,71 @@ def toposort(self, state: bool) -> list[bytes] | None:
96120 while queue :
97121 n = queue .pop ()
98122 L .append (n )
99- for m in adj .get (n , set () ):
123+ for m in adj .get (n , {} ):
100124 indegree [m ] -= 1
101125 if indegree [m ] == 0 :
102126 queue .append (m )
103127 if len (L ) != len (indegree ):
104128 return None
105129 return L
130+
131+ def assemble (self ) -> None | list [tuple [bytes , bytes ]]:
132+ leximin : list [bytes ] = self ._bound [0 ]
133+ leximax : list [bytes ] = self ._bound [1 ]
134+ if leximin == [] or leximax == []:
135+ return None
136+ anc : dict [bytes , list [None | bytes ]] = {b : [None , None ] for b in leximin }
137+ # Boolean markers for validation.
138+ hv : dict [bytes , bool ] = {b : False for b in leximin }
139+ stack : list [bytes ] = []
140+
141+ # First pass: based on leximin.
142+ for b in leximin :
143+ while stack and stack [- 1 ] < b :
144+ _ = stack .pop ()
145+ if stack :
146+ anc [b ][0 ] = stack [- 1 ]
147+ stack .append (b )
148+
149+ stack .clear ()
150+
151+ # Second pass: based on leximax.
152+ for b in leximax :
153+ while stack and stack [- 1 ] > b :
154+ _ = stack .pop ()
155+ hv [b ] = True
156+ # If a parent was set in first pass but hasn't been marked yet, it's inconsistent.
157+ r = anc [b ][0 ]
158+ if r is not None and not hv .get (r , False ):
159+ return None
160+ if stack :
161+ if anc [b ][0 ] is not None :
162+ anc [b ][1 ] = stack [- 1 ]
163+ else :
164+ anc [b ][0 ] = stack [- 1 ]
165+ stack .append (b )
166+
167+ # Reset hv markers.
168+ for b in leximin :
169+ hv [b ] = False
170+
171+ # Third pass: validate that all assigned parents appear earlier.
172+ for b in leximin :
173+ hv [b ] = True
174+ for parent in anc [b ]:
175+ if parent is not None and not hv [parent ]:
176+ return None
177+
178+ # Build the edge list: for every byte b, each non-None parent becomes an edge (parent -> b).
179+ elist : list [tuple [bytes , bytes ]] = []
180+ for b in leximin :
181+ for parent in anc [b ]:
182+ if parent is not None :
183+ elist .append ((parent , b ))
184+ return elist
106185
107186class MDB :
108- _fname : str = '/db_0 .pkl'
187+ _fname : str = '/db_1 .pkl'
109188
110189 def __init__ (self ) -> None :
111190 self .mod : MSet [Mod ] = MSet [Mod ]()
0 commit comments