Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax BNN is several times slower in JAX 0.4.33 compared to JAX 0.4.31 #1867

Open
ziatdinovmax opened this issue Sep 25, 2024 · 4 comments
Open
Labels
jax This issue is specific to JAX performance

Comments

@ziatdinovmax
Copy link

Jax-0.4.31: Runtime: 27.06 seconds
https://colab.research.google.com/drive/1EsFY1St8Y2ZNBZ9UXTa9FDWrjPDdTU4U?usp=sharing

Jax-0.4.33: Runtime: 84.91 seconds
https://colab.research.google.com/drive/1g7GkuK4-GloO6cywvDUf5BVU9qO2jf1W?usp=sharing

I’m not sure if this issue is specific to flax_random_module or a broader problem, but I’ve primarily been using NumPyro for HMC BNNs, and the difference in speed with the latest JAX release is quite dramatic

Code:

import time
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
from numpyro.contrib.module import random_flax_module
import flax.linen as nn


# Set a random seed for reproducibility
rng_key = jax.random.PRNGKey(0)

# Generate some dummy data
def generate_data(n=100, noise_std=0.1):
    X = jnp.linspace(-1, 1, n)
    y = 3 * X + 2 + np.random.normal(0, noise_std, size=X.shape)
    return X[:, None], y

# Define a simple neural network
class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(10)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x.squeeze()

# Define the model
def model(X, y):
    module = SimpleNN()
    nn = random_flax_module("nn", module, input_shape=(1, X.shape[-1]), prior=dist.Normal(0, 1))

    with numpyro.plate("data", X.shape[0]):
        mean = nn(X)
        numpyro.sample("obs", dist.Normal(mean, 0.1), obs=y)

# Generate data
X, y = generate_data()

# Initialize the NUTS sampler
nuts_kernel = NUTS(model)

# Run inference
num_warmup, num_samples = 500, 1000

start_time = time.time()

mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key, X, y)

end_time = time.time()

# Print runtime
print(f"Runtime: {end_time - start_time:.2f} seconds")

# Print summary statistics
print(mcmc.print_summary())
@tillahoffmann
Copy link
Contributor

Hey, this may be the same issue as jax-ml/jax#23822.

@fehiepsi fehiepsi added discussion jax This issue is specific to JAX labels Sep 28, 2024
@pacargile
Copy link

I had similar issues with SVI inference when I switched from JAX 0.4.30 to 0.4.33. The suggested workaround on jax-ml/jax#23822 (i.e., setting the env variable XLA_FLAGS=--xla_cpu_use_thunk_runtime=false) seemed to return my runtimes back to what it was with JAX 0.4.30.

@tillahoffmann
Copy link
Contributor

Do you think we can close this issue given that it's very likely an upstream issue in jax?

@fehiepsi
Copy link
Member

fehiepsi commented Dec 4, 2024

Let's keep this open for visibility.

To temporarily fix the issue, we can add this to the top of the program

import os

os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
jax This issue is specific to JAX performance
Projects
None yet
Development

No branches or pull requests

4 participants