Skip to content

Commit a4190e6

Browse files
Tianzhang Caibcebererobsdavis
authored
Add Tabnet support (#168)
* first commit for the addition of the TabDDPM plugin * Add DDPM test script and update DDPM plugin * add TabDDPM class and refactor * handle discrete cols and label generation * add hparam space and update tests of DDPM * debug and test DDPM * update TensorDataLoader and training loop * clear bugs * debug for regression tasks * debug for regression tasks; ALL TESTS PASSED * remove the official repo of TabDDPM * passed all pre-commit checks * convert assert to conditional AssertionErrors * added an auto annotation tool * update auto-anno and generate annotations * remove auto-anno and flake8 noqa * add python<3.9 compatible annotations * remove star import * replace builtin type annos to typing annos * resolve py38 compatibility issue * tests/plugins/generic/test_ddpm.py * change TabDDPM method signatures * remove Iterator subscription * update AssertionErrors, add EarlyStop callback, removed additional MLP, update logging * remove TensorDataLoader, update test_ddpm * update EarlyStopping * add TabDDPM tutorial, update TabDDPM plugin and encoders * add TabDDPM tutorial * major update of FeatureEncoder and TabularEncoder * add LogDistribution and LogIntDistribution * update DDPM to use TabularEncoder * update test_tabular_encoder and debug * debug and DDPM tutorial OK * debug LogDistribution and LogIntDistribution * change discrete encoding of BinEncoder to passthrough; passed all tests in test_tabular_encoder * add tabnet to plugins/core/models * add factory.py, let DDPM use TabNet, refactor * update docstrings and refactor * fix type annotation compatibility * make SkipConnection serializable * fix TabularEncoder.activation_layout * remove unnecessary code * fix minor bug and add more nn models in factory * update pandas and torch version requirement * update pandas and torch version requirement * update ddpm tutorial * restore setup.cfg * restore setup.cfg * replace LabelEncoder with OrdinalEncoder * update setup.cfg * update setup.cfg * debug datetimeDistribution * clean * update setup.cfg and goggle test * move DDPM tutorial to tutorials/plugins * update tabnet.py reference * update tab_ddpm * update * try fixing goggle * add more activations * minor fix * update * update * update * update * Update tabular_encoder.py * Update test_goggle.py * Update tabular_encoder.py * update * update * default cat nonlin of goggle <- gumbel_softmax * get_nonlin('softmax') <- GumbelSoftmax() * remove debug logging * update * update * fix merge * update pip upgrade commands in workflows * keep pip version to 23.0.1 in workflows --------- Co-authored-by: Bogdan Cebere <[email protected]> Co-authored-by: Rob <[email protected]>
1 parent 59108bf commit a4190e6

File tree

5 files changed

+10
-5
lines changed

5 files changed

+10
-5
lines changed

.github/workflows/test_full.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ jobs:
2727
if: ${{ matrix.os == 'macos-latest' }}
2828
- name: Install dependencies
2929
run: |
30+
pip install pip==23.0.1
3031
pip install -r prereq.txt
31-
pip install --upgrade pip
3232
- name: Test Core
3333
run: |
3434
pip install .[testing]

.github/workflows/test_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ jobs:
5454
if: ${{ matrix.os == 'macos-latest' }}
5555
- name: Install dependencies
5656
run: |
57+
pip install pip==23.0.1
5758
pip install -r prereq.txt
58-
pip install --upgrade pip
5959
- name: Test Core
6060
run: |
6161
pip install .[testing]

.github/workflows/test_tutorials.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ jobs:
3232
if: ${{ matrix.os == 'macos-latest' }}
3333
- name: Install dependencies
3434
run: |
35+
pip install pip==23.0.1
3536
pip install -r prereq.txt
36-
pip install --upgrade pip
3737
3838
pip install .[all]
3939

src/synthcity/plugins/core/models/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
DatetimeEncoder,
2121
FeatureEncoder,
2222
GaussianQuantileTransformer,
23-
LabelEncoder,
2423
MinMaxScaler,
2524
OneHotEncoder,
25+
OrdinalEncoder,
2626
RobustScaler,
2727
StandardScaler,
2828
)
@@ -74,7 +74,7 @@
7474
FEATURE_ENCODERS = dict(
7575
datetime=DatetimeEncoder,
7676
onehot=OneHotEncoder,
77-
label=LabelEncoder,
77+
ordinal=OrdinalEncoder,
7878
standard=StandardScaler,
7979
minmax=MinMaxScaler,
8080
robust=RobustScaler,

src/synthcity/plugins/core/models/tabnet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# TabNet: Attentive Interpretable Tabular Learning
2+
# Reference:
3+
# - https://arxiv.org/pdf/1908.07442.pdf
4+
# - https://github.com/dreamquark-ai/tabnet
5+
16
# stdlib
27
from typing import List, Optional, Tuple
38

0 commit comments

Comments
 (0)