Skip to content

Commit 10dbe78

Browse files
authored
Fix full test errors (#255)
* debugging * working fast tests * passing tests * pin lifelines<0.28 as 0.28 does not support python 3.8 * debugging lifelines files error * lifelines==0.27.7 * fix version pin * revert to strict pin * lifelines version constraints as generic as possible * split core tests into fast and slow and increase timeout * split slow tests into two * update version
1 parent a7956c8 commit 10dbe78

36 files changed

+120
-57
lines changed

.github/workflows/test_full.yml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,21 @@ jobs:
3030
run: |
3131
python -m pip install -U pip
3232
pip install -r prereq.txt
33-
- name: Test Core
33+
- name: Test Core - slow part one
34+
timeout-minutes: 1000
3435
run: |
3536
pip install .[testing]
36-
pytest -vvvs --durations=50
37+
pytest -vvvs --durations=50 -m "slow_1"
38+
- name: Test Core - slow part two
39+
timeout-minutes: 1000
40+
run: |
41+
pip install .[testing]
42+
pytest -vvvs --durations=50 -m "slow_2"
43+
- name: Test Core - fast
44+
timeout-minutes: 1000
45+
run: |
46+
pip install .[testing]
47+
pytest -vvvs --durations=50 -m "not slow"
3748
- name: Test GOGGLE
3849
run: |
3950
pip install .[testing,goggle]

setup.cfg

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ install_requires =
3838
scikit-learn>=1.2
3939
nflows>=0.14
4040
numpy>=1.20, <1.24
41-
lifelines>=0.27,!= 0.27.5
41+
lifelines>=0.27,!= 0.27.5, <0.27.8
4242
opacus>=1.3
4343
decaf-synthetic-data>=0.1.6
4444
optuna>=3.1
@@ -117,6 +117,8 @@ testpaths = tests
117117
# Use pytest markers to select/deselect specific tests
118118
markers =
119119
slow: mark tests as slow (deselect with '-m "not slow"')
120+
slow_1: mark tests as slow (deselect with '-m "not slow_1"')
121+
slow_2: mark tests as slow (deselect with '-m "not slow_1"')
120122

121123
[devpi:upload]
122124
# Options for the devpi: PyPI server and packaging tool

src/synthcity/plugins/core/dataloader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -928,12 +928,13 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any:
928928
self.data["observation_times"],
929929
self.data["outcome"],
930930
)
931+
931932
if as_numpy:
932933
longest_observation_seq = max([len(seq) for seq in temporal_data])
933934
return (
934935
np.asarray(static_data),
935936
np.asarray(
936-
pd.concat(temporal_data)
937+
temporal_data
937938
), # TODO: check this works with time series benchmarks
938939
# masked array to handle variable length sequences
939940
ma.vstack(

src/synthcity/plugins/core/plugin.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,6 @@ class PluginLoader:
560560

561561
@validate_arguments
562562
def __init__(self, plugins: list, expected_type: Type, categories: list) -> None:
563-
# self.reload()
564563
global PLUGIN_CATEGORY_REGISTRY
565564
PLUGIN_CATEGORY_REGISTRY = {cat: [] for cat in categories}
566565
self._refresh()
@@ -639,7 +638,6 @@ def list(self) -> List[str]:
639638
for plugin in all_plugins:
640639
if self.get_type(plugin).type() in self._categories:
641640
plugins.append(plugin)
642-
643641
return list(set(plugins))
644642

645643
def types(self) -> List[Type]:

src/synthcity/plugins/privacy/plugin_dpgan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class DPGANPlugin(Plugin):
101101
>>>
102102
>>> plugin.generate(50)
103103
104+
Note: There is a known issue with the training step for training GANs with conditionals with dp_enabled set to True, as is the case for DPGAN.
105+
104106
"""
105107

106108
@validate_arguments(config=dict(arbitrary_types_allowed=True))

src/synthcity/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.2.9"
1+
__version__ = "0.2.10"
22

33
MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
44
PATCH_VERSION = __version__.split(".")[-1]

tests/metrics/test_detection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def test_detect_synth_timeseries(test_plugin: Plugin, evaluator_t: Type) -> None
154154
assert evaluator.direction() == "minimize"
155155

156156

157+
@pytest.mark.slow_1
157158
@pytest.mark.slow
158159
def test_image_support_detection() -> None:
159160
dataset = datasets.MNIST(".", download=True)

tests/metrics/test_performance.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def test_evaluate_performance_classifier(
9494
@pytest.mark.xfail
9595
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
9696
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
97+
@pytest.mark.slow_1
9798
@pytest.mark.slow
9899
def test_evaluate_feature_importance_rank_dist_clf(
99100
distance: str, test_plugin: Plugin
@@ -183,6 +184,7 @@ def test_evaluate_performance_regression(
183184
@pytest.mark.xfail
184185
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
185186
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
187+
@pytest.mark.slow_1
186188
@pytest.mark.slow
187189
def test_evaluate_feature_importance_rank_dist_reg(
188190
distance: str, test_plugin: Plugin
@@ -211,6 +213,7 @@ def test_evaluate_feature_importance_rank_dist_reg(
211213
assert score["pvalue"] > 0
212214

213215

216+
@pytest.mark.slow_1
214217
@pytest.mark.slow
215218
@pytest.mark.parametrize("test_plugin", [Plugins().get("marginal_distributions")])
216219
@pytest.mark.parametrize(
@@ -296,6 +299,7 @@ def test_evaluate_performance_survival_analysis(
296299
@pytest.mark.xfail
297300
@pytest.mark.skipif(sys.platform != "linux", reason="Linux only for faster results")
298301
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher")
302+
@pytest.mark.slow_1
299303
@pytest.mark.slow
300304
def test_evaluate_feature_importance_rank_dist_surv(
301305
distance: str, test_plugin: Plugin
@@ -362,6 +366,7 @@ def test_evaluate_performance_custom_labels(
362366
assert "syn_ood" in good_score
363367

364368

369+
@pytest.mark.slow_1
365370
@pytest.mark.slow
366371
@pytest.mark.parametrize("test_plugin", [Plugins().get("timegan")])
367372
@pytest.mark.parametrize(
@@ -472,6 +477,7 @@ def test_evaluate_performance_time_series_survival(
472477
assert def_score == good_score["syn_id.c_index"] - good_score["syn_id.brier_score"]
473478

474479

480+
@pytest.mark.slow_1
475481
@pytest.mark.slow
476482
def test_image_support_perf() -> None:
477483
dataset = datasets.MNIST(".", download=True)

tests/plugins/core/models/test_tabular_gan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def test_gan_generation_with_early_stopping(patience_metric: Tuple[str, str]) ->
174174
assert generated.shape == (10, X.shape[1])
175175

176176

177+
@pytest.mark.slow_1
177178
@pytest.mark.slow
178179
def test_gan_sampling_adjustment() -> None:
179180
X = get_airfoil_dataset()

tests/plugins/core/models/test_ts_gan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def test_ts_gan_generation(source: Any) -> None:
129129
assert observation_times_gen.shape == (10, temporal.shape[1])
130130

131131

132+
@pytest.mark.slow_1
132133
@pytest.mark.slow
133134
@pytest.mark.parametrize("source", [GoogleStocksDataloader])
134135
def test_ts_gan_generation_schema(source: Any) -> None:

0 commit comments

Comments
 (0)