diff --git a/src/synthcity/plugins/privacy/plugin_privbayes.py b/src/synthcity/plugins/privacy/plugin_privbayes.py index fd4ab8bb..d4bb6da4 100644 --- a/src/synthcity/plugins/privacy/plugin_privbayes.py +++ b/src/synthcity/plugins/privacy/plugin_privbayes.py @@ -153,15 +153,26 @@ def _encode(self, data: pd.DataFrame) -> Any: encoders = {} for col in data.columns: - if len(data[col].unique()) < self.n_bins or data[col].dtype.name not in [ - "object", - "category", - ]: + dtype = data[col].dtype.name + n_unique = data[col].nunique() + + # Case 1 : categorical variables (string or category) + if dtype in ["object", "category"]: encoders[col] = { "type": "categorical", "model": LabelEncoder().fit(data[col]), } data[col] = encoders[col]["model"].transform(data[col]) + + # Case 2 : discrete numerical variables with few distinct values + elif np.issubdtype(data[col].dtype, np.integer) and n_unique <= self.n_bins: + encoders[col] = { + "type": "categorical", + "model": LabelEncoder().fit(data[col]), + } + data[col] = encoders[col]["model"].transform(data[col]) + + # Case 3 : continuous variables (float or many unique values) else: col_data = pd.cut(data[col], bins=self.n_bins) encoders[col] = {