Skip to content

Commit

Permalink
Refactor syntax for backward call and Lagrangian computation in Formu…
Browse files Browse the repository at this point in the history
…lation (#59)

Closes #58

## Changes

* Removed `_populate_gradients` method from Formulation in favor of `backward`
* Removed `composite_objective` method from Formulation in favor of `compute_lagrangian`
* Applied corresponding refactoring to docs

Co-authored-by: juan43ramirez <[email protected]>
  • Loading branch information
gallego-posada and juan43ramirez authored Nov 4, 2022
1 parent 62e384b commit 3e0aa6f
Show file tree
Hide file tree
Showing 25 changed files with 76 additions and 84 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ constrained_optimizer = cooper.ExtrapolationConstrainedOptimizer(formulation, pr
# The steps follow closely the `loss -> backward -> step` Pytorch workflow.
for iter_num in range(5000):
constrained_optimizer.zero_grad()
lagrangian = formulation.composite_objective(cmp.closure, probs)
formulation.custom_backward(lagrangian)
lagrangian = formulation.compute_lagrangian(cmp.closure, probs)
formulation.backward(lagrangian)
constrained_optimizer.step(cmp.closure, probs)
```

Expand Down
8 changes: 4 additions & 4 deletions cooper/formulation/augmented_lagrangian.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def weighted_violation(

# This is the violation of the "actual" constraint. We use this
# to update the value of the multipliers by lazily filling the
# multiplier gradients in `populate_gradients`
# multiplier gradients in `backward`

# TODO (JGP): Verify that call to backward is general enough for
# Lagrange Multiplier models
Expand All @@ -98,7 +98,7 @@ def weighted_violation(
return proxy_violation, sq_proxy_violation

@no_type_check
def composite_objective(
def compute_lagrangian(
self,
aug_lag_coeff_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
closure: Callable[..., CMPState] = None,
Expand All @@ -114,13 +114,13 @@ def composite_objective(
If no explicit proxy-constraints are provided, we use the given
inequality/equality constraints to compute the Augmented Lagrangian and
to populate the primal and dual gradients. Note that gradients are _not_
populated by this function, but rather :py:meth:`._populate_gradient`.
populated by this function, but rather :py:meth:`.backward`.
In case proxy constraints are provided in the CMPState, the non-proxy
constraints (potentially non-differentiable) are used for computing the
value of the Augmented Lagrangian. The accumulated proxy-constraints
are used in the backward computation triggered by
:py:meth:`._populate_gradient` (and thus must be differentiable).
:py:meth:`.backward` (and thus must be differentiable).
Args:
closure: Callable returning a :py:class:`cooper.problem.CMPState`
Expand Down
26 changes: 9 additions & 17 deletions cooper/formulation/formulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,16 @@ def dual_parameters(self):
pass

@abc.abstractmethod
def composite_objective(self):
def compute_lagrangian(self):
pass

@abc.abstractmethod
def _populate_gradients(self, *args, **kwargs):
"""Performs the actual backward computation and populates the gradients
for the trainable parameters for the dual variables."""
def backward(self, *args, **kwargs):
"""Performs the backward computation and populates the gradients
for the primal and dual variables according to the design of the
formulation."""
pass

def custom_backward(self, *args, **kwargs):
"""Alias for :py:meth:`._populate_gradients` to keep the ``backward``
naming convention used in Pytorch. For clarity, we avoid naming this
method ``backward`` as it is a method of the ``LagrangianFormulation``
object and not a method of a :py:class:`torch.Tensor` as is standard in
Pytorch.
"""
self._populate_gradients(*args, **kwargs)

def write_cmp_state(self, cmp_state: CMPState):
"""Provided that the formulation is linked to a
`ConstrainedMinimizationProblem`, writes a CMPState to the CMP."""
Expand Down Expand Up @@ -127,7 +119,7 @@ def load_state_dict(self, state_dict: dict):
"""
pass

def composite_objective(
def compute_lagrangian(
self,
closure: Callable[..., CMPState],
*closure_args,
Expand Down Expand Up @@ -156,10 +148,10 @@ def composite_objective(

return cmp_state.loss

def _populate_gradients(self, loss: torch.Tensor):
def backward(self, loss: torch.Tensor):
"""
Performs the actual backward computation which populates the gradients
for the primal variables.
Performs the backward computation which populates the gradients for the
primal variables.
Args:
loss: Loss tensor for computing gradients for primal variables.
Expand Down
8 changes: 4 additions & 4 deletions cooper/formulation/lagrangian.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class LagrangianFormulation(BaseLagrangianFormulation):
"""

@no_type_check
def composite_objective(
def compute_lagrangian(
self,
closure: Callable[..., CMPState] = None,
*closure_args,
Expand All @@ -250,13 +250,13 @@ def composite_objective(
If no explicit proxy-constraints are provided, we use the given
inequality/equality constraints to compute the Lagrangian and to
populate the primal and dual gradients. Note that gradients are _not_
populated by this function, but rather :py:meth:`._populate_gradient`.
populated by this function, but rather :py:meth:`.backward`.
In case proxy constraints are provided in the CMPState, the non-proxy
constraints (potentially non-differentiable) are used for computing the
value of the Lagrangian. The accumulated proxy-constraints are used in
the backward computation triggered by
:py:meth:`._populate_gradient` (and thus must be differentiable).
:py:meth:`.backward` (and thus must be differentiable).
Args:
closure: Callable returning a :py:class:`cooper.problem.CMPState`
Expand Down Expand Up @@ -361,7 +361,7 @@ def weighted_violation(
return proxy_violation

@no_type_check
def _populate_gradients(
def backward(
self,
lagrangian: torch.Tensor,
ignore_primal: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion cooper/multipliers/multipliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def restart_if_feasible_(self):

assert self.positive, "Restarts is only supported for inequality multipliers"

# Call to formulation._populate_gradients has already flipped sign
# Call to formulation.backwards has already flipped sign
# A currently *positive* gradient means original defect is negative, so
# the constraint is being satisfied.

Expand Down
6 changes: 3 additions & 3 deletions cooper/optim/constrained_optimizers/alternating_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ def populate_alternating_dual_gradient(
if isinstance(self.formulation, AugmentedLagrangianFormulation):
# Use LR of dual optimizer as penalty coefficient for the augmented
# Lagrangian
_ = self.formulation.composite_objective(
_ = self.formulation.compute_lagrangian(
closure=None,
aug_lag_coeff_scheduler=self.dual_scheduler,
pre_computed_state=alternate_cmp_state,
write_state=True,
) # type: ignore
else:
_ = self.formulation.composite_objective(
_ = self.formulation.compute_lagrangian(
closure=None, pre_computed_state=alternate_cmp_state, write_state=True
) # type: ignore
# Zero-out gradients for dual variables since they were already
Expand All @@ -125,7 +125,7 @@ def populate_alternating_dual_gradient(

# Not passing lagrangian since we only want to update the gradients for
# the dual variables
self.formulation._populate_gradients(
self.formulation.backward(
lagrangian=None, ignore_primal=True, ignore_dual=False
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ def step(
# For extrapolation, we need closure args here as the parameter
# values will have changed in the update applied on the
# extrapolation step
lagrangian = self.formulation.composite_objective(
lagrangian = self.formulation.compute_lagrangian(
closure, *closure_args, **closure_kwargs
) # type: ignore

# Populate gradients at extrapolation point
self.formulation.custom_backward(lagrangian)
self.formulation.backward(lagrangian)

# After this, the calls to `step` will update the stored copies with
# the newly computed gradients
Expand Down
8 changes: 4 additions & 4 deletions docs/source/additional_features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Alternating updates
It is possible to perform alternating updates between the primal and dual
parameters by setting the flag ``alternating=True`` in the construction of the
:py:class:`ConstrainedOptimizer`. In this case, the gradient computed by calling
:py:meth:`~cooper.formulation.Formulation.custom_backward` is used to update the
:py:meth:`~cooper.formulation.Formulation.backward` is used to update the
primal parameters. Then, the gradient with respect to the dual variables (given
the new value of the primal parameters!) is computed and used to update the dual
variables. This two-stage process is handled by **Cooper** inside the
Expand Down Expand Up @@ -157,12 +157,12 @@ Example
for step_id in range(1000):
coop.zero_grad()
lagrangian = formulation.composite_objective(
lagrangian = formulation.compute_lagrangian(
aug_lag_coeff_scheduler=coop.dual_scheduler,
closure=cmp.closure,
params=params,
)
formulation.custom_backward(lagrangian)
formulation.backward(lagrangian)
# We need to pass the closure or defect_fn to perform the alternating updates
# required by the Augmented Lagrangian method.
Expand Down Expand Up @@ -241,4 +241,4 @@ treated "as if they were a single optimizer". In particular, all primal optimize
operations such as :py:meth:`optimizer.step()<torch.optim.Optimizer.step>` are
executed simultaneously (without intermediate calls to
:py:meth:`cmp.closure()<cooper.problem.ConstrainedMinimizationProblem.closure>` or
:py:meth:`formulation.custom_backward(lagrangian)<cooper.formulation.Formulation.custom_backward>`).
:py:meth:`formulation.backward(lagrangian)<cooper.formulation.Formulation.backward>`).
8 changes: 4 additions & 4 deletions docs/source/constrained_optimizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ will involve the following steps:

#. (Optional) Iterate over your dataset and sample of mini-batch.
#. Call :py:meth:`constrained_optimizer.zero_grad()<zero_grad>` to reset the parameters' gradients
#. Compute the current :py:class:`CMPState` (or estimate it with the minibatch) and calculate the Lagrangian using :py:meth:`lagrangian.composite_objective(cmp.closure, ...)<cooper.formulation.LagrangianFormulation.composite_objective>`.
#. Populate the primal and dual gradients with :py:meth:`formulation.custom_backward(lagrangian)<cooper.formulation.Formulation.custom_backward>`
#. Compute the current :py:class:`CMPState` (or estimate it with the minibatch) and calculate the Lagrangian using :py:meth:`formulation.compute_lagrangian(cmp.closure, ...)<cooper.formulation.LagrangianFormulation.compute_lagrangian>`.
#. Populate the primal and dual gradients with :py:meth:`formulation.backward(lagrangian)<cooper.formulation.Formulation.backward>`
#. Perform updates on the parameters using the primal and dual optimizers based on the recently computed gradients, via a call to :py:meth:`constrained_optimizer.step()<step>`.

Example
Expand Down Expand Up @@ -140,10 +140,10 @@ Example
# The closure is required to compute the Lagrangian
# The closure might in turn require the model, inputs, targets, etc.
lagrangian = formulation.composite_objective(cmp.closure, ...)
lagrangian = formulation.compute_lagrangian(cmp.closure, ...)
# Populate the primal and dual gradients
formulation.custom_backward(lagrangian)
formulation.backward(lagrangian)
# Perform primal and dual parameter updates
constrained_optimizer.step()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/multipliers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Multipliers
defects provided by the :py:class:`~cooper.problem.CMPState` of the
considered :py:class:`~cooper.problem.ConstrainedMinimizationProblem`.
- Using them for computing Lagrangians in the
:py:meth:`~cooper.formulation.lagrangian.LagrangianFormulation.composite_objective`
:py:meth:`~cooper.formulation.lagrangian.LagrangianFormulation.compute_lagrangian`
method of :py:class:`~cooper.formulation.lagrangian.LagrangianFormulation`.

Constructing a DenseMultiplier
Expand Down
4 changes: 2 additions & 2 deletions docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ extra-gradient in the context of solving Variational Inequality Problems.
for step in range(num_steps):
const_optim.zero_grad()
lagrangian = formulation.composite_objective(cmp.closure, model, inputs)
formulation.custom_backward(lagrangian)
lagrangian = formulation.compute_lagrangian(cmp.closure, model, inputs)
formulation.backward(lagrangian)
# Non-extra-gradient optimizers
# Passing (cmp.closure, model, inputs) to step will simply be ignored
Expand Down
8 changes: 4 additions & 4 deletions tests/test_alternating_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_manual_alternating_proxy(aim_device):

# ----------------------- First iteration -----------------------
coop.zero_grad()
lagrangian = formulation.composite_objective(cmp.closure, params)
lagrangian = formulation.compute_lagrangian(cmp.closure, params)

# Check loss, proxy and non-proxy defects after forward pass
assert torch.allclose(lagrangian, mktensor(2.0))
Expand All @@ -44,7 +44,7 @@ def test_manual_alternating_proxy(aim_device):

# Check primal and dual gradients after backward. Dual gradient must match
# ineq_defect
formulation.custom_backward(lagrangian)
formulation.backward(lagrangian)
assert torch.allclose(params.grad, mktensor([0.0, -4.0]))
assert torch.allclose(formulation.state()[0].grad, cmp.state.ineq_defect)

Expand All @@ -60,7 +60,7 @@ def test_manual_alternating_proxy(aim_device):

# ----------------------- Second iteration -----------------------
coop.zero_grad()
lagrangian = formulation.composite_objective(cmp.closure, params)
lagrangian = formulation.compute_lagrangian(cmp.closure, params)

# Check loss, proxy and non-proxy defects after forward pass
assert torch.allclose(lagrangian, mktensor(1.3124))
Expand All @@ -70,7 +70,7 @@ def test_manual_alternating_proxy(aim_device):

# Check primal and dual gradients after backward. Dual gradient must match
# ineq_defect
formulation.custom_backward(lagrangian)
formulation.backward(lagrangian)
assert torch.allclose(params.grad, mktensor([-0.0162, -3.218]))
assert torch.allclose(formulation.state()[0].grad, cmp.state.ineq_defect)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_alternating_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_manual_alternating(aim_device, alternating, use_defect_fn):
defect_fn = cmp.defect_fn if use_defect_fn else None

coop.zero_grad()
lagrangian = formulation.composite_objective(cmp.closure, params)
lagrangian = formulation.compute_lagrangian(cmp.closure, params)

# Check loss, proxy and non-proxy defects after forward pass
assert torch.allclose(lagrangian, mktensor(2.0))
Expand All @@ -51,7 +51,7 @@ def test_manual_alternating(aim_device, alternating, use_defect_fn):

# Check primal and dual gradients after backward. Dual gradient must match
# ineq_defect
formulation.custom_backward(lagrangian)
formulation.backward(lagrangian)
assert torch.allclose(params.grad, mktensor([0.0, -4.0]))
assert torch.allclose(formulation.state()[0].grad, cmp.state.ineq_defect)

Expand Down Expand Up @@ -96,8 +96,8 @@ def test_convergence_alternating(aim_device, alternating, use_defect_fn):
coop.zero_grad()

# When using the unconstrained formulation, lagrangian = loss
lagrangian = formulation.composite_objective(closure=cmp.closure, params=params)
formulation.custom_backward(lagrangian)
lagrangian = formulation.compute_lagrangian(closure=cmp.closure, params=params)
formulation.backward(lagrangian)

# Need to pass closure to step function to perform alternating updates
if use_defect_fn:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_augmented_lagrangian.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ def test_convergence_augmented_lagrangian(aim_device):
for step_id in range(1500):
coop.zero_grad()

lagrangian = formulation.composite_objective(
lagrangian = formulation.compute_lagrangian(
aug_lag_coeff_scheduler=coop.dual_scheduler,
closure=cmp.closure,
params=params,
)
formulation.custom_backward(lagrangian)
formulation.backward(lagrangian)

coop.step(defect_fn=cmp.defect_fn, params=params)
coop.dual_scheduler.step()
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_manual_augmented_lagrangian(aim_device):

coop.zero_grad()

lagrangian = formulation.composite_objective(
lagrangian = formulation.compute_lagrangian(
aug_lag_coeff_scheduler=coop.dual_scheduler,
closure=cmp.closure,
params=params,
Expand All @@ -136,7 +136,7 @@ def test_manual_augmented_lagrangian(aim_device):
assert torch.allclose(cmp.state.loss, mktensor(2.0))
assert torch.allclose(lagrangian, mktensor(4.0))

formulation.custom_backward(lagrangian)
formulation.backward(lagrangian)

assert torch.allclose(params.grad, mktensor([-2.0, -6.0]))

Expand All @@ -152,7 +152,7 @@ def test_manual_augmented_lagrangian(aim_device):

coop.zero_grad()

lagrangian = formulation.composite_objective(
lagrangian = formulation.compute_lagrangian(
aug_lag_coeff_scheduler=coop.dual_scheduler,
closure=cmp.closure,
params=params,
Expand All @@ -161,7 +161,7 @@ def test_manual_augmented_lagrangian(aim_device):
assert torch.allclose(cmp.state.loss, mktensor(1.7676))
assert torch.allclose(lagrangian, mktensor(7.2972))

formulation.custom_backward(lagrangian)
formulation.backward(lagrangian)

assert torch.allclose(params.grad, mktensor([-3.8, -7.6]))

Expand Down
4 changes: 2 additions & 2 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def train_for_n_steps(coop, cmp, params, n_step=100):
coop.zero_grad()

# When using the unconstrained formulation, lagrangian = loss
lagrangian = coop.formulation.composite_objective(cmp.closure, params)
coop.formulation.custom_backward(lagrangian)
lagrangian = coop.formulation.compute_lagrangian(cmp.closure, params)
coop.formulation.backward(lagrangian)

coop.step()

Expand Down
Loading

0 comments on commit 3e0aa6f

Please sign in to comment.