Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OSQP crashing on unexpected params #570

Open
Illviljan opened this issue Jan 16, 2024 · 3 comments
Open

OSQP crashing on unexpected params #570

Illviljan opened this issue Jan 16, 2024 · 3 comments
Assignees
Labels
documentation Improvements or additions to documentation

Comments

@Illviljan
Copy link

I'm trying to move over some qp code to jaxopt, but I'm struggling to understand the cryptic errors that appears to only happen in the jaxopt implementation. I've tried with other packages and these params work with those implementations.

Here's a minimal example:

import numpy as np
import jax.numpy as jnp
from jaxopt import OSQP

from qpsolvers import solve_qp


def to_numpy(*args):
    return tuple(np.asarray(v) for v in args)


P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[-1.0]])
h_ = jnp.array([2.0])
A_ = jnp.array([[]], dtype=float).T
b_ = jnp.array([], dtype=float)

x = solve_qp(*to_numpy(P_, q_, G_, h_, A_, b_), solver="osqp")  # works

qp = OSQP()
deltas = qp.run(
    params_obj=(P_, q_),
    params_eq=(A_, b_),
    params_ineq=(G_, h_),
).params.primal  # Crashes with cryptic error.
# TypeError: dot_general requires contracting dimensions to have the same shape, got (1,) and (2,).

# jax-0.4.23 jaxlib-0.4.23 jaxopt-0.8.3 ml-dtypes-0.3.2 opt-einsum-3.3.0
@Algue-Rythme
Copy link
Collaborator

Algue-Rythme commented Jan 17, 2024

Hi Illviljan

Sorry for the cryptic error message. The error comes from the fact that the matrix A_ = jnp.array([[]], dtype=float).T is not a valid linear operator. If you don't need equality constraints you just need to pass None to params_eq:

P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[-1.0]])
h_ = jnp.array([2.0])
# A_ = jnp.array([[]], dtype=float).T
# b_ = jnp.array([], dtype=float)

qp = OSQP()
deltas = qp.run(
    params_obj=(P_, q_),
    params_eq=None,   # CHANGE HERE.
    params_ineq=(G_, h_),
).params.primal

Similarly, if you don't need inequality constraints just pass None to params_ineq. Thank you for your message, I just came to the realization that I forgot to document this functionnality.

@Algue-Rythme Algue-Rythme self-assigned this Jan 17, 2024
@Algue-Rythme Algue-Rythme added the documentation Improvements or additions to documentation label Jan 17, 2024
@Illviljan
Copy link
Author

Illviljan commented Jan 17, 2024

Thank you, a quite simple fix. I maybe just need to continue with all constraints active in my larger project.

I get surprised because it seems to me that jaxopt is the odd one out since A_ is valid in other qp packages.

Using None is fine I guess, the annoying part is that jaxopt doesn't allow both constraints to be None. Other packages allows that and I think it aligns more with how I build a new solution; start simple without any constraints and make sure it works, slowly add more constraints until the solution makes sense.

import numpy as np
import jax.numpy as jnp
from jaxopt import OSQP

from qpsolvers import solve_qp


def to_numpy(*args):
    return tuple(np.asarray(v) for v in args)


P_ = jnp.array([[576.0]])
q_ = jnp.array([-216.0])
G_ = jnp.array([[]], dtype=float).T
h_ = jnp.array([], dtype=float)
A_ = jnp.array([[]], dtype=float).T
b_ = jnp.array([], dtype=float)

x = solve_qp(*to_numpy(P_, q_, G_, h_, A_, b_), solver="osqp")  # works
print(x)

qp = OSQP()
x = qp.run(
    params_obj=(P_, q_),
    params_eq=None,
    params_ineq=None,
).params.primal  # Unnecessarily strict crash

@Algue-Rythme
Copy link
Collaborator

Algue-Rythme commented Jan 18, 2024

That's true ; but using OSQP when you don't have constraints is overkill. In this case OSQP algorithm degenerates toward an inefficient way to solve a linear system.

As argued in the documentation you should revert to conjugate gradient in this case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants