Skip to content

Commit 227bd54

Browse files
authored
Improvements&bugfixing (#118)
* create serde folder if it is missing * add generate random seed * cleaunp
1 parent d3b1014 commit 227bd54

File tree

4 files changed

+23
-1
lines changed

4 files changed

+23
-1
lines changed

src/synthcity/plugins/core/plugin.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ def fit(self, X: Union[DataLoader, pd.DataFrame], *args: Any, **kwargs: Any) ->
211211
if "cond" in kwargs and kwargs["cond"] is not None:
212212
self.expecting_conditional = True
213213

214+
enable_reproducible_results(self.random_state)
215+
214216
self.data_info = X.info()
215217

216218
self._schema = Schema(
@@ -262,6 +264,7 @@ def generate(
262264
self,
263265
count: Optional[int] = None,
264266
constraints: Optional[Constraints] = None,
267+
random_state: Optional[int] = None,
265268
**kwargs: Any,
266269
) -> DataLoader:
267270
"""Synthetic data generation method.
@@ -301,6 +304,9 @@ def generate(
301304
>>>
302305
>>> assert (syn_data["InterestingFeature"] == 0).all()
303306
307+
random_state: optional int.
308+
Optional random seed to use.
309+
304310
Returns:
305311
<count> synthetic samples
306312
"""
@@ -310,6 +316,9 @@ def generate(
310316
if self._schema is None:
311317
raise RuntimeError("Fit the model first")
312318

319+
if random_state is not None:
320+
enable_reproducible_results(random_state)
321+
313322
has_gen_cond = "cond" in kwargs and kwargs["cond"] is not None
314323
if has_gen_cond and not self.expecting_conditional:
315324
raise RuntimeError(

src/synthcity/utils/serialization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ def load(buff: bytes) -> Any:
1717

1818

1919
def save_to_file(path: Union[str, Path], model: Any) -> Any:
20+
path = Path(path)
21+
ppath = path.absolute().parent
22+
23+
if not ppath.exists():
24+
ppath.mkdir(parents=True, exist_ok=True)
25+
2026
with open(path, "wb") as f:
2127
return cloudpickle.dump(model, f)
2228

src/synthcity/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.8"
1+
__version__ = "0.1.9"
22

33
MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
44
MINOR_VERSION = __version__.split(".")[-1]

tests/plugins/generic/test_ctgan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
7575
assert len(X_gen) == 50
7676
assert test_plugin.schema_includes(X_gen)
7777

78+
# generate with random seed
79+
X_gen1 = test_plugin.generate(50, random_state=0)
80+
X_gen2 = test_plugin.generate(50, random_state=0)
81+
X_gen3 = test_plugin.generate(50)
82+
assert (X_gen1.numpy() == X_gen2.numpy()).all()
83+
assert (X_gen1.numpy() != X_gen3.numpy()).any()
84+
7885

7986
@pytest.mark.parametrize(
8087
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)

0 commit comments

Comments
 (0)