-
Notifications
You must be signed in to change notification settings - Fork 86
Description
TabDDPM generates numerical features disproportionately at extremes (while categorical features are fine)
When training a new TabDDPM model, I observe that categorical features are generated correctly, but numerical features often collapse to the minimum or maximum values of their domain. This creates artificial spikes at the extremes that are not present in the real data distribution.
Code :
Here is a random dataset that you can use to try out the script :
import pandas as pd
from synthcity.plugins import Plugins
df = pd.read_csv("random_dataset.csv")
if "mrs_90d" in df.columns and "target" not in df.columns:
df = df.rename(columns={"mrs_90d": "target"})
is_clf = "target" in df.columns
ddpm = Plugins().get(
"ddpm",
n_iter=200,
batch_size=32,
is_classification=is_clf,
lr=1e-4,
)
ddpm.fit(df)
syn = ddpm.generate(count=len(df)).dataframe()
syn = syn.round()
num_cols = syn.select_dtypes(include=["number"]).columns
syn[num_cols] = syn[num_cols].astype("int64")
syn.to_csv("tabddpm_synthetic_output.csv", index=False)
print("Saved:", "tabddpm_synthetic_output.csv")
Issue
Example: if age in the real dataset ranges between 20 and 80 years, the synthetic samples contain only values 20 or 80, with nothing in between.
This happens before any rounding (I also apply a global .round() in my pipeline, but the edge collapse is already present without it).
System Information :
- Synthcity version : 0.2.11
- Python version : 3.10
- Hardware : CPU