Skip to content

Commit

Permalink
Added possibility to set callback in opti to have intermediate soluti…
Browse files Browse the repository at this point in the history
…ons.
  • Loading branch information
S-Dafarra committed Jan 8, 2024
1 parent 81d69c5 commit 856df1f
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 11 deletions.
2 changes: 2 additions & 0 deletions src/hippopt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hippopt.base.opti_callback as opti_callback

from .base.dynamics import Dynamics, TypedDynamics, dot
from .base.multiple_shooting_solver import MultipleShootingSolver
from .base.opti_solver import OptiFailure, OptiSolver
Expand Down
1 change: 1 addition & 0 deletions src/hippopt/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import (
dynamics,
multiple_shooting_solver,
opti_callback,
opti_solver,
optimal_control_problem,
optimal_control_solver,
Expand Down
25 changes: 17 additions & 8 deletions src/hippopt/base/opti_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def set_opti(self, opti: cs.Opti) -> None:
# In theory, the callback is included in opti,
# so the weakref is to avoid circular references
self.opti = weakref.proxy(opti)
self.reset()


class BestCost(CallbackCriterion):
Expand Down Expand Up @@ -242,7 +243,7 @@ def _get_current_primal_infeasibility(self) -> float:
return self.opti.debug.stats()["iterations"]["inf_pr"][-1]


class CombinedCallbackCriterion(abc.ABC, CallbackCriterion):
class CombinedCallbackCriterion(CallbackCriterion, abc.ABC):
""""""

def __init__(self, lhs: CallbackCriterion, rhs: CallbackCriterion) -> None:
Expand Down Expand Up @@ -318,11 +319,12 @@ def __init__(
self.cost = costs
self.constraints = constraints

self.best_stats = None
self.best_variables = {}
self.best_iteration = None
self.best_objects = {}
self.best_cost = None
self.best_cost_values = {}
self.best_constraint_multipliers = {}
self.ignore_map = {obj: False for obj in self.optimization_objects}

def call(self, i: int) -> None:
""""""
Expand All @@ -333,12 +335,19 @@ def call(self, i: int) -> None:
_logger = logging.getLogger(f"[hippopt::{self.__class__.__name__}]")
_logger.info(f"[i={i}] New best intermediate variables")

self.best_stats = self.opti.debug.stats()
self.best_iteration = i
self.best_cost = self.opti.debug.value(self.opti.f)
self.best_variables = {
optimization_object: self.opti.debug.value(optimization_object)
for optimization_object in self.optimization_objects
}
self.best_objects = {}
for optimization_object in self.optimization_objects:
if self.ignore_map[optimization_object]:
continue
try:
self.best_objects[optimization_object] = self.opti.debug.value(
optimization_object
)
except Exception as err: # noqa
self.ignore_map[optimization_object] = True

self.best_cost_values = {
cost: self.opti.debug.value(cost) for cost in self.cost
}
Expand Down
63 changes: 60 additions & 3 deletions src/hippopt/base/opti_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import casadi as cs
import numpy as np

from hippopt.base.opti_callback import (
CallbackCriterion,
SaveBestUnsolvedVariablesCallback,
)
from hippopt.base.optimization_object import (
OptimizationObject,
StorageType,
Expand All @@ -22,8 +26,15 @@


class OptiFailure(Exception):
def __init__(self, message: Exception):
super().__init__("Opti failed to solve the problem. Message: " + str(message))
def __init__(self, message: Exception, callback_used: bool):
callback_info = ""
if callback_used:
callback_info = (
" and the callback did not manage to save an intermediate solution"
)
super().__init__(
f"Opti failed to solve the problem{callback_info}. Message: {str(message)}"
)


class InitialGuessFailure(Exception):
Expand All @@ -50,6 +61,11 @@ class OptiSolver(OptimizationSolver):
options_plugin: dataclasses.InitVar[dict[str, Any]] = dataclasses.field(
default=None
)
_callback_criterion: CallbackCriterion = dataclasses.field(default=None)
callback_criterion: dataclasses.InitVar[CallbackCriterion] = dataclasses.field(
default=None
)
_callback: SaveBestUnsolvedVariablesCallback = dataclasses.field(default=None)

_cost: cs.MX = dataclasses.field(default=None)
_cost_expressions: dict[str, cs.MX] = dataclasses.field(default=None)
Expand Down Expand Up @@ -80,6 +96,7 @@ def __post_init__(
problem_type: str = "nlp",
options_solver: dict[str, Any] = None,
options_plugin: dict[str, Any] = None,
callback_criterion: CallbackCriterion = None,
):
self._solver = cs.Opti(problem_type)
self._inner_solver = (
Expand All @@ -94,6 +111,7 @@ def __post_init__(
self._solver.solver(
self._inner_solver, self._options_plugin, self._options_solver
)
self._callback_criterion = callback_criterion
self._cost_expressions = {}
self._constraint_expressions = {}
self._objects_type_map = {}
Expand Down Expand Up @@ -634,10 +652,49 @@ def solve(self) -> None:
raise ValueError(
"The following parameters are not set: " + str(self._free_parameters)
)
use_callback = self._callback_criterion is not None
if use_callback:
self._callback = SaveBestUnsolvedVariablesCallback(
criterion=self._callback_criterion,
opti=self._solver,
optimization_objects=list(self._objects_type_map.keys()),
costs=list(self._cost_expressions.values()),
constraints=list(self._constraint_expressions.values()),
)
self._solver.callback(self._callback)
try:
opti_solution = self._solver.solve()
except Exception as err: # noqa
raise OptiFailure(message=err)
if use_callback and self._callback.best_iteration is not None:
self._logger.warning(
"Opti failed to solve the problem, but the callback managed to save"
" an intermediate solution at "
f"iteration {self._callback.best_iteration}."
)
self._output_cost = self._callback.best_cost
self._output_solution = self._generate_solution_output(
variables=self._variables,
input_solution=self._callback.best_objects,
)
self._cost_values = {
name: float(
self._callback.best_cost_values[self._cost_expressions[name]]
)
for name in self._cost_expressions
}
self._constraint_values = {
name: np.array(
(
self._callback.best_constraint_multipliers[
self._constraint_expressions[name]
]
)
)
for name in self._constraint_expressions
}
return

raise OptiFailure(message=err, callback_used=use_callback)

self._output_cost = opti_solution.value(self._cost)
self._output_solution = self._generate_solution_output(
Expand Down
18 changes: 18 additions & 0 deletions test/test_optimization_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
OptiFailure,
OptimizationObject,
OptimizationProblem,
OptiSolver,
Parameter,
StorageType,
Variable,
default_storage_field,
opti_callback,
)


Expand Down Expand Up @@ -263,3 +265,19 @@ def test_opti_failure():
print("Received error: ", err)
else:
assert False


def test_opti_callback():
opti_solver = OptiSolver(
callback_criterion=opti_callback.BestCost()
| opti_callback.BestPrimalInfeasibility()
)
problem, variables = OptimizationProblem.create(
input_structure=SwitchVar(), optimization_solver=opti_solver
)

problem.add_constraint(variables.x <= 1)
problem.add_constraint(variables.x >= 0)
problem.add_constraint(variables.x**2 == 10)

problem.solve()

0 comments on commit 856df1f

Please sign in to comment.