diff --git a/src/hippopt/base/opti_callback.py b/src/hippopt/base/opti_callback.py index 3cfc2576..d4c1aa17 100644 --- a/src/hippopt/base/opti_callback.py +++ b/src/hippopt/base/opti_callback.py @@ -27,6 +27,7 @@ class CallbackCriterion(abc.ABC): def __init__(self) -> None: """""" self.opti = None + self.opti_debug = None @abc.abstractmethod def satisfied(self) -> bool: @@ -69,6 +70,10 @@ def set_opti(self, opti: cs.Opti) -> None: self.opti = weakref.proxy(opti) self.reset() + def update_opti_debug(self, opti_debug: cs.OptiAdvanced) -> None: + """""" + self.opti_debug = opti_debug + class BestCost(CallbackCriterion): """""" @@ -107,7 +112,7 @@ def update(self) -> None: def _get_current_cost(self) -> float: """""" - return self.opti.debug.value(self.opti.f) + return self.opti_debug.value(self.opti.f) class AcceptableCost(CallbackCriterion): @@ -151,7 +156,7 @@ def update(self) -> None: def _get_current_cost(self) -> float: """""" - return self.opti.debug.value(self.opti.f) + return self.opti_debug.value(self.opti.f) class AcceptablePrimalInfeasibility(CallbackCriterion): @@ -199,7 +204,7 @@ def update(self) -> None: def _get_current_primal_infeasibility(self) -> float: """""" - return self.opti.debug.stats()["iterations"]["inf_pr"][-1] + return self.opti_debug.stats()["iterations"]["inf_pr"][-1] class BestPrimalInfeasibility(CallbackCriterion): @@ -240,7 +245,7 @@ def update(self) -> None: def _get_current_primal_infeasibility(self) -> float: """""" - return self.opti.debug.stats()["iterations"]["inf_pr"][-1] + return self.opti_debug.stats()["iterations"]["inf_pr"][-1] class CombinedCallbackCriterion(CallbackCriterion, abc.ABC): @@ -274,6 +279,13 @@ def set_opti(self, opti: cs.Opti) -> None: self.lhs.set_opti(opti) self.rhs.set_opti(opti) + @final + def update_opti_debug(self, opti_debug: cs.OptiAdvanced) -> None: + """""" + + self.lhs.update_opti_debug(opti_debug) + self.rhs.update_opti_debug(opti_debug) + class OrCombinedCallbackCriterion(CombinedCallbackCriterion): """""" @@ -330,6 +342,8 @@ def __init__( def call(self, i: int) -> None: """""" + opti_debug = self.opti.debug + self.criterion.update_opti_debug(opti_debug) if self.criterion.satisfied(): self.criterion.update() @@ -338,24 +352,22 @@ def call(self, i: int) -> None: _logger.info(f"[i={i}] New best intermediate variables") self.best_iteration = i - self.best_cost = self.opti.debug.value(self.opti.f) + self.best_cost = opti_debug.value(self.opti.f) for variable in self.variables: if self.ignore_map[variable]: continue try: - self.best_objects[variable] = self.opti.debug.value(variable) + self.best_objects[variable] = opti_debug.value(variable) except Exception as err: # noqa self.ignore_map[variable] = True for parameter in self.parameters: if self.ignore_map[parameter]: continue - self.best_objects[parameter] = self.opti.debug.value(parameter) + self.best_objects[parameter] = opti_debug.value(parameter) self.ignore_map[parameter] = True # Parameters are saved only once - self.best_cost_values = { - cost: self.opti.debug.value(cost) for cost in self.cost - } + self.best_cost_values = {cost: opti_debug.value(cost) for cost in self.cost} self.best_constraint_multipliers = { - constraint: self.opti.debug.value(self.opti.dual(constraint)) + constraint: opti_debug.value(self.opti.dual(constraint)) for constraint in self.constraints }