diff --git a/src/hippopt/base/opti_callback.py b/src/hippopt/base/opti_callback.py index 0481760c..3cfc2576 100644 --- a/src/hippopt/base/opti_callback.py +++ b/src/hippopt/base/opti_callback.py @@ -302,7 +302,8 @@ def __init__( self, criterion: CallbackCriterion, opti: cs.Opti, - optimization_objects: list[cs.MX], + variables: list[cs.MX], + parameters: list[cs.MX], costs: list[cs.MX], constraints: list[cs.MX], ) -> None: @@ -315,7 +316,8 @@ def __init__( # so the weakref is to avoid circular references self.opti = weakref.proxy(opti) self.criterion.set_opti(opti) - self.optimization_objects = optimization_objects + self.variables = variables + self.parameters = parameters self.cost = costs self.constraints = constraints @@ -324,7 +326,7 @@ def __init__( self.best_cost = None self.best_cost_values = {} self.best_constraint_multipliers = {} - self.ignore_map = {obj: False for obj in self.optimization_objects} + self.ignore_map = {obj: False for obj in self.variables + self.parameters} def call(self, i: int) -> None: """""" @@ -337,16 +339,18 @@ def call(self, i: int) -> None: self.best_iteration = i self.best_cost = self.opti.debug.value(self.opti.f) - self.best_objects = {} - for optimization_object in self.optimization_objects: - if self.ignore_map[optimization_object]: + for variable in self.variables: + if self.ignore_map[variable]: continue try: - self.best_objects[optimization_object] = self.opti.debug.value( - optimization_object - ) + self.best_objects[variable] = self.opti.debug.value(variable) except Exception as err: # noqa - self.ignore_map[optimization_object] = True + self.ignore_map[variable] = True + for parameter in self.parameters: + if self.ignore_map[parameter]: + continue + self.best_objects[parameter] = self.opti.debug.value(parameter) + self.ignore_map[parameter] = True # Parameters are saved only once self.best_cost_values = { cost: self.opti.debug.value(cost) for cost in self.cost diff --git a/src/hippopt/base/opti_solver.py b/src/hippopt/base/opti_solver.py index 955496e1..715c974a 100644 --- a/src/hippopt/base/opti_solver.py +++ b/src/hippopt/base/opti_solver.py @@ -666,10 +666,19 @@ def solve(self) -> None: ) use_callback = self._callback_criterion is not None if use_callback: + variables = [] + parameters = [] + for obj in self._objects_type_map: + if self._objects_type_map[obj] is Variable.StorageTypeValue: + variables.append(obj) + elif self._objects_type_map[obj] is Parameter.StorageTypeValue: + parameters.append(obj) + self._callback = SaveBestUnsolvedVariablesCallback( criterion=self._callback_criterion, opti=self._solver, - optimization_objects=list(self._objects_type_map.keys()), + variables=variables, + parameters=parameters, costs=list(self._cost_expressions.values()) if self._callback_save_costs else [],