Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple constraints minimization #18

Open
dudifrid opened this issue Oct 20, 2022 · 7 comments
Open

Multiple constraints minimization #18

dudifrid opened this issue Oct 20, 2022 · 7 comments
Labels
enhancement New feature or request

Comments

@dudifrid
Copy link
Contributor

dudifrid commented Oct 20, 2022

First of all - thank you for the great repo!

Now, I have a question: is there any support in multiple constraints minimization? It is often the case, since I am minimizing a function that gets tensor as input, so I need multiple constraints for the relations between the input tensor's elements.

So far I've tried: list of dictionaries, dictionary of lists, dictionary that returns tensor with > 1 elements, and dictionary that returns list of tensors. They all failed.

Here's my toy snippet of code with one constraint:

from torchmin import minimize_constr
import torch
import torch.nn as nn

eps0 = torch.rand((2,3))

res = minimize_constr(
    lambda x : nn.L1Loss()(x.sum(),  x.sum()),
    eps0, 
    max_iter=100,
    constr=dict(
        fun=lambda x: x.square().sum(), 
        lb=1, ub=1
    ),
    disp=1
)

eps = res.x

Thanks in advance!

@rfeinman
Copy link
Owner

Alas, minimize_constr() does not currently support multiple constraints. It would be possible to add this functionality, namely by allowing fun to output a vector as opposed to a scalar (in which case lb and ub would need to be vectors of the same size). However, this would require a change to the way that jacobians and hessians are computed behind the scene; jacobian computations would now look like what were previously hessian computations, and I'm not sure what exactly hessian computations would look like (if they are possible at all).

This is a low priority right now, given that the majority of pytorch-minimize is devoted to minimize(). The minimize_constr() utility is merely a wrapper for scipy that was included as a temporary convenience. I have been meaning to write custom implementations of constrained minimization in pure PyTorch.

@dudifrid
Copy link
Contributor Author

Thanks for the rapid reply!
I need only linear constraints (but for very complicated neural network-dependent objective function), and I'd love to here from you what is best for my needs: some kind of workaround (I tried but didn't come up with any, so far) \ different suitable libraries \ is it best to implement this myself (from your answer it sounds difficult, but maybe it is easier for linear constraints?)

@rfeinman rfeinman added the enhancement New feature or request label Mar 23, 2023
@ilan-gold
Copy link

ilan-gold commented May 2, 2023

@rfeinman I am a bit of a newbie at optimization, but hopefully my math background will help me hit the ground running. I might be interested in taking this on if you're looking for someone to implement this in torch. I have a question though - I have seen people use parameters in a "logX" space or similar so that the parameters that the optimizer sees are unbounded, but they are then transformed by X**parmas to just the positive reals when evaluating the objective. What is the theoretical grounding here, if any, as a workaround? And would something similar work by using a re-scaled sigmoid of some sort (in order to put an upper bound on things)?

If you feel given your expertise that these are both bad ideas/workarounds, I would be interested in implementing constrained optimization for multiple bounds provided it is not a completely insane amount of work and you'd be willing to provide some guidance :) I have a fair bit of open source experience and I just finished up adding some features to a GPU-based numerical integration package, so I feel comfortable offering this. I also knew nothing about numerical integration before I started that and here we are 😄 so I hope I can catch on here too if need be :)

@ilan-gold
Copy link

P.S I'm interested only in box contraints, so from I can see L-BFGS-B would be the route to go, I think.

@ilan-gold
Copy link

P.P.S I would also be interested in BFGS-B if that exists. No attachment to limited memory particularly

@xavinatalia
Copy link

P.S I'm interested only in box contraints, so from I can see L-BFGS-B would be the route to go, I think.

Excuse me, I want to apply L-BFGS-B in pytorch settings and use cuda, are there any implementations of algorithm L-BFGS-B? Specifically, I prefer to the algorithm equivalent to scipy.optimize.minimize(method='L-BFGS-B')

@gnorman7
Copy link

gnorman7 commented Jul 3, 2024

I've created a wrapper around this function (scipy.optimize.minimize(method=’trust-constr’)) for my own purposes, supporting multiple constraints. However, I don't compute Hessians, instead using the scipy bfgs approximation. If there's still interest in this, I'll make it open source sooner rather than later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants