Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential random key reuse #1949

Open
miguelbiron opened this issue Jan 14, 2025 · 0 comments
Open

Potential random key reuse #1949

miguelbiron opened this issue Jan 14, 2025 · 0 comments

Comments

@miguelbiron
Copy link

Hi -- I think that the MCMC module might be reusing a random key at some point. I detected this while checking my own code with the optional JAX feature jax_debug_key_reuse turned on. At least NUTS and BarkerMH suffer from this. Example:

import jax

# activate experimental check for key reuse
jax.config.update('jax_debug_key_reuse', True)  

from jax import random
from jax import numpy as jnp
from numpyro.infer import MCMC,BarkerMH

def gaussian_potential(x):
    return ((x - 2) ** 2).sum()

kernel = BarkerMH(potential_fn=gaussian_potential)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, progress_bar=False)
mcmc.run(random.key(0), init_params=jnp.array([1., 2.]))

This raises a KeyReuseError

Traceback (most recent call last):
  File "<python-input-0>", line 13, in <module>
    mcmc.run(random.key(0), init_params=jnp.array([1., 2.]))
    ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mbiron/projects/layermodels/.venv/lib/python3.13/site-packages/numpyro/infer/mcmc.py", line 702, in run
    states_flat, last_state = partial_map_fn(map_args)
                              ~~~~~~~~~~~~~~^^^^^^^^^^
  File "/home/mbiron/projects/layermodels/.venv/lib/python3.13/site-packages/numpyro/infer/mcmc.py", line 489, in _single_chain_mcmc
    collect_vals = fori_collect(
        lower_idx,
    ...<14 lines>...
        else 1,
    )
  File "/home/mbiron/projects/layermodels/.venv/lib/python3.13/site-packages/numpyro/util.py", line 369, in fori_collect
    last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)
                                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
jax.errors.KeyReuseError: Previously-consumed key passed to jit-compiled function at index 18
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.KeyReuseError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant