Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential extension to compute_log_probs to facilitate VI model diagnostics #1939

Closed
hessammehr opened this issue Dec 19, 2024 · 3 comments
Closed
Labels
question Further information is requested

Comments

@hessammehr
Copy link
Contributor

It often happens that you want to diagnose a VI fit, specifically examining how well the guide fits the prior, the data, etc. So far, I've been using a function like the following but would be interested to know if there are better established alternatives and, if not, whether it would be an appropriate as a backwards compatible extension to the newly introduced (and very useful) compute_log_probs function (or perhaps as a separate function).

def compute_log_probs(
    model,
    model_args: tuple,
    model_kwargs: dict,
    model_params: dict,
    guide=None,
    guide_params:dict=None,
    sum_log_prob: bool = True,
):
    from numpyro.infer.util import compute_log_probs as clp
    from numpyro.handlers import trace, replay, substitute
    if guide:
        guide_trace = trace(substitute(guide, guide_params or {})).get_trace(*model_args, **model_kwargs)
        model = replay(model, guide_trace)
    return clp(model, model_args, model_kwargs, model_params, sum_log_prob=sum_log_prob)
@fehiepsi fehiepsi added the question Further information is requested label Dec 20, 2024
@fehiepsi
Copy link
Member

Sorry for the late response. I think your implementation is correct. Re introducing new behavior compute_log_probs, I guess it is unnecessary. Maybe @tillahoffmann has other opinion on this.

@tillahoffmann
Copy link
Collaborator

Yes, I think this seems like a good implementation.

Having said that, I'm not sure if adding the two extra arguments might overload the function a little and lead to more complex signatures down the line, e.g., should we also include a rng_key in the signature or should the seeding of the guide happen outside but the parameter substitution inside compute_log_probs? Maybe a separate function would be better to keep compute_log_probs doing one thing only? E.g., compute_log_probs is also relevant for MCMC sampling but the guide isn't. I can see an argument for either though. What do you think?

Do you know how often the pattern compute_log_probs(replay(model, trace(substitute(guide, guide_params)))) appears across the code base?

@fehiepsi
Copy link
Member

I agree that having the extension is unnecessary. Combining with handlers looks nicer to me. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants