Skip to content

Commit cf6ea56

Browse files
authored
Augmentation benchmark (#150)
* Add new benchmark code * Merge main into branch * Augmentation benchmark added * Clean up * Clean up * Remove unnecessary tutorial file * Clean up * clean up * Debug test and clean up * Added new tests for augmentation benchmark * Added new metric api tests for augmentation * clean up * clean up * version bumped and clean up * clean up docstrings
1 parent b82baca commit cf6ea56

File tree

13 files changed

+610
-45
lines changed

13 files changed

+610
-45
lines changed

src/synthcity/benchmark/__init__.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import platform
55
import random
6+
from copy import copy
67
from pathlib import Path
78
from typing import Any, Dict, List, Optional, Tuple
89

@@ -14,6 +15,7 @@
1415

1516
# synthcity absolute
1617
import synthcity.logger as log
18+
from synthcity.benchmark.utils import augment_data
1719
from synthcity.metrics import Metrics
1820
from synthcity.metrics.scores import ScoreEvaluator
1921
from synthcity.plugins import Plugins
@@ -48,10 +50,16 @@ def evaluate(
4850
synthetic_constraints: Optional[Constraints] = None,
4951
synthetic_cache: bool = True,
5052
synthetic_reuse_if_exists: bool = True,
53+
augmented_reuse_if_exists: bool = True,
5154
task_type: str = "classification", # classification, regression, survival_analysis, time_series
5255
workspace: Path = Path("workspace"),
56+
augmentation_rule: str = "equal",
57+
strict_augmentation: bool = False,
58+
ad_hoc_augment_vals: Optional[Dict] = None,
59+
use_metric_cache: bool = True,
5360
**generate_kwargs: Any,
5461
) -> pd.DataFrame:
62+
5563
"""Benchmark the performance of several algorithms.
5664
5765
Args:
@@ -80,11 +88,21 @@ def evaluate(
8088
synthetic_cache: bool
8189
Enable experiment caching
8290
synthetic_reuse_if_exists: bool
83-
If the current synthetic dataset is cached, it will be reused for the experiments.
91+
If the current synthetic dataset is cached, it will be reused for the experiments. Defaults to True.
92+
augmented_reuse_if_exists: bool
93+
If the current augmented dataset is cached, it will be reused for the experiments. Defaults to True.
8494
task_type: str
8595
The type of problem. Relevant for evaluating the downstream models with the correct metrics. Valid tasks are: "classification", "regression", "survival_analysis", "time_series", "time_series_survival".
8696
workspace: Path
8797
Path for caching experiments. Default: "workspace".
98+
augmentation_rule: str
99+
The rule used to achieve the desired proportion records with each value in the fairness column. Possible values are: 'equal', 'log', and 'ad-hoc'. Defaults to "equal".
100+
strict_augmentation: bool
101+
Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False.
102+
ad_hoc_augment_vals: Dict
103+
A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to None.
104+
use_metric_cache: bool
105+
If the current metric has been previously run and is cached, it will be reused for the experiments. Defaults to True.
88106
plugin_kwargs:
89107
Optional kwargs for each algorithm. Example {"adsgan": {"n_iter": 10}},
90108
"""
@@ -115,6 +133,17 @@ def evaluate(
115133
hash_object = hashlib.sha256(kwargs_hash_raw)
116134
kwargs_hash = hash_object.hexdigest()
117135

136+
augmentation_arguments = {
137+
"augmentation_rule": augmentation_rule,
138+
"strict_augmentation": strict_augmentation,
139+
"ad_hoc_augment_vals": ad_hoc_augment_vals,
140+
}
141+
augmentation_arguments_hash_raw = json.dumps(
142+
copy(augmentation_arguments), sort_keys=True
143+
).encode()
144+
augmentation_hash_object = hashlib.sha256(augmentation_arguments_hash_raw)
145+
augmentation_hash = augmentation_hash_object.hexdigest()
146+
118147
repeats_list = list(range(repeats))
119148
random.shuffle(repeats_list)
120149

@@ -126,14 +155,22 @@ def evaluate(
126155

127156
clear_cache()
128157

129-
cache_file = (
158+
X_syn_cache_file = (
130159
workspace
131160
/ f"{experiment_name}_{testcase}_{plugin}_{kwargs_hash}_{platform.python_version()}_{repeat}.bkp"
132161
)
133162
generator_file = (
134163
workspace
135164
/ f"{experiment_name}_{testcase}_{plugin}_{kwargs_hash}_{platform.python_version()}_generator_{repeat}.bkp"
136165
)
166+
X_augment_cache_file = (
167+
workspace
168+
/ f"{experiment_name}_{testcase}_{plugin}_augmentation_{augmentation_hash}_{kwargs_hash}_{platform.python_version()}_{repeat}.bkp"
169+
)
170+
augment_generator_file = (
171+
workspace
172+
/ f"{experiment_name}_{testcase}_{plugin}_augmentation_{augmentation_hash}_{kwargs_hash}_{platform.python_version()}_generator_{repeat}.bkp"
173+
)
137174

138175
log.info(
139176
f"[testcase] Experiment repeat: {repeat} task type: {task_type} Train df hash = {experiment_name}"
@@ -152,8 +189,8 @@ def evaluate(
152189
if synthetic_cache:
153190
save_to_file(generator_file, generator)
154191

155-
if cache_file.exists() and synthetic_reuse_if_exists:
156-
X_syn = load_from_file(cache_file)
192+
if X_syn_cache_file.exists() and synthetic_reuse_if_exists:
193+
X_syn = load_from_file(X_syn_cache_file)
157194
else:
158195
try:
159196
X_syn = generator.generate(
@@ -168,13 +205,68 @@ def evaluate(
168205
continue
169206

170207
if synthetic_cache:
171-
save_to_file(cache_file, X_syn)
208+
save_to_file(X_syn_cache_file, X_syn)
209+
210+
# Augmentation
211+
if metrics and any(
212+
"augmentation" in metric
213+
for metric in [x for v in metrics.values() for x in v]
214+
):
215+
if augment_generator_file.exists() and augmented_reuse_if_exists:
216+
augment_generator = load_from_file(augment_generator_file)
217+
else:
218+
augment_generator = Plugins(categories=plugin_cats).get(
219+
plugin,
220+
**kwargs,
221+
)
222+
try:
223+
if not X.get_fairness_column():
224+
raise ValueError(
225+
"To use the augmentation metrics, `fairness_column` must be set to a string representing the name of a column in the DataLoader."
226+
)
227+
augment_generator.fit(
228+
X.train(),
229+
cond=X.train()[X.get_fairness_column()],
230+
)
231+
except BaseException as e:
232+
log.critical(
233+
f"[{plugin}][take {repeat}] failed to fit augmentation generator: {e}"
234+
)
235+
continue
236+
if synthetic_cache:
237+
save_to_file(augment_generator_file, augment_generator)
238+
239+
if X_augment_cache_file.exists() and augmented_reuse_if_exists:
240+
X_augmented = load_from_file(X_augment_cache_file)
241+
else:
242+
try:
243+
X_augmented = augment_data(
244+
X.train(),
245+
augment_generator,
246+
rule=augmentation_rule,
247+
strict=strict_augmentation,
248+
ad_hoc_augment_vals=ad_hoc_augment_vals,
249+
**generate_kwargs,
250+
)
251+
if len(X_augmented) == 0:
252+
raise RuntimeError("Plugin failed to generate data")
253+
except BaseException as e:
254+
log.critical(
255+
f"[{plugin}][take {repeat}] failed to generate augmentation data: {e}"
256+
)
257+
continue
258+
if synthetic_cache:
259+
save_to_file(X_augment_cache_file, X_augmented)
260+
else:
261+
X_augmented = None
172262
evaluation = Metrics.evaluate(
173-
X_test if X_test is not None else X,
263+
X_test if X_test is not None else X.test(),
174264
X_syn,
265+
X_augmented,
175266
metrics=metrics,
176267
task_type=task_type,
177268
workspace=workspace,
269+
use_cache=use_metric_cache,
178270
)
179271

180272
mean_score = evaluation["mean"].to_dict()

src/synthcity/benchmark/utils.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# stdlib
2+
import math
3+
from copy import copy
4+
from typing import Any, Dict, Optional
5+
6+
# third party
7+
import numpy as np
8+
import pandas as pd
9+
from pydantic import validate_arguments
10+
from typing_extensions import Literal
11+
12+
# synthcity absolute
13+
from synthcity.plugins.core.constraints import Constraints
14+
from synthcity.plugins.core.dataloader import DataLoader
15+
16+
17+
def calculate_fair_aug_sample_size(
18+
X_train: pd.DataFrame,
19+
fairness_column: Optional[str], # a categorical column of K levels
20+
rule: Literal[
21+
"equal", "log", "ad-hoc"
22+
], # TODO: Confirm are there any more methods to include
23+
ad_hoc_augment_vals: Optional[
24+
Dict[Any, int]
25+
] = None, # Only required for rule == "ad-hoc"
26+
) -> Dict:
27+
"""Calculate how many samples to augment.
28+
29+
Args:
30+
X_train (pd.DataFrame): The real dataset to be augmented.
31+
fairness_column (str): The column name of the column to test the fairness of a downstream model with respect to.
32+
rule (Literal["equal", "log", "ad-hoc"]): The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to "equal".
33+
ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional): A dictionary containing the number of each class to augment the real data with. If using rule="ad-hoc" this function returns ad_hoc_augment_vals, otherwise this parameter is ignored. Defaults to {}.
34+
35+
Returns:
36+
Dict: A dictionary containing the number of each class to augment the real data with.
37+
"""
38+
39+
# the majority class is unchanged
40+
if rule == "equal":
41+
# number of sample will be the same for each value in the fairness column after augmentation
42+
# N_aug(i) = N_ang(j) for all i and j in value in the fairness column
43+
fairness_col_counts = X_train[fairness_column].value_counts()
44+
majority_size = fairness_col_counts.max()
45+
augmentation_counts = {
46+
fair_col_val: (majority_size - fairness_col_counts.loc[fair_col_val])
47+
for fair_col_val in fairness_col_counts.index
48+
}
49+
elif rule == "log":
50+
# number of samples in aug data will be proportional to the log frequency in the real data.
51+
# Note: taking the log makes the distribution more even.
52+
# N_aug(i) is proportional to log(N_real(i))
53+
fairness_col_counts = X_train[fairness_column].value_counts()
54+
majority_size = fairness_col_counts.max()
55+
log_coefficient = majority_size / math.log(majority_size)
56+
57+
augmentation_counts = {
58+
fair_col_val: (
59+
majority_size - round(math.log(fair_col_count) * log_coefficient)
60+
)
61+
for fair_col_val, fair_col_count in fairness_col_counts.items()
62+
}
63+
elif rule == "ad-hoc":
64+
# use user-specified values to augment
65+
if not ad_hoc_augment_vals:
66+
raise ValueError(
67+
"When augmenting with an `ad-hoc` method, ad_hoc_augment_vals must be a dictionary, where the dictionary keys are the values of the fairness_column and the dictionary values are the number of records to augment."
68+
)
69+
else:
70+
if not set(ad_hoc_augment_vals.keys()).issubset(
71+
set(X_train[fairness_column].values)
72+
):
73+
raise ValueError(
74+
"ad_hoc_augment_vals must be a dictionary, where the dictionary keys are the values of the fairness_column and the dictionary values are the number of records to augment."
75+
)
76+
elif set(X_train[fairness_column].values) != set(
77+
ad_hoc_augment_vals.keys()
78+
):
79+
ad_hoc_augment_vals = {
80+
k: v
81+
for k, v in ad_hoc_augment_vals.items()
82+
if k in set(X_train[fairness_column].values)
83+
}
84+
85+
augmentation_counts = ad_hoc_augment_vals
86+
87+
return augmentation_counts
88+
89+
90+
@validate_arguments(config=dict(arbitrary_types_allowed=True))
91+
def _generate_synthetic_data(
92+
X_train: DataLoader,
93+
augment_generator: Any,
94+
strict: bool = True,
95+
rule: Literal["equal", "log", "ad-hoc"] = "equal",
96+
ad_hoc_augment_vals: Optional[
97+
Dict[Any, int]
98+
] = None, # Only required for rule == "ad-hoc"
99+
synthetic_constraints: Optional[Constraints] = None,
100+
**generate_kwargs: Any,
101+
) -> pd.DataFrame:
102+
"""Generates synthetic data
103+
104+
Args:
105+
X_train (DataLoader): The dataset used to train the downstream model.
106+
augment_generator (Any): The synthetic model to be used to generate the synthetic portion of the augmented dataset.
107+
strict (bool, optional): Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False.
108+
rule (Literal["equal", "log", "ad-hoc"): The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to "equal".
109+
ad_hoc_augment_vals (Dict[ Union[int, str], int ], optional): A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to {}.
110+
111+
Returns:
112+
pd.DataFrame: The generated synthetic data.
113+
"""
114+
augmentation_counts = calculate_fair_aug_sample_size(
115+
X_train.dataframe(),
116+
X_train.get_fairness_column(),
117+
rule,
118+
ad_hoc_augment_vals=ad_hoc_augment_vals,
119+
)
120+
if not strict:
121+
# set count equal to the total number of records required according to calculate_fair_aug_sample_size
122+
count = sum(augmentation_counts.values())
123+
cond = pd.Series(
124+
np.repeat(
125+
list(augmentation_counts.keys()), list(augmentation_counts.values())
126+
)
127+
)
128+
syn_data = augment_generator.generate(
129+
count=count,
130+
cond=cond,
131+
constraints=synthetic_constraints,
132+
**generate_kwargs,
133+
).dataframe()
134+
else:
135+
syn_data_list = []
136+
for fairness_value, count in augmentation_counts.items():
137+
if count > 0:
138+
constraints = Constraints(
139+
rules=[(X_train.get_fairness_column(), "==", fairness_value)]
140+
)
141+
syn_data_list.append(
142+
augment_generator.generate(
143+
count=count, constraints=constraints
144+
).dataframe()
145+
)
146+
syn_data = pd.concat(syn_data_list)
147+
return syn_data
148+
149+
150+
@validate_arguments(config=dict(arbitrary_types_allowed=True))
151+
def augment_data(
152+
X_train: DataLoader,
153+
augment_generator: Any,
154+
strict: bool = False,
155+
rule: Literal["equal", "log", "ad-hoc"] = "equal",
156+
ad_hoc_augment_vals: Optional[
157+
Dict[Any, int]
158+
] = None, # Only required for rule == "ad-hoc"
159+
synthetic_constraints: Optional[Constraints] = None,
160+
**generate_kwargs: Any,
161+
) -> DataLoader:
162+
"""Augment the real data with generated synthetic data
163+
164+
Args:
165+
X_train (DataLoader): The ground truth DataLoader to augment with synthetic data.
166+
augment_generator (Any): The synthetic model to be used to generate the synthetic portion of the augmented dataset.
167+
strict (bool, optional): Flag to ensure that the condition for generating synthetic data is strictly met. Defaults to False.
168+
rule (Literal["equal", "log", "ad-hoc"): The rule used to achieve the desired proportion records with each value in the fairness column. Defaults to "equal".
169+
ad_hoc_augment_vals (Dict[Union[int, str], int], optional): A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to None.
170+
synthetic_constraints (Optional[Constraints]): Constraints placed on the generation of the synthetic data. Defaults to None.
171+
172+
Returns:
173+
DataLoader: The augmented dataset and labels.
174+
"""
175+
syn_data = _generate_synthetic_data(
176+
X_train,
177+
augment_generator,
178+
strict=strict,
179+
rule=rule,
180+
ad_hoc_augment_vals=ad_hoc_augment_vals,
181+
synthetic_constraints=synthetic_constraints,
182+
**generate_kwargs,
183+
)
184+
185+
augmented_data_loader = copy(X_train)
186+
augmented_data_loader.data = pd.concat(
187+
[
188+
X_train.data,
189+
syn_data,
190+
]
191+
)
192+
193+
return augmented_data_loader

0 commit comments

Comments
 (0)