Skip to content

Commit

Permalink
change api for fisher info and gaussianity to use the model directly …
Browse files Browse the repository at this point in the history
…instead of model.logpdf wrapper
  • Loading branch information
phinate committed Nov 30, 2021
1 parent 947b312 commit 643cca1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 40 deletions.
14 changes: 5 additions & 9 deletions src/relaxed/ops/fisher_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,17 @@
"cramer_rao_uncert",
)

from typing import Callable
from typing import Any

import jax
import jax.numpy as jnp
from chex import Array


def fisher_info(
logpdf: Callable[[Array, Array], float], pars: Array, data: Array
) -> Array:
return -jax.hessian(logpdf)(pars, data)
def fisher_info(model: Any, pars: Array, data: Array) -> Array:
return -jax.hessian(lambda p, d: model.logpdf(p, d)[0])(pars, data)


def cramer_rao_uncert(
logpdf: Callable[[Array, Array], float], pars: Array, data: Array
) -> Array:
inv = jnp.linalg.inv(fisher_info(logpdf, pars, data))
def cramer_rao_uncert(model: Any, pars: Array, data: Array) -> Array:
inv = jnp.linalg.inv(fisher_info(model, pars, data))
return jnp.sqrt(jnp.diagonal(inv))
10 changes: 7 additions & 3 deletions src/relaxed/ops/likelihood_gaussianity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from jax import jit, vmap
from jax.random import PRNGKey, multivariate_normal

from relaxed.ops import fisher_info

if TYPE_CHECKING:
import pyhf

Expand All @@ -29,8 +31,7 @@ def gaussian_logpdf(
def gaussianity(
model: pyhf.Model,
bestfit_pars: Array,
cov_approx: Array,
observed_data: Array,
data: Array,
rng_key: PRNGKey,
n_samples: int = 1000,
) -> Array:
Expand All @@ -39,6 +40,9 @@ def gaussianity(
# - do this across a number of points in parspace (sampled from the gaussian approx)
# and take the mean squared diff
# - centre the values wrt the best-fit vals to scale the differences

cov_approx = jnp.linalg.inv(fisher_info(model, bestfit_pars, data))

gaussian_parspace_samples = multivariate_normal(
key=rng_key,
mean=bestfit_pars,
Expand All @@ -51,7 +55,7 @@ def gaussianity(
model.logpdf(pars, data)[0] - model.logpdf(bestfit_pars, data)[0]
), # scale origin to bestfit pars
in_axes=(0, None),
)(gaussian_parspace_samples, observed_data)
)(gaussian_parspace_samples, data)

relative_nlls_gaussian = vmap(
lambda pars, data: -(
Expand Down
34 changes: 6 additions & 28 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,14 @@ def true_grad(mu, bins):


def test_fisher_info(example_model):
def model(pars, data):
return example_model.logpdf(pars, data)[0]

pars = example_model.config.suggested_init()
data = example_model.expected_data(pars)

# this is just the computed output, assumed correct
# probably needs a more thorough analytic test
res = np.array([[0.90909091, 9.09090909], [9.09090909, 290.90909091]])

assert np.allclose(relaxed.fisher_info(model, pars, data), res)
assert np.allclose(relaxed.fisher_info(example_model, pars, data), res)


def test_fisher_uncerts_validity():
Expand All @@ -133,34 +130,24 @@ def test_fisher_uncerts_validity():
mle_pars, mle_uncerts = fit_res[:, 0], fit_res[:, 1]

# uncertainties from autodiff hessian
def lpdf(p, d):
return m.logpdf(p, d)[0]

relaxed_uncerts = relaxed.cramer_rao_uncert(lpdf, mle_pars, data)
relaxed_uncerts = relaxed.cramer_rao_uncert(m, mle_pars, data)
assert np.allclose(mle_uncerts, relaxed_uncerts, rtol=5e-2)


def test_fisher_info_grad(example_model):
def pipeline(x):
pars = example_model.config.suggested_init()

def model(pars, data):
return example_model.logpdf(pars, data)[0]

data = example_model.expected_data(pars)
return relaxed.fisher_info(model, pars * x, data * x)
return relaxed.fisher_info(example_model, pars * x, data * x)

jacrev(pipeline)(4.0) # just check you can calc it w/o exception


def test_fisher_uncert_grad(example_model):
def pipeline(x):
def model(pars, data):
return example_model.logpdf(pars, data)[0]

pars = example_model.config.suggested_init()
data = example_model.expected_data(pars)
return relaxed.cramer_rao_uncert(model, pars * x, data * x)
return relaxed.cramer_rao_uncert(example_model, pars * x, data * x)

jacrev(pipeline)(4.0) # just check you can calc it w/o exception

Expand All @@ -170,22 +157,13 @@ def test_gaussianity():
m = pyhf.simplemodels.uncorrelated_background([5, 5], [50, 50], [5, 5])
pars = jnp.asarray(m.config.suggested_init())
data = jnp.asarray(m.expected_data(pars))
cov_approx = jnp.linalg.inv(
relaxed.fisher_info(lambda d, p: m.logpdf(d, p)[0], pars, data)
)
relaxed.gaussianity(m, pars, cov_approx, data, PRNGKey(0))
relaxed.gaussianity(m, pars, data, PRNGKey(0))


def test_gaussianity_grad(example_model):
def pipeline(x):
def model(pars, data):
return example_model.logpdf(pars, data)[0]

pars = example_model.config.suggested_init()
data = example_model.expected_data(pars)
cov_approx = jnp.linalg.inv(relaxed.fisher_info(model, pars, data))
return relaxed.gaussianity(
example_model, pars * x, cov_approx * x, data * x, PRNGKey(0)
)
return relaxed.gaussianity(example_model, pars * x, data * x, PRNGKey(0))

jacrev(pipeline)(4.0) # just check you can calc it w/o exception

0 comments on commit 643cca1

Please sign in to comment.