Skip to content

Commit

Permalink
Added to_function in opti solver
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Dafarra committed Apr 11, 2024
1 parent 9991cbe commit d807aac
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/hippopt/base/opti_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def __init__(self, message: Exception):
)


class ToFunctionSolveNotCalled(Exception):
def __init__(self):
super().__init__(
"The solve() method must be called before converting to function."
)


@dataclasses.dataclass
class OptiSolver(OptimizationSolver):
DefaultSolverType: ClassVar[str] = "ipopt"
Expand Down Expand Up @@ -580,3 +587,24 @@ def get_object_type(self, obj: cs.MX) -> str:

def get_free_parameters_names(self) -> list[str]:
return self._free_parameters

def to_function(
self, name: str = "opti_function", options: dict = None
) -> cs.Function:
if self._output_solution is None:
# to_function does not seem to work without calling solve() first
raise ToFunctionSolveNotCalled

variables_names = list(self._objects_dict.keys())
# Prepend guess to the variable names
guess_names = ["guess." + name for name in variables_names]
variables_values = list(self._objects_dict.values())
options = {} if options is None else options
return self._solver.to_function(
name,
variables_values,
variables_values,
guess_names,
variables_names,
options,
)
49 changes: 49 additions & 0 deletions test/test_optimization_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,52 @@ def test_opti_callback():
problem.add_constraint(variables.x**2 == 10)

problem.solve()


def test_opti_to_function():
opti_solver = OptiSolver()
problem, var = OptimizationProblem.create(
input_structure=MyTestVarAndPar(), optimization_solver=opti_solver
)
initial_guess = MyTestVarAndPar()
np.random.seed(123)
a = 10.0 * np.random.rand(3, 1) + 0.01
b = 20.0 * np.random.rand(3, 1) - 10.0
c = 20.0 * np.random.rand(3, 1) - 10.0

initial_guess.parameter = c

problem.add_expression(
mode=ExpressionType.minimize,
expression=(
a[k] * cs.power(var.composite.variable[k], 2)
+ b[k] * var.composite.variable[k]
for k in range(3)
),
)

problem.add_expression(
mode=ExpressionType.subject_to,
expression=( # noqa
var.composite.variable[k] >= var.parameter[k] for k in range(3)
),
)

problem.solver().set_initial_guess(initial_guess=initial_guess)
problem.solve()

c = 20.0 * np.random.rand(3, 1) - 10.0
initial_guess.parameter = c

opti_function = opti_solver.to_function()
output_dict = opti_function(**initial_guess.to_dict(prefix="guess."))
output = MyTestVarAndPar()
output.from_dict(output_dict)

expected_x = np.zeros((3, 1))
for i in range(3):
expected = -b[i] / (2 * a[i])
expected_x[i] = expected if expected >= c[i] else c[i]

assert output.composite.variable == pytest.approx(expected_x)
assert output.parameter == pytest.approx(c)

0 comments on commit d807aac

Please sign in to comment.