Skip to content

Commit 46f83ad

Browse files
Tianzhang Caibcebererobsdavis
authored
Optuna hyperparameter optimization tutorial (#178)
* first commit for the addition of the TabDDPM plugin * Add DDPM test script and update DDPM plugin * add TabDDPM class and refactor * handle discrete cols and label generation * add hparam space and update tests of DDPM * debug and test DDPM * update TensorDataLoader and training loop * clear bugs * debug for regression tasks * debug for regression tasks; ALL TESTS PASSED * remove the official repo of TabDDPM * passed all pre-commit checks * convert assert to conditional AssertionErrors * added an auto annotation tool * update auto-anno and generate annotations * remove auto-anno and flake8 noqa * add python<3.9 compatible annotations * remove star import * replace builtin type annos to typing annos * resolve py38 compatibility issue * tests/plugins/generic/test_ddpm.py * change TabDDPM method signatures * remove Iterator subscription * update AssertionErrors, add EarlyStop callback, removed additional MLP, update logging * remove TensorDataLoader, update test_ddpm * update EarlyStopping * add TabDDPM tutorial, update TabDDPM plugin and encoders * add TabDDPM tutorial * major update of FeatureEncoder and TabularEncoder * add LogDistribution and LogIntDistribution * update DDPM to use TabularEncoder * update test_tabular_encoder and debug * debug and DDPM tutorial OK * debug LogDistribution and LogIntDistribution * change discrete encoding of BinEncoder to passthrough; passed all tests in test_tabular_encoder * add tabnet to plugins/core/models * add factory.py, let DDPM use TabNet, refactor * update docstrings and refactor * fix type annotation compatibility * make SkipConnection serializable * fix TabularEncoder.activation_layout * remove unnecessary code * fix minor bug and add more nn models in factory * update pandas and torch version requirement * update pandas and torch version requirement * update ddpm tutorial * restore setup.cfg * restore setup.cfg * replace LabelEncoder with OrdinalEncoder * update setup.cfg * update setup.cfg * debug datetimeDistribution * clean * update setup.cfg and goggle test * move DDPM tutorial to tutorials/plugins * update tabnet.py reference * update tab_ddpm * update distribution, add optuna utils and tutorial * update * Fix plugin type of static_model of fflows * update intlogdistribution and tutorial * try fixing goggle * add more activations * minor fix * update * update * update * update * Update tabular_encoder.py * Update test_goggle.py * Update tabular_encoder.py * update * update tutorial 8 * update * default cat nonlin of goggle <- gumbel_softmax * get_nonlin('softmax') <- GumbelSoftmax() * remove debug logging * update * update * fix merge * fix merge * update pip upgrade commands in workflows * update pip upgrade commands in workflows * keep pip version to 23.0.1 in workflows * keep pip version to 23.0.1 in workflows * update * update * update * update * update * update * fix distribution * update * move upgrading of wheel to prereq.txt * update --------- Co-authored-by: Bogdan Cebere <[email protected]> Co-authored-by: Rob <[email protected]>
1 parent a4190e6 commit 46f83ad

File tree

10 files changed

+440
-83
lines changed

10 files changed

+440
-83
lines changed

.github/workflows/test_full.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
if: ${{ matrix.os == 'macos-latest' }}
2828
- name: Install dependencies
2929
run: |
30-
pip install pip==23.0.1
30+
python -m pip install -U pip
3131
pip install -r prereq.txt
3232
- name: Test Core
3333
run: |

.github/workflows/test_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
if: ${{ matrix.os == 'macos-latest' }}
5555
- name: Install dependencies
5656
run: |
57-
pip install pip==23.0.1
57+
python -m pip install -U pip
5858
pip install -r prereq.txt
5959
- name: Test Core
6060
run: |

.github/workflows/test_tutorials.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
if: ${{ matrix.os == 'macos-latest' }}
3333
- name: Install dependencies
3434
run: |
35-
pip install pip==23.0.1
35+
python -m pip install -U pip
3636
pip install -r prereq.txt
3737
3838
pip install .[all]

prereq.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
numpy
2-
torch<2.0
2+
torch>=1.10.0,<2.0
33
tsai
4+
wheel>=0.40

src/synthcity/plugins/core/distribution.py

Lines changed: 54 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,25 @@ def as_constraint(self) -> Constraints:
111111

112112
@abstractmethod
113113
def min(self) -> Any:
114-
"Get the min value of the distribution"
114+
"""Get the min value of the distribution."""
115115
...
116116

117117
@abstractmethod
118118
def max(self) -> Any:
119-
"Get the max value of the distribution"
119+
"""Get the max value of the distribution."""
120120
...
121121

122-
@abstractmethod
123122
def __eq__(self, other: Any) -> bool:
124-
...
123+
return type(self) == type(other) and self.get() == other.get()
124+
125+
def __contains__(self, item: Any) -> bool:
126+
"""
127+
Example:
128+
>>> dist = CategoricalDistribution(name="foo", choices=["a", "b", "c"])
129+
>>> "a" in dist
130+
True
131+
"""
132+
return self.has(item)
125133

126134
@abstractmethod
127135
def dtype(self) -> str:
@@ -146,7 +154,7 @@ def _validate_choices(cls: Any, v: List, values: Dict) -> List:
146154
raise ValueError(
147155
"Invalid choices for CategoricalDistribution. Provide data or choices params"
148156
)
149-
return v
157+
return sorted(set(v))
150158

151159
def get(self) -> List[Any]:
152160
return [self.name, self.choices]
@@ -176,12 +184,6 @@ def min(self) -> Any:
176184
def max(self) -> Any:
177185
return max(self.choices)
178186

179-
def __eq__(self, other: Any) -> bool:
180-
if not isinstance(other, CategoricalDistribution):
181-
return False
182-
183-
return self.name == other.name and set(self.choices) == set(other.choices)
184-
185187
def dtype(self) -> str:
186188
types = {
187189
"object": 0,
@@ -259,33 +261,24 @@ def min(self) -> Any:
259261
def max(self) -> Any:
260262
return self.high
261263

262-
def __eq__(self, other: Any) -> bool:
263-
if not isinstance(other, type(self)):
264-
return False
265-
266-
return (
267-
self.name == other.name
268-
and self.low == other.low
269-
and self.high == other.high
270-
)
271-
272264
def dtype(self) -> str:
273265
return "float"
274266

275267

276268
class LogDistribution(FloatDistribution):
277269
low: float = np.finfo(np.float64).tiny
278270
high: float = np.finfo(np.float64).max
279-
base: float = 2.0
271+
272+
def get(self) -> List[Any]:
273+
return [self.name, self.low, self.high]
280274

281275
def sample(self, count: int = 1) -> Any:
282276
np.random.seed(self.random_state)
283277
msamples = self.sample_marginal(count)
284278
if msamples is not None:
285279
return msamples
286-
lo = np.log2(self.low) / np.log2(self.base)
287-
hi = np.log2(self.high) / np.log2(self.base)
288-
return self.base ** np.random.uniform(lo, hi, count)
280+
lo, hi = np.log2(self.low), np.log2(self.high)
281+
return 2.0 ** np.random.uniform(lo, hi, count)
289282

290283

291284
class IntegerDistribution(Distribution):
@@ -313,6 +306,12 @@ def _validate_high_thresh(cls: Any, v: int, values: Dict) -> int:
313306
return int(values[mkey].index.max())
314307
return v
315308

309+
@validator("step", always=True)
310+
def _validate_step(cls: Any, v: int, values: Dict) -> int:
311+
if v < 1:
312+
raise ValueError("Step must be greater than 0")
313+
return v
314+
316315
def get(self) -> List[Any]:
317316
return [self.name, self.low, self.high, self.step]
318317

@@ -322,9 +321,9 @@ def sample(self, count: int = 1) -> Any:
322321
if msamples is not None:
323322
return msamples
324323

325-
high = (self.high + 1 - self.low) // self.step
326-
s = np.random.choice(high, count)
327-
return s * self.step + self.low
324+
steps = (self.high - self.low) // self.step
325+
samples = np.random.choice(steps + 1, count)
326+
return samples * self.step + self.low
328327

329328
def has(self, val: Any) -> bool:
330329
return self.low <= val and val <= self.high
@@ -347,34 +346,31 @@ def min(self) -> Any:
347346
def max(self) -> Any:
348347
return self.high
349348

350-
def __eq__(self, other: Any) -> bool:
351-
if not isinstance(other, IntegerDistribution):
352-
return False
353-
354-
return (
355-
self.name == other.name
356-
and self.low == other.low
357-
and self.high == other.high
358-
)
359-
360349
def dtype(self) -> str:
361350
return "int"
362351

363352

364-
class LogIntDistribution(FloatDistribution):
365-
low: float = 1.0
366-
high: float = float(np.iinfo(np.int64).max)
367-
base: float = 2.0
353+
class IntLogDistribution(IntegerDistribution):
354+
low: int = 1
355+
high: int = np.iinfo(np.int64).max
356+
357+
@validator("step", always=True)
358+
def _validate_step(cls: Any, v: int, values: Dict) -> int:
359+
if v != 1:
360+
raise ValueError("Step must be 1 for IntLogDistribution")
361+
return v
362+
363+
def get(self) -> List[Any]:
364+
return [self.name, self.low, self.high]
368365

369366
def sample(self, count: int = 1) -> Any:
370367
np.random.seed(self.random_state)
371368
msamples = self.sample_marginal(count)
372369
if msamples is not None:
373370
return msamples
374-
lo = np.log2(self.low) / np.log2(self.base)
375-
hi = np.log2(self.high) / np.log2(self.base)
376-
s = self.base ** np.random.uniform(lo, hi, count)
377-
return s.astype(int)
371+
lo, hi = np.log2(self.low), np.log2(self.high)
372+
samples = 2.0 ** np.random.uniform(lo, hi, count)
373+
return samples.astype(int)
378374

379375

380376
class DatetimeDistribution(Distribution):
@@ -383,49 +379,46 @@ class DatetimeDistribution(Distribution):
383379
:parts: 1
384380
"""
385381

386-
offset: int = 120
387382
low: datetime = datetime.utcfromtimestamp(0)
388383
high: datetime = datetime.now()
389-
390-
@validator("offset", always=True)
391-
def _validate_offset(cls: Any, v: int) -> int:
392-
if v < 0:
393-
raise ValueError("offset must be greater than 0")
394-
return v
384+
step: timedelta = timedelta(microseconds=1)
385+
offset: timedelta = timedelta(seconds=120)
395386

396387
@validator("low", always=True)
397388
def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime:
398389
mkey = "marginal_distribution"
399390
if mkey in values and values[mkey] is not None:
400391
v = values[mkey].index.min()
401-
return v - timedelta(seconds=values["offset"])
392+
return v
402393

403394
@validator("high", always=True)
404395
def _validate_high_thresh(cls: Any, v: datetime, values: Dict) -> datetime:
405396
mkey = "marginal_distribution"
406397
if mkey in values and values[mkey] is not None:
407398
v = values[mkey].index.max()
408-
return v + timedelta(seconds=values["offset"])
399+
return v
409400

410401
def get(self) -> List[Any]:
411-
return [self.name, self.low, self.high]
402+
return [self.name, self.low, self.high, self.step, self.offset]
412403

413404
def sample(self, count: int = 1) -> Any:
414405
np.random.seed(self.random_state)
415406
msamples = self.sample_marginal(count)
416407
if msamples is not None:
417408
return msamples
418409

419-
delta = self.high - self.low
420-
return self.low + delta * np.random.rand(count)
410+
n = (self.high - self.low) // self.step + 1
411+
samples = np.round(np.random.rand(count) * n - 0.5)
412+
return self.low + samples * self.step
421413

422414
def has(self, val: datetime) -> bool:
423415
return self.low <= val and val <= self.high
424416

425417
def includes(self, other: "Distribution") -> bool:
426-
return self.min() - timedelta(
427-
seconds=self.offset
428-
) <= other.min() and other.max() <= self.max() + timedelta(seconds=self.offset)
418+
return (
419+
self.min() - self.offset <= other.min()
420+
and other.max() <= self.max() + self.offset
421+
)
429422

430423
def as_constraint(self) -> Constraints:
431424
return Constraints(
@@ -442,16 +435,6 @@ def min(self) -> Any:
442435
def max(self) -> Any:
443436
return self.high
444437

445-
def __eq__(self, other: Any) -> bool:
446-
if not isinstance(other, DatetimeDistribution):
447-
return False
448-
449-
return (
450-
self.name == other.name
451-
and self.low == other.low
452-
and self.high == other.high
453-
)
454-
455438
def dtype(self) -> str:
456439
return "datetime"
457440

src/synthcity/plugins/core/models/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DatetimeEncoder,
2121
FeatureEncoder,
2222
GaussianQuantileTransformer,
23+
LabelEncoder,
2324
MinMaxScaler,
2425
OneHotEncoder,
2526
OrdinalEncoder,
@@ -75,6 +76,7 @@
7576
datetime=DatetimeEncoder,
7677
onehot=OneHotEncoder,
7778
ordinal=OrdinalEncoder,
79+
label=LabelEncoder,
7880
standard=StandardScaler,
7981
minmax=MinMaxScaler,
8082
robust=RobustScaler,

src/synthcity/plugins/generic/plugin_ddpm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from synthcity.plugins.core.distribution import (
1919
Distribution,
2020
IntegerDistribution,
21+
IntLogDistribution,
2122
LogDistribution,
22-
LogIntDistribution,
2323
)
2424
from synthcity.plugins.core.models.tabular_ddpm import TabDDPM
2525
from synthcity.plugins.core.models.tabular_encoder import TabularEncoder
@@ -180,11 +180,11 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
180180
"""
181181
return [
182182
LogDistribution(name="lr", low=1e-5, high=1e-1),
183-
LogIntDistribution(name="batch_size", low=256, high=4096),
183+
IntLogDistribution(name="batch_size", low=256, high=4096),
184184
IntegerDistribution(name="num_timesteps", low=10, high=1000),
185-
LogIntDistribution(name="n_iter", low=1000, high=10000),
185+
IntLogDistribution(name="n_iter", low=1000, high=10000),
186186
# IntegerDistribution(name="n_layers_hidden", low=2, high=8),
187-
# LogIntDistribution(name="dim_hidden", low=128, high=1024),
187+
# IntLogDistribution(name="dim_hidden", low=128, high=1024),
188188
]
189189

190190
def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin":

src/synthcity/plugins/time_series/plugin_fflows.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from fflows import FourierFlow
1212

1313
# synthcity absolute
14+
from synthcity.plugins import Plugins
1415
from synthcity.plugins.core.dataloader import DataLoader
1516
from synthcity.plugins.core.distribution import (
1617
CategoricalDistribution,
@@ -24,7 +25,6 @@
2425
from synthcity.plugins.core.models.ts_model import TimeSeriesModel
2526
from synthcity.plugins.core.plugin import Plugin
2627
from synthcity.plugins.core.schema import Schema
27-
from synthcity.plugins.generic import GenericPlugins
2828
from synthcity.utils.constants import DEVICE
2929

3030

@@ -134,9 +134,7 @@ def __init__(
134134
normalize=normalize,
135135
).to(device)
136136

137-
self.static_model = GenericPlugins().get(
138-
self.static_model_name, device=self.device
139-
)
137+
self.static_model = Plugins().get(self.static_model_name, device=self.device)
140138

141139
self.temporal_encoder = TimeSeriesTabularEncoder(
142140
max_clusters=encoder_max_clusters
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# stdlib
2+
from typing import Any, Dict, List
3+
4+
# third party
5+
import optuna
6+
7+
# synthcity absolute
8+
import synthcity.plugins.core.distribution as D
9+
10+
11+
def suggest(trial: optuna.Trial, dist: D.Distribution) -> Any:
12+
if isinstance(dist, D.FloatDistribution):
13+
return trial.suggest_float(dist.name, dist.low, dist.high)
14+
elif isinstance(dist, D.LogDistribution):
15+
return trial.suggest_float(dist.name, dist.low, dist.high, log=True)
16+
elif isinstance(dist, D.IntegerDistribution):
17+
return trial.suggest_int(dist.name, dist.low, dist.high, dist.step)
18+
elif isinstance(dist, D.IntLogDistribution):
19+
return trial.suggest_int(dist.name, dist.low, dist.high, log=True)
20+
elif isinstance(dist, D.CategoricalDistribution):
21+
return trial.suggest_categorical(dist.name, dist.choices)
22+
else:
23+
raise ValueError(f"Unknown dist: {dist}")
24+
25+
26+
def suggest_all(trial: optuna.Trial, distributions: List[D.Distribution]) -> Dict:
27+
return {dist.name: suggest(trial, dist) for dist in distributions}

0 commit comments

Comments
 (0)