Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing improvement: Preserve the wrapped callable's signature in jax.jit #23719

Open
2 tasks done
lebrice opened this issue Sep 18, 2024 · 1 comment
Open
2 tasks done
Labels
enhancement New feature or request

Comments

@lebrice
Copy link

lebrice commented Sep 18, 2024

  • Check for duplicate requests.
  • Describe your goal, and if possible provide a code snippet with a motivating example.

Hello there. Simple feature request / bug report:

Currently, jax.jit drops the typing signature of the wrapped callable.
For example this code shows no warnings in a code editor:

import jax

@jax.jit
def foo(a: jax.Array) -> jax.Array:
    return a

foo(bob=123)  # type-checker should display a warning!

Same goes for functions or methods annotated with a functools.partial of jax.jit: The signature of the wrapped callable is dropped:

import jax
import functools

@functools.partial(jax.jit, static_argnames=["some_static_arg"])
def foo_with_static_arg(a: jax.Array, some_static_arg: Any) -> jax.Array:
    return a

foo_with_static_arg(bob=123)  # type-checker should also display a warning here!

Is there something I'm not aware of that might make this undesirable for some reason?

In the meantime, I made #23720 to address this.
Let me know what you think :)

@lebrice lebrice added the enhancement New feature or request label Sep 18, 2024
lebrice added a commit to lebrice/jax that referenced this issue Sep 18, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

Crossref #14688, which was a previous attempt to solve this.

lebrice added a commit to lebrice/jax that referenced this issue Sep 18, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
lebrice added a commit to lebrice/jax that referenced this issue Sep 18, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
lebrice added a commit to lebrice/jax that referenced this issue Sep 18, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
lebrice added a commit to lebrice/jax that referenced this issue Sep 19, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
lebrice added a commit to lebrice/jax that referenced this issue Sep 19, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
lebrice added a commit to lebrice/jax that referenced this issue Sep 21, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
lebrice added a commit to lebrice/jax that referenced this issue Sep 21, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
lebrice added a commit to lebrice/jax that referenced this issue Sep 21, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
lebrice added a commit to lebrice/jax that referenced this issue Sep 21, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
lebrice added a commit to lebrice/jax that referenced this issue Sep 21, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
lebrice added a commit to lebrice/jax that referenced this issue Sep 26, 2024
- Fix for jax-ml#23719

Signed-off-by: Fabrice Normandin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants