@@ -111,17 +111,25 @@ def as_constraint(self) -> Constraints:
111111
112112 @abstractmethod
113113 def min (self ) -> Any :
114- "Get the min value of the distribution"
114+ """ Get the min value of the distribution."" "
115115 ...
116116
117117 @abstractmethod
118118 def max (self ) -> Any :
119- "Get the max value of the distribution"
119+ """ Get the max value of the distribution."" "
120120 ...
121121
122- @abstractmethod
123122 def __eq__ (self , other : Any ) -> bool :
124- ...
123+ return type (self ) == type (other ) and self .get () == other .get ()
124+
125+ def __contains__ (self , item : Any ) -> bool :
126+ """
127+ Example:
128+ >>> dist = CategoricalDistribution(name="foo", choices=["a", "b", "c"])
129+ >>> "a" in dist
130+ True
131+ """
132+ return self .has (item )
125133
126134 @abstractmethod
127135 def dtype (self ) -> str :
@@ -146,7 +154,7 @@ def _validate_choices(cls: Any, v: List, values: Dict) -> List:
146154 raise ValueError (
147155 "Invalid choices for CategoricalDistribution. Provide data or choices params"
148156 )
149- return v
157+ return sorted ( set ( v ))
150158
151159 def get (self ) -> List [Any ]:
152160 return [self .name , self .choices ]
@@ -176,12 +184,6 @@ def min(self) -> Any:
176184 def max (self ) -> Any :
177185 return max (self .choices )
178186
179- def __eq__ (self , other : Any ) -> bool :
180- if not isinstance (other , CategoricalDistribution ):
181- return False
182-
183- return self .name == other .name and set (self .choices ) == set (other .choices )
184-
185187 def dtype (self ) -> str :
186188 types = {
187189 "object" : 0 ,
@@ -259,33 +261,24 @@ def min(self) -> Any:
259261 def max (self ) -> Any :
260262 return self .high
261263
262- def __eq__ (self , other : Any ) -> bool :
263- if not isinstance (other , type (self )):
264- return False
265-
266- return (
267- self .name == other .name
268- and self .low == other .low
269- and self .high == other .high
270- )
271-
272264 def dtype (self ) -> str :
273265 return "float"
274266
275267
276268class LogDistribution (FloatDistribution ):
277269 low : float = np .finfo (np .float64 ).tiny
278270 high : float = np .finfo (np .float64 ).max
279- base : float = 2.0
271+
272+ def get (self ) -> List [Any ]:
273+ return [self .name , self .low , self .high ]
280274
281275 def sample (self , count : int = 1 ) -> Any :
282276 np .random .seed (self .random_state )
283277 msamples = self .sample_marginal (count )
284278 if msamples is not None :
285279 return msamples
286- lo = np .log2 (self .low ) / np .log2 (self .base )
287- hi = np .log2 (self .high ) / np .log2 (self .base )
288- return self .base ** np .random .uniform (lo , hi , count )
280+ lo , hi = np .log2 (self .low ), np .log2 (self .high )
281+ return 2.0 ** np .random .uniform (lo , hi , count )
289282
290283
291284class IntegerDistribution (Distribution ):
@@ -313,6 +306,12 @@ def _validate_high_thresh(cls: Any, v: int, values: Dict) -> int:
313306 return int (values [mkey ].index .max ())
314307 return v
315308
309+ @validator ("step" , always = True )
310+ def _validate_step (cls : Any , v : int , values : Dict ) -> int :
311+ if v < 1 :
312+ raise ValueError ("Step must be greater than 0" )
313+ return v
314+
316315 def get (self ) -> List [Any ]:
317316 return [self .name , self .low , self .high , self .step ]
318317
@@ -322,9 +321,9 @@ def sample(self, count: int = 1) -> Any:
322321 if msamples is not None :
323322 return msamples
324323
325- high = (self .high + 1 - self .low ) // self .step
326- s = np .random .choice (high , count )
327- return s * self .step + self .low
324+ steps = (self .high - self .low ) // self .step
325+ samples = np .random .choice (steps + 1 , count )
326+ return samples * self .step + self .low
328327
329328 def has (self , val : Any ) -> bool :
330329 return self .low <= val and val <= self .high
@@ -347,34 +346,31 @@ def min(self) -> Any:
347346 def max (self ) -> Any :
348347 return self .high
349348
350- def __eq__ (self , other : Any ) -> bool :
351- if not isinstance (other , IntegerDistribution ):
352- return False
353-
354- return (
355- self .name == other .name
356- and self .low == other .low
357- and self .high == other .high
358- )
359-
360349 def dtype (self ) -> str :
361350 return "int"
362351
363352
364- class LogIntDistribution (FloatDistribution ):
365- low : float = 1.0
366- high : float = float (np .iinfo (np .int64 ).max )
367- base : float = 2.0
353+ class IntLogDistribution (IntegerDistribution ):
354+ low : int = 1
355+ high : int = np .iinfo (np .int64 ).max
356+
357+ @validator ("step" , always = True )
358+ def _validate_step (cls : Any , v : int , values : Dict ) -> int :
359+ if v != 1 :
360+ raise ValueError ("Step must be 1 for IntLogDistribution" )
361+ return v
362+
363+ def get (self ) -> List [Any ]:
364+ return [self .name , self .low , self .high ]
368365
369366 def sample (self , count : int = 1 ) -> Any :
370367 np .random .seed (self .random_state )
371368 msamples = self .sample_marginal (count )
372369 if msamples is not None :
373370 return msamples
374- lo = np .log2 (self .low ) / np .log2 (self .base )
375- hi = np .log2 (self .high ) / np .log2 (self .base )
376- s = self .base ** np .random .uniform (lo , hi , count )
377- return s .astype (int )
371+ lo , hi = np .log2 (self .low ), np .log2 (self .high )
372+ samples = 2.0 ** np .random .uniform (lo , hi , count )
373+ return samples .astype (int )
378374
379375
380376class DatetimeDistribution (Distribution ):
@@ -383,49 +379,46 @@ class DatetimeDistribution(Distribution):
383379 :parts: 1
384380 """
385381
386- offset : int = 120
387382 low : datetime = datetime .utcfromtimestamp (0 )
388383 high : datetime = datetime .now ()
389-
390- @validator ("offset" , always = True )
391- def _validate_offset (cls : Any , v : int ) -> int :
392- if v < 0 :
393- raise ValueError ("offset must be greater than 0" )
394- return v
384+ step : timedelta = timedelta (microseconds = 1 )
385+ offset : timedelta = timedelta (seconds = 120 )
395386
396387 @validator ("low" , always = True )
397388 def _validate_low_thresh (cls : Any , v : datetime , values : Dict ) -> datetime :
398389 mkey = "marginal_distribution"
399390 if mkey in values and values [mkey ] is not None :
400391 v = values [mkey ].index .min ()
401- return v - timedelta ( seconds = values [ "offset" ])
392+ return v
402393
403394 @validator ("high" , always = True )
404395 def _validate_high_thresh (cls : Any , v : datetime , values : Dict ) -> datetime :
405396 mkey = "marginal_distribution"
406397 if mkey in values and values [mkey ] is not None :
407398 v = values [mkey ].index .max ()
408- return v + timedelta ( seconds = values [ "offset" ])
399+ return v
409400
410401 def get (self ) -> List [Any ]:
411- return [self .name , self .low , self .high ]
402+ return [self .name , self .low , self .high , self . step , self . offset ]
412403
413404 def sample (self , count : int = 1 ) -> Any :
414405 np .random .seed (self .random_state )
415406 msamples = self .sample_marginal (count )
416407 if msamples is not None :
417408 return msamples
418409
419- delta = self .high - self .low
420- return self .low + delta * np .random .rand (count )
410+ n = (self .high - self .low ) // self .step + 1
411+ samples = np .round (np .random .rand (count ) * n - 0.5 )
412+ return self .low + samples * self .step
421413
422414 def has (self , val : datetime ) -> bool :
423415 return self .low <= val and val <= self .high
424416
425417 def includes (self , other : "Distribution" ) -> bool :
426- return self .min () - timedelta (
427- seconds = self .offset
428- ) <= other .min () and other .max () <= self .max () + timedelta (seconds = self .offset )
418+ return (
419+ self .min () - self .offset <= other .min ()
420+ and other .max () <= self .max () + self .offset
421+ )
429422
430423 def as_constraint (self ) -> Constraints :
431424 return Constraints (
@@ -442,16 +435,6 @@ def min(self) -> Any:
442435 def max (self ) -> Any :
443436 return self .high
444437
445- def __eq__ (self , other : Any ) -> bool :
446- if not isinstance (other , DatetimeDistribution ):
447- return False
448-
449- return (
450- self .name == other .name
451- and self .low == other .low
452- and self .high == other .high
453- )
454-
455438 def dtype (self ) -> str :
456439 return "datetime"
457440
0 commit comments