Skip to content

Commit 59108bf

Browse files
Tianzhang Caibcebere
andauthored
Update of data_encoder and tabular_encoder (#159)
* first commit for the addition of the TabDDPM plugin * Add DDPM test script and update DDPM plugin * add TabDDPM class and refactor * handle discrete cols and label generation * add hparam space and update tests of DDPM * debug and test DDPM * update TensorDataLoader and training loop * clear bugs * debug for regression tasks * debug for regression tasks; ALL TESTS PASSED * remove the official repo of TabDDPM * passed all pre-commit checks * convert assert to conditional AssertionErrors * added an auto annotation tool * update auto-anno and generate annotations * remove auto-anno and flake8 noqa * add python<3.9 compatible annotations * remove star import * replace builtin type annos to typing annos * resolve py38 compatibility issue * tests/plugins/generic/test_ddpm.py * change TabDDPM method signatures * remove Iterator subscription * update AssertionErrors, add EarlyStop callback, removed additional MLP, update logging * remove TensorDataLoader, update test_ddpm * update EarlyStopping * add TabDDPM tutorial, update TabDDPM plugin and encoders * add TabDDPM tutorial * major update of FeatureEncoder and TabularEncoder * add LogDistribution and LogIntDistribution * update DDPM to use TabularEncoder * update test_tabular_encoder and debug * debug and DDPM tutorial OK * debug LogDistribution and LogIntDistribution * change discrete encoding of BinEncoder to passthrough; passed all tests in test_tabular_encoder * add tabnet to plugins/core/models * add factory.py, let DDPM use TabNet, refactor * update docstrings and refactor * fix type annotation compatibility * make SkipConnection serializable * fix TabularEncoder.activation_layout * remove unnecessary code * fix minor bug and add more nn models in factory * update pandas and torch version requirement * update ddpm tutorial * restore setup.cfg * update setup.cfg * debug datetimeDistribution * clean * update setup.cfg and goggle test * move DDPM tutorial to tutorials/plugins * update tab_ddpm * update * try fixing goggle * add more activations * minor fix * update * update * update * Update tabular_encoder.py * Update test_goggle.py * Update tabular_encoder.py * update * update * default cat nonlin of goggle <- gumbel_softmax * get_nonlin('softmax') <- GumbelSoftmax() * remove debug logging * update --------- Co-authored-by: Bogdan Cebere <[email protected]>
1 parent da12d59 commit 59108bf

File tree

24 files changed

+3441
-715
lines changed

24 files changed

+3441
-715
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,4 @@ lightning_logs
6767
generated
6868
MNIST
6969
cifar-10*
70-
src/test.py
70+
local_test*.py

setup.cfg

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ install_requires =
3535
scikit-learn>=1.0
3636
nflows>=0.14
3737
pandas>=1.3,<2.0
38-
torch>=1.10,<2.0
38+
torch>=1.10.0,<2.0
3939
numpy>=1.20
4040
lifelines>=0.27
4141
opacus>=1.3
@@ -59,7 +59,6 @@ install_requires =
5959
tsai; python_version>"3.7"
6060
importlib-metadata; python_version<"3.8"
6161

62-
6362
[options.packages.find]
6463
where = src
6564
exclude =

src/synthcity/plugins/core/dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# synthcity absolute
1717
from synthcity.plugins.core.constraints import Constraints
1818
from synthcity.plugins.core.dataset import FlexibleDataset, TensorDataset
19-
from synthcity.plugins.core.models.data_encoder import DatetimeEncoder
19+
from synthcity.plugins.core.models.feature_encoder import DatetimeEncoder
2020
from synthcity.utils.compression import compress_dataset, decompress_dataset
2121
from synthcity.utils.serialization import dataframe_hash
2222

src/synthcity/plugins/core/distribution.py

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def sample(self, count: int = 1) -> Any:
157157
if msamples is not None:
158158
return msamples
159159

160-
return np.random.choice(self.choices, count).tolist()
160+
return np.random.choice(self.choices, count)
161161

162162
def has(self, val: Any) -> bool:
163163
return val in self.choices
@@ -209,8 +209,8 @@ class FloatDistribution(Distribution):
209209
:parts: 1
210210
"""
211211

212-
low: float = np.iinfo(np.int64).min
213-
high: float = np.iinfo(np.int64).max
212+
low: float = np.finfo(np.float64).min
213+
high: float = np.finfo(np.float64).max
214214

215215
@validator("low", always=True)
216216
def _validate_low_thresh(cls: Any, v: float, values: Dict) -> float:
@@ -260,7 +260,7 @@ def max(self) -> Any:
260260
return self.high
261261

262262
def __eq__(self, other: Any) -> bool:
263-
if not isinstance(other, FloatDistribution):
263+
if not isinstance(other, type(self)):
264264
return False
265265

266266
return (
@@ -273,6 +273,21 @@ def dtype(self) -> str:
273273
return "float"
274274

275275

276+
class LogDistribution(FloatDistribution):
277+
low: float = np.finfo(np.float64).tiny
278+
high: float = np.finfo(np.float64).max
279+
base: float = 2.0
280+
281+
def sample(self, count: int = 1) -> Any:
282+
np.random.seed(self.random_state)
283+
msamples = self.sample_marginal(count)
284+
if msamples is not None:
285+
return msamples
286+
lo = np.log2(self.low) / np.log2(self.base)
287+
hi = np.log2(self.high) / np.log2(self.base)
288+
return self.base ** np.random.uniform(lo, hi, count)
289+
290+
276291
class IntegerDistribution(Distribution):
277292
"""
278293
.. inheritance-diagram:: synthcity.plugins.core.distribution.IntegerDistribution
@@ -307,8 +322,9 @@ def sample(self, count: int = 1) -> Any:
307322
if msamples is not None:
308323
return msamples
309324

310-
choices = [val for val in range(self.low, self.high + 1, self.step)]
311-
return np.random.choice(choices, count).tolist()
325+
high = (self.high + 1 - self.low) // self.step
326+
s = np.random.choice(high, count)
327+
return s * self.step + self.low
312328

313329
def has(self, val: Any) -> bool:
314330
return self.low <= val and val <= self.high
@@ -345,7 +361,20 @@ def dtype(self) -> str:
345361
return "int"
346362

347363

348-
OFFSET = 120
364+
class LogIntDistribution(FloatDistribution):
365+
low: float = 1.0
366+
high: float = float(np.iinfo(np.int64).max)
367+
base: float = 2.0
368+
369+
def sample(self, count: int = 1) -> Any:
370+
np.random.seed(self.random_state)
371+
msamples = self.sample_marginal(count)
372+
if msamples is not None:
373+
return msamples
374+
lo = np.log2(self.low) / np.log2(self.base)
375+
hi = np.log2(self.high) / np.log2(self.base)
376+
s = self.base ** np.random.uniform(lo, hi, count)
377+
return s.astype(int)
349378

350379

351380
class DatetimeDistribution(Distribution):
@@ -354,24 +383,29 @@ class DatetimeDistribution(Distribution):
354383
:parts: 1
355384
"""
356385

386+
offset: int = 120
357387
low: datetime = datetime.utcfromtimestamp(0)
358388
high: datetime = datetime.now()
359389

390+
@validator("offset", always=True)
391+
def _validate_offset(cls: Any, v: int) -> int:
392+
if v < 0:
393+
raise ValueError("offset must be greater than 0")
394+
return v
395+
360396
@validator("low", always=True)
361397
def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime:
362398
mkey = "marginal_distribution"
363399
if mkey in values and values[mkey] is not None:
364400
v = values[mkey].index.min()
365-
366-
return v - timedelta(seconds=OFFSET)
401+
return v - timedelta(seconds=values["offset"])
367402

368403
@validator("high", always=True)
369404
def _validate_high_thresh(cls: Any, v: datetime, values: Dict) -> datetime:
370405
mkey = "marginal_distribution"
371406
if mkey in values and values[mkey] is not None:
372407
v = values[mkey].index.max()
373-
374-
return v + timedelta(seconds=OFFSET)
408+
return v + timedelta(seconds=values["offset"])
375409

376410
def get(self) -> List[Any]:
377411
return [self.name, self.low, self.high]
@@ -382,23 +416,16 @@ def sample(self, count: int = 1) -> Any:
382416
if msamples is not None:
383417
return msamples
384418

385-
samples = np.random.uniform(
386-
datetime.timestamp(self.low), datetime.timestamp(self.high), count
387-
)
388-
389-
samples_dt = []
390-
for s in samples:
391-
samples_dt.append(datetime.fromtimestamp(s))
392-
393-
return samples_dt
419+
delta = self.high - self.low
420+
return self.low + delta * np.random.rand(count)
394421

395422
def has(self, val: datetime) -> bool:
396423
return self.low <= val and val <= self.high
397424

398425
def includes(self, other: "Distribution") -> bool:
399426
return self.min() - timedelta(
400-
seconds=OFFSET
401-
) <= other.min() and other.max() <= self.max() + timedelta(seconds=OFFSET)
427+
seconds=self.offset
428+
) <= other.min() and other.max() <= self.max() + timedelta(seconds=self.offset)
402429

403430
def as_constraint(self) -> Constraints:
404431
return Constraints(

src/synthcity/plugins/core/models/convnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ class ConvNet(nn.Module):
6969
@validate_arguments(config=dict(arbitrary_types_allowed=True))
7070
def __init__(
7171
self,
72-
task_type: str,
73-
model: nn.Module, # classification/regression
72+
task_type: str, # classification/regression
73+
model: nn.Module,
7474
lr: float = 1e-3,
7575
weight_decay: float = 1e-3,
7676
opt_betas: tuple = (0.9, 0.999),

src/synthcity/plugins/core/models/data_encoder.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

0 commit comments

Comments
 (0)