Skip to content

Commit

Permalink
add cat unfold
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul-Saves committed Oct 8, 2024
1 parent 9c17f0d commit 82f9465
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
15 changes: 11 additions & 4 deletions smt_design_space/design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def correct_get_acting(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:

# Unfold if needed
if x_is_unfolded:
x_corrected, is_acting = self.unfold_x(x_corrected, is_acting)
x_corrected, is_acting, _ = self.unfold_x(x_corrected, is_acting)

return x_corrected, is_acting

Expand Down Expand Up @@ -383,7 +383,7 @@ def sample_valid_x(

# Unfold if needed
if unfolded:
x, is_acting = self.unfold_x(x, is_acting)
x, is_acting, _ = self.unfold_x(x, is_acting)

return x, is_acting

Expand Down Expand Up @@ -491,7 +491,7 @@ def fold_x(

def unfold_x(
self, x: np.ndarray, is_acting: np.ndarray = None, fold_mask: np.ndarray = None
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]:
"""
Unfold x and optionally is_acting. Unfolding creates one extra dimension for each categorical variable using
one-hot encoding.
Expand All @@ -511,6 +511,9 @@ def unfold_x(
- Unfolded samples
is_acting_unfolded: np.ndarray [n, dim_unfolded]
- (Optional) boolean matrix specifying for each unfolded variable whether it is acting or non-acting
is_categorical_unfolded: np.ndarray [n, dim_unfolded]
- (Optional) boolean matrix specifying for each unfolded variable whether it is categorical or not
"""

# Get number of unfolded dimension
Expand All @@ -520,6 +523,7 @@ def unfold_x(
is_acting_unfolded = (
np.ones(x_unfolded.shape, dtype=bool) if is_acting is not None else None
)
is_categorical_unfolded = np.ones(x_unfolded.shape, dtype=bool)

i_x_unfold = 0
for i, dv in enumerate(self.design_variables):
Expand All @@ -541,20 +545,23 @@ def unfold_x(
is_acting_unfolded[:, i_x_unfold : i_x_unfold + n_dim_cat] = (
np.tile(is_acting[:, [i]], (1, n_dim_cat))
)
is_categorical_unfolded[:, i_x_unfold : i_x_unfold + n_dim_cat] *= True

i_x_unfold += n_dim_cat

else:
x_unfolded[:, i_x_unfold] = x[:, i]
if is_acting is not None:
is_acting_unfolded[:, i_x_unfold] = is_acting[:, i]
is_categorical_unfolded[:, i_x_unfold] *= False

i_x_unfold += 1

x_unfolded = x_unfolded[:, :i_x_unfold]
if is_acting is not None:
is_acting_unfolded = is_acting_unfolded[:, :i_x_unfold]

return x_unfolded, is_acting_unfolded
return x_unfolded, is_acting_unfolded, is_categorical_unfolded[0]

def _get_n_dim_unfolded(self) -> int:
return sum(
Expand Down
9 changes: 6 additions & 3 deletions smt_design_space/test_design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_base_design_space(self):
]
)

x_unfolded, is_acting_unfolded = ds.unfold_x(x, is_acting)
x_unfolded, is_acting_unfolded, is_cat_unfolded = ds.unfold_x(x, is_acting)
self.assertTrue(
np.all(
x_unfolded
Expand All @@ -171,12 +171,15 @@ def test_base_design_space(self):
[False, False, True, True],
],
)
self.assertEqual(is_cat_unfolded.dtype, bool)

np.testing.assert_array_equal(is_cat_unfolded, [True, True, False, False])

x_folded, is_acting_folded = ds.fold_x(x_unfolded, is_acting_unfolded)
np.testing.assert_array_equal(x_folded, x)
np.testing.assert_array_equal(is_acting_folded, is_acting)

x_unfold_mask, is_act_unfold_mask = ds.unfold_x(
x_unfold_mask, is_act_unfold_mask, _ = ds.unfold_x(
x, is_acting, fold_mask=np.array([False] * 3)
)
np.testing.assert_array_equal(x_unfold_mask, x)
Expand Down Expand Up @@ -308,7 +311,7 @@ def test_folding_mask(self):

self.assertEqual(ds._get_n_dim_unfolded(), 6)

x_unfolded, is_act_unfolded = ds.unfold_x(x, is_act, np.array([True, False]))
x_unfolded, is_act_unfolded, _ = ds.unfold_x(x, is_act, np.array([True, False]))
self.assertTrue(np.all(x_unfolded == np.array([[0, 1, 0, 2]])))
self.assertTrue(
np.all(is_act_unfolded == np.array([[True, True, True, False]]))
Expand Down

0 comments on commit 82f9465

Please sign in to comment.