From 643cca19c3019d3b7e227d903862b1bbee2bae15 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Tue, 30 Nov 2021 16:24:11 +0100 Subject: [PATCH] change api for fisher info and gaussianity to use the model directly instead of model.logpdf wrapper --- src/relaxed/ops/fisher_information.py | 14 ++++------ src/relaxed/ops/likelihood_gaussianity.py | 10 +++++-- tests/test_ops.py | 34 ++++------------------- 3 files changed, 18 insertions(+), 40 deletions(-) diff --git a/src/relaxed/ops/fisher_information.py b/src/relaxed/ops/fisher_information.py index 04eae89..1ddea5d 100644 --- a/src/relaxed/ops/fisher_information.py +++ b/src/relaxed/ops/fisher_information.py @@ -5,21 +5,17 @@ "cramer_rao_uncert", ) -from typing import Callable +from typing import Any import jax import jax.numpy as jnp from chex import Array -def fisher_info( - logpdf: Callable[[Array, Array], float], pars: Array, data: Array -) -> Array: - return -jax.hessian(logpdf)(pars, data) +def fisher_info(model: Any, pars: Array, data: Array) -> Array: + return -jax.hessian(lambda p, d: model.logpdf(p, d)[0])(pars, data) -def cramer_rao_uncert( - logpdf: Callable[[Array, Array], float], pars: Array, data: Array -) -> Array: - inv = jnp.linalg.inv(fisher_info(logpdf, pars, data)) +def cramer_rao_uncert(model: Any, pars: Array, data: Array) -> Array: + inv = jnp.linalg.inv(fisher_info(model, pars, data)) return jnp.sqrt(jnp.diagonal(inv)) diff --git a/src/relaxed/ops/likelihood_gaussianity.py b/src/relaxed/ops/likelihood_gaussianity.py index 865fd61..6b9a4f2 100644 --- a/src/relaxed/ops/likelihood_gaussianity.py +++ b/src/relaxed/ops/likelihood_gaussianity.py @@ -11,6 +11,8 @@ from jax import jit, vmap from jax.random import PRNGKey, multivariate_normal +from relaxed.ops import fisher_info + if TYPE_CHECKING: import pyhf @@ -29,8 +31,7 @@ def gaussian_logpdf( def gaussianity( model: pyhf.Model, bestfit_pars: Array, - cov_approx: Array, - observed_data: Array, + data: Array, rng_key: PRNGKey, n_samples: int = 1000, ) -> Array: @@ -39,6 +40,9 @@ def gaussianity( # - do this across a number of points in parspace (sampled from the gaussian approx) # and take the mean squared diff # - centre the values wrt the best-fit vals to scale the differences + + cov_approx = jnp.linalg.inv(fisher_info(model, bestfit_pars, data)) + gaussian_parspace_samples = multivariate_normal( key=rng_key, mean=bestfit_pars, @@ -51,7 +55,7 @@ def gaussianity( model.logpdf(pars, data)[0] - model.logpdf(bestfit_pars, data)[0] ), # scale origin to bestfit pars in_axes=(0, None), - )(gaussian_parspace_samples, observed_data) + )(gaussian_parspace_samples, data) relative_nlls_gaussian = vmap( lambda pars, data: -( diff --git a/tests/test_ops.py b/tests/test_ops.py index 251dd74..0028d98 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -98,9 +98,6 @@ def true_grad(mu, bins): def test_fisher_info(example_model): - def model(pars, data): - return example_model.logpdf(pars, data)[0] - pars = example_model.config.suggested_init() data = example_model.expected_data(pars) @@ -108,7 +105,7 @@ def model(pars, data): # probably needs a more thorough analytic test res = np.array([[0.90909091, 9.09090909], [9.09090909, 290.90909091]]) - assert np.allclose(relaxed.fisher_info(model, pars, data), res) + assert np.allclose(relaxed.fisher_info(example_model, pars, data), res) def test_fisher_uncerts_validity(): @@ -133,34 +130,24 @@ def test_fisher_uncerts_validity(): mle_pars, mle_uncerts = fit_res[:, 0], fit_res[:, 1] # uncertainties from autodiff hessian - def lpdf(p, d): - return m.logpdf(p, d)[0] - - relaxed_uncerts = relaxed.cramer_rao_uncert(lpdf, mle_pars, data) + relaxed_uncerts = relaxed.cramer_rao_uncert(m, mle_pars, data) assert np.allclose(mle_uncerts, relaxed_uncerts, rtol=5e-2) def test_fisher_info_grad(example_model): def pipeline(x): pars = example_model.config.suggested_init() - - def model(pars, data): - return example_model.logpdf(pars, data)[0] - data = example_model.expected_data(pars) - return relaxed.fisher_info(model, pars * x, data * x) + return relaxed.fisher_info(example_model, pars * x, data * x) jacrev(pipeline)(4.0) # just check you can calc it w/o exception def test_fisher_uncert_grad(example_model): def pipeline(x): - def model(pars, data): - return example_model.logpdf(pars, data)[0] - pars = example_model.config.suggested_init() data = example_model.expected_data(pars) - return relaxed.cramer_rao_uncert(model, pars * x, data * x) + return relaxed.cramer_rao_uncert(example_model, pars * x, data * x) jacrev(pipeline)(4.0) # just check you can calc it w/o exception @@ -170,22 +157,13 @@ def test_gaussianity(): m = pyhf.simplemodels.uncorrelated_background([5, 5], [50, 50], [5, 5]) pars = jnp.asarray(m.config.suggested_init()) data = jnp.asarray(m.expected_data(pars)) - cov_approx = jnp.linalg.inv( - relaxed.fisher_info(lambda d, p: m.logpdf(d, p)[0], pars, data) - ) - relaxed.gaussianity(m, pars, cov_approx, data, PRNGKey(0)) + relaxed.gaussianity(m, pars, data, PRNGKey(0)) def test_gaussianity_grad(example_model): def pipeline(x): - def model(pars, data): - return example_model.logpdf(pars, data)[0] - pars = example_model.config.suggested_init() data = example_model.expected_data(pars) - cov_approx = jnp.linalg.inv(relaxed.fisher_info(model, pars, data)) - return relaxed.gaussianity( - example_model, pars * x, cov_approx * x, data * x, PRNGKey(0) - ) + return relaxed.gaussianity(example_model, pars * x, data * x, PRNGKey(0)) jacrev(pipeline)(4.0) # just check you can calc it w/o exception