Skip to content

Commit 236d4fa

Browse files
committed
feat: add manual generator for generation steps
1 parent c009422 commit 236d4fa

File tree

1 file changed

+73
-2
lines changed

1 file changed

+73
-2
lines changed

src/foambo/core.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,35 @@
99
import pandas as pd
1010
import numpy as np
1111

12-
from omegaconf import OmegaConf, DictConfig, DictKeyType
12+
from omegaconf import OmegaConf, DictConfig, DictKeyType, ListConfig
1313
from ax.core.base_trial import TrialStatus, BaseTrial
1414
from ax.core.trial import Trial
1515
from ax.core.runner import Runner
1616
from ax.core.metric import Metric, MetricFetchResult, MetricFetchE
1717
from ax.core.data import Data
18+
from ax.core.search_space import SearchSpace
19+
from ax.core.parameter import RangeParameter, ChoiceParameter, FixedParameter
1820
from ax.service.utils.instantiation import ObjectiveProperties
21+
from ax.models.random.base import RandomModel
22+
from ax.modelbridge.registry import (
23+
ModelRegistryBase,
24+
ModelSetup,
25+
MODEL_KEY_TO_MODEL_SETUP,
26+
RandomModelBridge,
27+
Cont_X_trans,
28+
)
1929
from ax.utils.notebook.plotting import plot_config_to_html, render
2030
from ax.plot.pareto_frontier import plot_pareto_frontier
2131
from ax.utils.report.render import render_report_elements
2232
from typing import Any, Dict, NamedTuple, Union, Iterable, Set, List
2333
from ax.utils.common.result import Ok, Err
2434
from ax.storage.metric_registry import register_metric
2535
from ax.storage.runner_registry import register_runner
26-
from ax.storage.json_store.registry import CORE_ENCODER_REGISTRY, CORE_DECODER_REGISTRY, CORE_CLASS_DECODER_REGISTRY
36+
from ax.storage.json_store.registry import (
37+
CORE_ENCODER_REGISTRY,
38+
CORE_DECODER_REGISTRY,
39+
CORE_CLASS_DECODER_REGISTRY,
40+
)
2741
from ax.storage.json_store.encoders import metric_to_dict
2842
from ax.storage.json_store.encoders import runner_to_dict
2943

@@ -91,6 +105,63 @@ def plot_frontier(frontier,name, CI_level=0.9):
91105
))
92106
render(plot_config)
93107

108+
class ManualGenerator(RandomModel):
109+
"""
110+
Class to generate trials manually within a Strategy Step
111+
"""
112+
def __init__(self, parameter_sets: ListConfig, search: SearchSpace):
113+
self.index = 0
114+
self.parameter_sets = parameter_sets
115+
self.search_space = search
116+
self.deduplicate = False
117+
self.generated_points = None
118+
119+
def gen(
120+
self,
121+
n: int,
122+
bounds = None,
123+
linear_constraints = None,
124+
fixed_features = None,
125+
model_gen_options = None,
126+
rounding_func = None,
127+
):
128+
params = self.search_space.parameters
129+
if self.index + n > len(self.parameter_sets):
130+
raise ValueError(
131+
f"Not enough parameter sets available. Requested: {n}, Available: {len(self.parameter_sets) - self.index}."
132+
)
133+
generated = self.parameter_sets[self.index:self.index + n]
134+
points = []
135+
for i in range(len(generated)):
136+
param_bounds = []
137+
for k in generated[i].keys():
138+
point = None
139+
if isinstance(params[k], RangeParameter):
140+
point = (generated[i][k] - params[k].lower) / (params[k].upper-params[k].lower)
141+
if isinstance(params[k], ChoiceParameter):
142+
id = params[k].values.index(generated[i][k])
143+
point = id / len(params[k].values)
144+
if isinstance(params[k], FixedParameter):
145+
point = 1.0
146+
param_bounds.append(point)
147+
points.append(param_bounds)
148+
points = np.array(points)
149+
self.index += n
150+
self.generated_points = points
151+
return points, np.ones(len(points))
152+
153+
MODEL_KEY_TO_MODEL_SETUP["MANUAL"] = ModelSetup(
154+
bridge_class=RandomModelBridge,
155+
model_class=ManualGenerator,
156+
transforms=Cont_X_trans,
157+
)
158+
159+
class CustomModels(ModelRegistryBase):
160+
"""
161+
Register custom generators and models
162+
"""
163+
Manual = "MANUAL"
164+
94165
class HPCJob(NamedTuple):
95166
"""
96167
An async OpenFOAM job scheduled on an HPC system.

0 commit comments

Comments
 (0)