Skip to content

Commit

Permalink
add adsg with original api
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul-Saves committed Oct 14, 2024
1 parent 5def554 commit 3f94425
Showing 1 changed file with 164 additions and 2 deletions.
166 changes: 164 additions & 2 deletions smt_design_space/design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from adsg_core.graph.graph_edges import EdgeType
from adsg_core import GraphProcessor, SelectionChoiceNode
from adsg_core.graph.adsg import ADSG
from adsg_core import BasicADSG, NamedNode, DesignVariableNode

HAS_ADSG = True
except ImportError:
Expand Down Expand Up @@ -1303,24 +1304,84 @@ def __init__(
self.adsg = adsg
elif design_variables is not None:
# to do
self.ds_leg = DesignSpace(
design_variables=design_variables, random_state=seed
)
self.adsg = _legacy_to_adsg(self.ds_leg)
pass
else:
raise ValueError("Either design_variables or adsg should be provided.")

self.graph_proc = GraphProcessor(graph=adsg)
self.graph_proc = GraphProcessor(graph=self.adsg)

if not (HAS_ADSG):
raise ImportError("ADSG is not installed")
if not (HAS_CONFIG_SPACE):
raise ImportError("ConfigSpace is not installed")

design_space = ensure_design_space(design_space=adsg)
design_space = ensure_design_space(design_space=self.adsg)
self._design_variables = design_space.design_variables
super().__init__(design_variables=self._design_variables, random_state=seed)
self._cs = design_space._cs
self._cs_cate = design_space._cs_cate
self._is_decreed = design_space._is_decreed

def declare_decreed_var(
self, decreed_var: int, meta_var: int, meta_value: VarValueType
):
"""
Define a conditional (decreed) variable to be active when the meta variable has (one of) the provided values.
Parameters
----------
decreed_var: int
- Index of the conditional variable (the variable that is conditionally active)
meta_var: int
- Index of the meta variable (the variable that determines whether the conditional var is active)
meta_value: int | str | list[int|str]
- The value or list of values that the meta variable can have to activate the decreed var
"""

self.ds_leg.declare_decreed_var(
decreed_var=decreed_var, meta_var=meta_var, meta_value=meta_value
)
self.adsg = _legacy_to_adsg(self.ds_leg)
design_space = ensure_design_space(design_space=self.adsg)
self._design_variables = design_space.design_variables
self._cs = design_space._cs
self._cs_cate = design_space._cs_cate
self._is_decreed = design_space._is_decreed
self.graph_proc = GraphProcessor(graph=self.adsg)

def add_value_constraint(
self, var1: int, value1: VarValueType, var2: int, value2: VarValueType
):
"""
Define a constraint where two variables cannot have the given values at the same time.
Parameters
----------
var1: int
- Index of the first variable
value1: int | str | list[int|str]
- Value or values that the first variable is checked against
var2: int
- Index of the second variable
value2: int | str | list[int|str]
- Value or values that the second variable is checked against
"""

self.ds_leg.add_value_constraint(
var1=var1, value1=value1, var2=var2, value2=value2
)
self.adsg = _legacy_to_adsg(self.ds_leg)
design_space = ensure_design_space(design_space=self.adsg)
self._design_variables = design_space.design_variables
self._cs = design_space._cs
self._cs_cate = design_space._cs_cate
self._is_decreed = design_space._is_decreed
self.graph_proc = GraphProcessor(graph=self.adsg)

def _sample_valid_x(
self,
n: int,
Expand Down Expand Up @@ -1554,3 +1615,104 @@ def remove_symmetry(lst):
) # Forbid more than 35 neurons with ASGD

return design_space


def _legacy_to_adsg(legacy_ds: "DesignSpace") -> BasicADSG:
"""
Interface to turn a legacy DesignSpace back into an ADSG instance.
Parameters:
legacy_ds (DesignSpace): The legacy DesignSpace instance.
Returns:
BasicADSG: The corresponding ADSG graph.
"""
adsg = BasicADSG()

# Create nodes for each variable in the DesignSpace
nodes = {}
value_nodes = {} # This will store decreed value nodes
start_nodes = set()
for i, var in enumerate(legacy_ds._design_variables):
if isinstance(var, FloatVariable) or isinstance(var, IntegerVariable):
# Create a DesignVariableNode with bounds for continuous variables
var_node = DesignVariableNode(f"x{i}", bounds=(var.lower, var.upper))
elif isinstance(var, CategoricalVariable):
# Create a SelectionChoiceNode for categorical variables
var_node = NamedNode(f"x{i}")
choices = [NamedNode(value) for value in var.values]
value_nodes[f"x{i}"] = (
choices # Store decreed value nodes for this variable
)
adsg.add_selection_choice(f"choice_x{i}", var_node, choices)
elif isinstance(var, OrdinalVariable):
# Create a SelectionChoiceNode for ordinal variables (ordinal treated like categorical)
var_node = NamedNode(f"x{i}")
choices = [NamedNode(value) for value in var.values]
value_nodes[f"x{i}"] = (
choices # Store decreed value nodes for this variable
)
adsg.add_selection_choice(f"choice_x{i}", var_node, choices)
else:
raise ValueError(f"Unsupported variable type: {type(var)}")

adsg.add_node(var_node)
nodes[f"x{i}"] = var_node
start_nodes.add(var_node)

# Handle decreed variables (conditional dependencies)
for decreed_var in legacy_ds._cs._conditionals:
decreed_node = nodes[f"{decreed_var}"]
if decreed_node in start_nodes:
start_nodes.remove(decreed_node)
# Get parent condition(s) from the legacy design space
parent_conditions = legacy_ds._cs._parent_conditions_of[decreed_var]
for condition in parent_conditions:
meta_var = condition.parent.name # Parent variable
try:
meta_values = (
condition.values
) # Values that activate the decreed variable
except AttributeError:
meta_values = [condition.value]

# Add conditional decreed edges
for value in meta_values:
meta_nodes = [node for node in value_nodes[f"{meta_var}"]]
meta_node_ind = [
node.name for node in value_nodes[f"{meta_var}"]
].index(str(value)[:])
value_node = meta_nodes[meta_node_ind]

meta_node = nodes[f"x{legacy_ds._cs._hyperparameter_idx[meta_var]}"]
adsg.add_edge(
value_node, decreed_node
) # Linking decreed node to meta node

# Handle value constraints (incompatibilities)
for value_constraint in legacy_ds._cs.forbidden_clauses:
clause1 = value_constraint.components[0]
var1 = clause1.hyperparameter.name
values1 = clause1.value or clause1.values
clause2 = value_constraint.components[1]
var2 = clause2.hyperparameter.name
values2 = clause2.value or clause2.values

for value1 in values1:
for value2 in values2:
# Retrieve decreed value nodes from value_nodes
value_nodes1 = [node for node in value_nodes[f"{var1}"]]
value_node1_ind = [node.name for node in value_nodes[f"{var1}"]].index(
str(value1)[:]
)
value_node1 = value_nodes1[value_node1_ind]
value_nodes2 = [node for node in value_nodes[f"{var2}"]]
value_node2_ind = [node.name for node in value_nodes[f"{var2}"]].index(
str(value2)[:]
)
value_node2 = value_nodes2[value_node2_ind]
if value_node1 and value_node2:
# Add incompatibility constraint between the two value nodes
adsg.add_incompatibility_constraint([value_node1, value_node2])
adsg = adsg.set_start_nodes(start_nodes)
return adsg

0 comments on commit 3f94425

Please sign in to comment.