|
9 | 9 | import pandas as pd |
10 | 10 | import numpy as np |
11 | 11 |
|
12 | | -from omegaconf import OmegaConf, DictConfig, DictKeyType |
| 12 | +from omegaconf import OmegaConf, DictConfig, DictKeyType, ListConfig |
13 | 13 | from ax.core.base_trial import TrialStatus, BaseTrial |
14 | 14 | from ax.core.trial import Trial |
15 | 15 | from ax.core.runner import Runner |
16 | 16 | from ax.core.metric import Metric, MetricFetchResult, MetricFetchE |
17 | 17 | from ax.core.data import Data |
| 18 | +from ax.core.search_space import SearchSpace |
| 19 | +from ax.core.parameter import RangeParameter, ChoiceParameter, FixedParameter |
18 | 20 | from ax.service.utils.instantiation import ObjectiveProperties |
| 21 | +from ax.models.random.base import RandomModel |
| 22 | +from ax.modelbridge.registry import ( |
| 23 | + ModelRegistryBase, |
| 24 | + ModelSetup, |
| 25 | + MODEL_KEY_TO_MODEL_SETUP, |
| 26 | + RandomModelBridge, |
| 27 | + Cont_X_trans, |
| 28 | +) |
19 | 29 | from ax.utils.notebook.plotting import plot_config_to_html, render |
20 | 30 | from ax.plot.pareto_frontier import plot_pareto_frontier |
21 | 31 | from ax.utils.report.render import render_report_elements |
22 | 32 | from typing import Any, Dict, NamedTuple, Union, Iterable, Set, List |
23 | 33 | from ax.utils.common.result import Ok, Err |
24 | 34 | from ax.storage.metric_registry import register_metric |
25 | 35 | from ax.storage.runner_registry import register_runner |
26 | | -from ax.storage.json_store.registry import CORE_ENCODER_REGISTRY, CORE_DECODER_REGISTRY, CORE_CLASS_DECODER_REGISTRY |
| 36 | +from ax.storage.json_store.registry import ( |
| 37 | + CORE_ENCODER_REGISTRY, |
| 38 | + CORE_DECODER_REGISTRY, |
| 39 | + CORE_CLASS_DECODER_REGISTRY, |
| 40 | +) |
27 | 41 | from ax.storage.json_store.encoders import metric_to_dict |
28 | 42 | from ax.storage.json_store.encoders import runner_to_dict |
29 | 43 |
|
@@ -91,6 +105,63 @@ def plot_frontier(frontier,name, CI_level=0.9): |
91 | 105 | )) |
92 | 106 | render(plot_config) |
93 | 107 |
|
| 108 | +class ManualGenerator(RandomModel): |
| 109 | + """ |
| 110 | + Class to generate trials manually within a Strategy Step |
| 111 | + """ |
| 112 | + def __init__(self, parameter_sets: ListConfig, search: SearchSpace): |
| 113 | + self.index = 0 |
| 114 | + self.parameter_sets = parameter_sets |
| 115 | + self.search_space = search |
| 116 | + self.deduplicate = False |
| 117 | + self.generated_points = None |
| 118 | + |
| 119 | + def gen( |
| 120 | + self, |
| 121 | + n: int, |
| 122 | + bounds = None, |
| 123 | + linear_constraints = None, |
| 124 | + fixed_features = None, |
| 125 | + model_gen_options = None, |
| 126 | + rounding_func = None, |
| 127 | + ): |
| 128 | + params = self.search_space.parameters |
| 129 | + if self.index + n > len(self.parameter_sets): |
| 130 | + raise ValueError( |
| 131 | + f"Not enough parameter sets available. Requested: {n}, Available: {len(self.parameter_sets) - self.index}." |
| 132 | + ) |
| 133 | + generated = self.parameter_sets[self.index:self.index + n] |
| 134 | + points = [] |
| 135 | + for i in range(len(generated)): |
| 136 | + param_bounds = [] |
| 137 | + for k in generated[i].keys(): |
| 138 | + point = None |
| 139 | + if isinstance(params[k], RangeParameter): |
| 140 | + point = (generated[i][k] - params[k].lower) / (params[k].upper-params[k].lower) |
| 141 | + if isinstance(params[k], ChoiceParameter): |
| 142 | + id = params[k].values.index(generated[i][k]) |
| 143 | + point = id / len(params[k].values) |
| 144 | + if isinstance(params[k], FixedParameter): |
| 145 | + point = 1.0 |
| 146 | + param_bounds.append(point) |
| 147 | + points.append(param_bounds) |
| 148 | + points = np.array(points) |
| 149 | + self.index += n |
| 150 | + self.generated_points = points |
| 151 | + return points, np.ones(len(points)) |
| 152 | + |
| 153 | +MODEL_KEY_TO_MODEL_SETUP["MANUAL"] = ModelSetup( |
| 154 | + bridge_class=RandomModelBridge, |
| 155 | + model_class=ManualGenerator, |
| 156 | + transforms=Cont_X_trans, |
| 157 | +) |
| 158 | + |
| 159 | +class CustomModels(ModelRegistryBase): |
| 160 | + """ |
| 161 | + Register custom generators and models |
| 162 | + """ |
| 163 | + Manual = "MANUAL" |
| 164 | + |
94 | 165 | class HPCJob(NamedTuple): |
95 | 166 | """ |
96 | 167 | An async OpenFOAM job scheduled on an HPC system. |
|
0 commit comments