Skip to content

Commit

Permalink
Allow avoiding saving the cost values and constraint multipliers to s…
Browse files Browse the repository at this point in the history
…ave time in the callback
  • Loading branch information
S-Dafarra committed Jan 15, 2024
1 parent f50cec7 commit fd83021
Showing 1 changed file with 42 additions and 16 deletions.
58 changes: 42 additions & 16 deletions src/hippopt/base/opti_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ class OptiSolver(OptimizationSolver):
default=None
)
_callback: SaveBestUnsolvedVariablesCallback = dataclasses.field(default=None)
_callback_save_costs: bool = dataclasses.field(default=True)
_callback_save_constraint_multipliers: bool = dataclasses.field(default=True)
callback_save_costs: dataclasses.InitVar[bool] = dataclasses.field(default=None)
callback_save_constraint_multipliers: dataclasses.InitVar[bool] = dataclasses.field(
default=None
)

_cost: cs.MX = dataclasses.field(default=None)
_cost_expressions: dict[str, cs.MX] = dataclasses.field(default=None)
Expand Down Expand Up @@ -97,6 +103,8 @@ def __post_init__(
options_solver: dict[str, Any] = None,
options_plugin: dict[str, Any] = None,
callback_criterion: CallbackCriterion = None,
callback_save_costs: bool = True,
callback_save_constraint_multipliers: bool = True,
):
self._solver = cs.Opti(problem_type)
self._inner_solver = (
Expand All @@ -112,6 +120,10 @@ def __post_init__(
self._inner_solver, self._options_plugin, self._options_solver
)
self._callback_criterion = callback_criterion
self._callback_save_costs = callback_save_costs
self._callback_save_constraint_multipliers = (
callback_save_constraint_multipliers
)
self._cost_expressions = {}
self._constraint_expressions = {}
self._objects_type_map = {}
Expand Down Expand Up @@ -658,8 +670,12 @@ def solve(self) -> None:
criterion=self._callback_criterion,
opti=self._solver,
optimization_objects=list(self._objects_type_map.keys()),
costs=list(self._cost_expressions.values()),
constraints=list(self._constraint_expressions.values()),
costs=list(self._cost_expressions.values())
if self._callback_save_costs
else [],
constraints=list(self._constraint_expressions.values())
if self._callback_save_constraint_multipliers
else [],
)
self._solver.callback(self._callback)
try:
Expand All @@ -676,22 +692,32 @@ def solve(self) -> None:
variables=self._variables,
input_solution=self._callback.best_objects,
)
self._cost_values = {
name: float(
self._callback.best_cost_values[self._cost_expressions[name]]
)
for name in self._cost_expressions
}
self._constraint_values = {
name: np.array(
(
self._callback.best_constraint_multipliers[
self._constraint_expressions[name]
self._cost_values = (
{
name: float(
self._callback.best_cost_values[
self._cost_expressions[name]
]
)
)
for name in self._constraint_expressions
}
for name in self._cost_expressions
}
if self._callback_save_costs
else {}
)
self._constraint_values = (
{
name: np.array(
(
self._callback.best_constraint_multipliers[
self._constraint_expressions[name]
]
)
)
for name in self._constraint_expressions
}
if self._callback_save_constraint_multipliers
else {}
)
return

raise OptiFailure(message=err, callback_used=use_callback)
Expand Down

0 comments on commit fd83021

Please sign in to comment.