diff --git a/cooper/formulation/augmented_lagrangian.py b/cooper/formulation/augmented_lagrangian.py index a6746e33..95971701 100644 --- a/cooper/formulation/augmented_lagrangian.py +++ b/cooper/formulation/augmented_lagrangian.py @@ -25,7 +25,7 @@ class AugmentedLagrangianFormulation(LagrangianFormulation): def __init__( self, - cmp: ConstrainedMinimizationProblem, + cmp: Optional[ConstrainedMinimizationProblem] = None, ineq_init: Optional[torch.Tensor] = None, eq_init: Optional[torch.Tensor] = None, ): @@ -104,7 +104,7 @@ def composite_objective( closure: Callable[..., CMPState] = None, *closure_args, pre_computed_state: Optional[CMPState] = None, - write_state: bool = True, + write_state: Optional[bool] = True, **closure_kwargs ) -> torch.Tensor: """ @@ -146,8 +146,8 @@ def composite_objective( else: cmp_state = closure(*closure_args, **closure_kwargs) - if write_state: - self.cmp.state = cmp_state + if write_state and self.cmp is not None: + self.write_cmp_state(cmp_state) # Extract values from ProblemState object loss = cmp_state.loss diff --git a/cooper/formulation/formulation.py b/cooper/formulation/formulation.py index 553f9b16..4e08aeea 100644 --- a/cooper/formulation/formulation.py +++ b/cooper/formulation/formulation.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional import torch @@ -13,8 +13,8 @@ class Formulation(abc.ABC): """Base class for formulations of CMPs.""" - def __init__(self): - self.cmp = None + def __init__(self, cmp: Optional[ConstrainedMinimizationProblem] = None): + self.cmp = cmp @abc.abstractmethod def state_dict(self) -> Dict[str, Any]: @@ -67,6 +67,18 @@ def custom_backward(self, *args, **kwargs): """ self._populate_gradients(*args, **kwargs) + def write_cmp_state(self, cmp_state: CMPState): + """Provided that the formulation is linked to a + `ConstrainedMinimizationProblem`, writes a CMPState to the CMP.""" + + if self.cmp is None: + raise RuntimeError( + """Cannot write state to a formulation which is not linked to a + ConstrainedMinimizationProblem""" + ) + + self.cmp.state = cmp_state + class UnconstrainedFormulation(Formulation): """ @@ -77,7 +89,7 @@ class UnconstrainedFormulation(Formulation): to solve and which gives rise to the Lagrangian. """ - def __init__(self, cmp: ConstrainedMinimizationProblem): + def __init__(self, cmp: Optional[ConstrainedMinimizationProblem] = None): """Construct new `UnconstrainedFormulation`""" self.cmp = cmp @@ -119,7 +131,7 @@ def composite_objective( self, closure: Callable[..., CMPState], *closure_args, - write_state: bool = True, + write_state: Optional[bool] = True, **closure_kwargs ) -> torch.Tensor: """ @@ -138,8 +150,9 @@ def composite_objective( """ cmp_state = closure(*closure_args, **closure_kwargs) - if write_state: - self.cmp.state = cmp_state + + if write_state and self.cmp is not None: + self.write_cmp_state(cmp_state) return cmp_state.loss diff --git a/cooper/formulation/lagrangian.py b/cooper/formulation/lagrangian.py index 0833aac9..52458379 100644 --- a/cooper/formulation/lagrangian.py +++ b/cooper/formulation/lagrangian.py @@ -28,7 +28,7 @@ class BaseLagrangianFormulation(Formulation, metaclass=abc.ABCMeta): def __init__( self, - cmp: ConstrainedMinimizationProblem, + cmp: Optional[ConstrainedMinimizationProblem] = None, ineq_init: Optional[torch.Tensor] = None, eq_init: Optional[torch.Tensor] = None, ): @@ -240,7 +240,7 @@ def composite_objective( closure: Callable[..., CMPState] = None, *closure_args, pre_computed_state: Optional[CMPState] = None, - write_state: bool = True, + write_state: Optional[bool] = True, **closure_kwargs ) -> torch.Tensor: """ @@ -281,8 +281,8 @@ def composite_objective( else: cmp_state = closure(*closure_args, **closure_kwargs) - if write_state: - self.cmp.state = cmp_state + if write_state and self.cmp is not None: + self.write_cmp_state(cmp_state) # Extract values from ProblemState object loss = cmp_state.loss diff --git a/cooper/optim/constrained_optimizers/alternating_optimizer.py b/cooper/optim/constrained_optimizers/alternating_optimizer.py index a44bd5d2..d9d457a0 100644 --- a/cooper/optim/constrained_optimizers/alternating_optimizer.py +++ b/cooper/optim/constrained_optimizers/alternating_optimizer.py @@ -23,7 +23,6 @@ def __init__( dual_restarts: bool = False, ): self.formulation = formulation - self.cmp = self.formulation.cmp if isinstance(primal_optimizers, torch.optim.Optimizer): self.primal_optimizers = [primal_optimizers] diff --git a/cooper/optim/constrained_optimizers/extrapolation_optimizer.py b/cooper/optim/constrained_optimizers/extrapolation_optimizer.py index 4aede18d..1fa472cd 100644 --- a/cooper/optim/constrained_optimizers/extrapolation_optimizer.py +++ b/cooper/optim/constrained_optimizers/extrapolation_optimizer.py @@ -56,7 +56,6 @@ def __init__( dual_restarts: bool = False, ): self.formulation = formulation - self.cmp = self.formulation.cmp if isinstance(primal_optimizers, ExtragradientOptimizer): self.primal_optimizers = [primal_optimizers] diff --git a/cooper/optim/constrained_optimizers/simultaneous_optimizer.py b/cooper/optim/constrained_optimizers/simultaneous_optimizer.py index 62269f4a..04577ec5 100644 --- a/cooper/optim/constrained_optimizers/simultaneous_optimizer.py +++ b/cooper/optim/constrained_optimizers/simultaneous_optimizer.py @@ -55,7 +55,6 @@ def __init__( dual_restarts: bool = False, ): self.formulation = formulation - self.cmp = self.formulation.cmp if isinstance(primal_optimizers, torch.optim.Optimizer): self.primal_optimizers = [primal_optimizers] diff --git a/cooper/optim/unconstrained_optimizer.py b/cooper/optim/unconstrained_optimizer.py index d5d1815a..1dc0dfb2 100644 --- a/cooper/optim/unconstrained_optimizer.py +++ b/cooper/optim/unconstrained_optimizer.py @@ -36,7 +36,6 @@ def __init__( ) self.formulation = formulation - self.cmp = self.formulation.cmp if isinstance(primal_optimizers, torch.optim.Optimizer): self.primal_optimizers = [primal_optimizers] diff --git a/tests/test_simplest_pipeline.py b/tests/test_simplest_pipeline.py new file mode 100644 index 00000000..9a88499e --- /dev/null +++ b/tests/test_simplest_pipeline.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python + +"""Tests for Constrained Optimizer class. This test already verifies that the +code behaves as expected for an unconstrained setting.""" + +import pytest +import torch + +import cooper + + +@pytest.fixture() +def params(): + return torch.nn.Parameter(torch.tensor([0.0, -1.0])) + + +@pytest.fixture() +def formulation(): + return cooper.LagrangianFormulation() + + +@pytest.fixture() +def constrained_optimizer(params, formulation): + primal_optim = torch.optim.SGD([params], lr=1e-2, momentum=0.3) + dual_optim = cooper.optim.partial_optimizer(torch.optim.SGD, lr=1e-2) + + return cooper.SimultaneousConstrainedOptimizer( + formulation, primal_optim, dual_optim, dual_restarts=True + ) + + +def loss_fn(params): + param_x, param_y = params + + return param_x**2 + 2 * param_y**2 + + +def defect_fn(params): + + param_x, param_y = params + + # Two inequality constraints + defect = torch.stack( + [ + -param_x - param_y + 1.0, # x + y \ge 1 + param_x**2 + param_y - 1.0, # x**2 + y \le 1.0 + ] + ) + + return defect + + +def test_simplest_pipeline(params, formulation, constrained_optimizer): + + for step_id in range(1500): + constrained_optimizer.zero_grad() + + loss = loss_fn(params) + defect = defect_fn(params) + + # Create a CMPState object to hold the loss and defect values + cmp_state = cooper.CMPState(loss=loss, ineq_defect=defect) + + lagrangian = formulation.composite_objective(pre_computed_state=cmp_state) + formulation.custom_backward(lagrangian) + + constrained_optimizer.step() + + assert torch.allclose(params[0], torch.tensor(2.0 / 3.0)) + assert torch.allclose(params[1], torch.tensor(1.0 / 3.0))