Skip to content

Commit

Permalink
add aggregate by mean (#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt authored Aug 26, 2023
1 parent d69b295 commit 1ebc265
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
10 changes: 8 additions & 2 deletions bofire/data_models/domain/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,11 @@ def coerce_invalids(self, experiments: pd.DataFrame) -> pd.DataFrame:
return experiments

def aggregate_by_duplicates(
self, experiments: pd.DataFrame, prec: int, delimiter: str = "-"
self,
experiments: pd.DataFrame,
prec: int,
delimiter: str = "-",
method: Literal["mean", "median"] = "mean",
) -> Tuple[pd.DataFrame, list]:
"""Aggregate the dataframe by duplicate experiments
Expand All @@ -417,6 +421,8 @@ def aggregate_by_duplicates(
Tuple[pd.DataFrame, list]: Dataframe holding the aggregated experiments, list of lists holding the labcodes of the duplicates
"""
# prepare the parent frame
if method not in ["mean", "median"]:
raise ValueError(f"Unknown aggregation type provided: {method}")

preprocessed = self.outputs.preprocess_experiments_any_valid_output(experiments)
assert preprocessed is not None
Expand All @@ -437,7 +443,7 @@ def aggregate_by_duplicates(

# group and aggregate
agg: Dict[str, Any] = {
feat: "mean" for feat in self.get_feature_keys(ContinuousOutput)
feat: method for feat in self.get_feature_keys(ContinuousOutput)
}
agg["labcode"] = lambda x: delimiter.join(sorted(x.tolist()))
for feat in self.get_feature_keys(Output):
Expand Down
29 changes: 26 additions & 3 deletions tests/bofire/data_models/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,8 @@ def test_coerce_invalids():
assert_frame_equal(experiments, expected, check_dtype=False)


def test_aggregate_by_duplicates():
@pytest.mark.parametrize("method", ["mean", "median"])
def test_aggregate_by_duplicates(method):
# dataframe with duplicates
full = pd.DataFrame.from_dict(
{
Expand All @@ -625,7 +626,9 @@ def test_aggregate_by_duplicates():
domain = Domain(
inputs=Inputs(features=[if1, if2]), outputs=Outputs(features=[of1, of2])
)
aggregated, duplicated_labcodes = domain.aggregate_by_duplicates(full, prec=2)
aggregated, duplicated_labcodes = domain.aggregate_by_duplicates(
full, prec=2, method=method
)
assert duplicated_labcodes == [["1", "4"]]
assert_frame_equal(
aggregated, expected_aggregated, check_dtype=False, check_like=True
Expand Down Expand Up @@ -655,10 +658,30 @@ def test_aggregate_by_duplicates():
domain = Domain(
inputs=Inputs(features=[if1, if2]), outputs=Outputs(features=[of1, of2])
)
aggregated, duplicated_labcodes = domain.aggregate_by_duplicates(full, prec=2)
aggregated, duplicated_labcodes = domain.aggregate_by_duplicates(
full, prec=2, method=method
)
assert duplicated_labcodes == []


def test_aggregate_by_duplicates_error():
full = pd.DataFrame.from_dict(
{
"x1": [1.0, 2.0, 3.0, 1.0],
"x2": [1.0, 2.0, 3.0, 1.0],
"out1": [4.0, 5.0, 6.0, 3.0],
"out2": [-4.0, -5.0, -6.0, -3.0],
"valid_out1": [1, 1, 1, 1],
"valid_out2": [1, 1, 1, 1],
}
)
domain = Domain(
inputs=Inputs(features=[if1, if2]), outputs=Outputs(features=[of1, of2])
)
with pytest.raises(ValueError, match="Unknown aggregation type provided: 25"):
domain.aggregate_by_duplicates(full, prec=2, method="25")


domain = Domain(
inputs=Inputs(features=[if1, if2]), outputs=Outputs(features=[of1, of2, of1_, of2_])
)
Expand Down

0 comments on commit 1ebc265

Please sign in to comment.