Skip to content

Commit

Permalink
Merge pull request #61 from gradhep/dict-pars
Browse files Browse the repository at this point in the history
Refactor that assumes parameters are in a key-value mapping
  • Loading branch information
phinate authored Aug 4, 2023
2 parents 2bff240 + 17a4b28 commit 7c483c8
Show file tree
Hide file tree
Showing 11 changed files with 419 additions and 296 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.11"]
python-version: ["3.9", "3.11"]
runs-on: [ubuntu-latest]

steps:
Expand All @@ -57,10 +57,15 @@ jobs:
- name: Install package
run: python -m pip install .[test]

- name: Install jaxopt from source (temp)
- name: Install jaxopt from source (temp -- my lbfgs fix)
run:
python -m pip install --upgrade --force-reinstall git+https://github.com/google/jaxopt.git

- name: Install pyhf from source (temp -- jax devicearray not found fix)
run:
python -m pip install --upgrade --force-reinstall
git+https://github.com/scikit-hep/pyhf.git

- name: Test package
run: >-
pytest -ra --cov --cov-report=xml --cov-report=term --durations=20
Expand Down
12 changes: 5 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "Differentiable versions of common HEP operations."
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
classifiers = [
"Development Status :: 1 - Planning",
"Intended Audience :: Science/Research",
Expand All @@ -20,7 +20,6 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand All @@ -30,17 +29,16 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"equinox",
"jaxopt>=0.7", # LBFGSB
"optax>=0.1.2", # deprecated jax.tree_multimap
"equinox>=0.10.6", # eqx.field
"jaxopt>=0.7", # LBFGSB
"typing_extensions >=4.6; python_version<'3.11'",
]

[project.optional-dependencies]
test = [
"pytest >=6",
"pytest-cov >=3",
"pyhf[jax] >=0.7.0",
"pyhf[jax]", # should be pyhf[jax] >=0.7.1, but not released yet
"iminuit",
]
dev = [
Expand Down Expand Up @@ -94,7 +92,7 @@ port.exclude_lines = [

[tool.mypy]
files = ["src", "tests"]
python_version = "3.8"
python_version = "3.9"
warn_unused_configs = true
strict = true
show_error_codes = true
Expand Down
169 changes: 105 additions & 64 deletions src/relaxed/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,31 @@
__all__ = ("hypotest",)

import logging
from typing import Any
from typing import TYPE_CHECKING, Any

import jax.numpy as jnp
import pyhf
import jax.scipy as jsp
from equinox import filter_jit
from jax import Array

from relaxed.mle import fit, fixed_poi_fit

PyTree = Any
if TYPE_CHECKING:
PyTree = Any
from jax.typing import ArrayLike


@filter_jit
def hypotest(
test_poi: float,
data: Array,
model: PyTree,
init_pars: dict[str, ArrayLike],
bounds: dict[str, ArrayLike] | None = None,
poi_name: str = "mu",
return_mle_pars: bool = False,
test_stat: str = "q",
expected_pars: Array | None = None,
test_stat: str = "qmu",
expected_pars: dict[str, ArrayLike] | None = None,
cls_method: bool = True,
) -> tuple[Array, Array] | Array:
"""Calculate expected CLs/p-values via hypothesis tests.
Expand All @@ -34,106 +39,142 @@ def hypotest(
The value of the test parameter to use for the hypothesis test.
data : Array
The data to use for the hypothesis test.
model : pyhf.Model
The model to use for the hypothesis test.
lr : float
Learning rate for the MLE fit, done via gradient descent.
return_mle_pars : bool, optional
Whether to return the MLE parameters calculated as a by-product.
test_stat : str, optional
The test statistic to use for the hypothesis test. One of:
- "qmu" (default, used for upper limits)
- "q0" (used for discovery of a positive signal)
expected_pars : Array, optional
Use if calculating expected significance and these are known. If not
provided, the MLE parameters will be fitted.
model : PyTree
The model to use for the hypothesis test. Has a `logpdf` method with signature
`logpdf(pars: dict[str, ArrayLike], data: Array) -> Array`.
init_pars : dict[str, ArrayLike]
The initial parameters to use for fits within the hypothesis test.
bounds : dict[str, ArrayLike] | None
(optional) The bounds to use on parameters for fits within the hypothesis test.
poi_name : str
The name of the parameter(s) of interest.
return_mle_pars : bool
Whether to return the MLE parameters.
test_stat : str
The test statistic type to use for the hypothesis test. Default is `qmu`.
expected_pars : dict[str, ArrayLike] | None
The MLE parameters from a previous fit, to use as the expected parameters.
cls_method : bool
Whether to use the CLs method for the hypothesis test. Default is True (if qmu test)
Returns
-------
Array
The expected CLs/p-value.
Array
The MLE parameters, if `return_mle_pars` is True.
or tuple[Array, Array]
The expected CLs/p-value and the MLE parameters. Only returned if `return_mle_pars` is True.
"""
if test_stat == "q":
if test_stat == "q" or test_stat == "qmu":
return qmu_test(
test_poi, data, model, return_mle_pars, expected_pars, cls_method
test_poi,
data,
model,
init_pars,
bounds,
poi_name,
return_mle_pars,
expected_pars,
cls_method,
)
if test_stat == "q0":
logging.info(
"test_poi automatically set to 0 for q0 test (bkg-only null hypothesis)"
)
return q0_test(0.0, data, model, return_mle_pars, expected_pars)
return q0_test(
0.0,
data,
model,
init_pars,
bounds,
poi_name,
return_mle_pars,
expected_pars,
)

msg = f"Unknown test statistic: {test_stat}"
raise ValueError(msg)


@filter_jit
def qmu_test(
def _profile_likelihood_ratio(
test_poi: float,
data: Array,
model: PyTree,
return_mle_pars: bool = False,
init_pars: dict[str, ArrayLike],
bounds: dict[str, ArrayLike] | None,
poi_name: str,
expected_pars: Array | None = None,
cls_method: bool = True,
) -> tuple[Array, Array] | Array:
# hard-code 1 as inits for now
# TODO: need to parse different inits for constrained and global fits
# because init_pars[0] is not necessarily the poi init
init_pars = jnp.asarray(model.config.suggested_init())
) -> tuple[Array, Array]:
# remove the poi from the init_pars -- dict-based logic!
conditional_init = {k: v for k, v in init_pars.items() if k != poi_name}
if bounds is not None:
conditional_bounds = {k: v for k, v in bounds.items() if k != poi_name}
else:
conditional_bounds = None
conditional_pars = fixed_poi_fit(
data, model, poi_condition=test_poi, init_pars=init_pars[:-1]
data,
model,
poi_value=test_poi,
poi_name=poi_name,
init_pars=conditional_init,
bounds=conditional_bounds,
)
if expected_pars is None:
mle_pars = fit(data, model, init_pars=init_pars)
mle_pars = fit(data, model, init_pars=init_pars, bounds=bounds)
else:
mle_pars = expected_pars
profile_likelihood = -2 * (
model.logpdf(conditional_pars, data)[0] - model.logpdf(mle_pars, data)[0]
profile_likelihood_ratio = -2 * (
model.logpdf(pars=conditional_pars, data=data)
- model.logpdf(pars=mle_pars, data=data)
)

poi_hat = mle_pars[model.config.poi_index]
qmu = jnp.where(poi_hat < test_poi, profile_likelihood, 0.0)
return profile_likelihood_ratio, mle_pars

CLsb = 1 - pyhf.tensorlib.normal_cdf(jnp.sqrt(qmu))

@filter_jit
def qmu_test(
test_poi: float,
data: Array,
model: PyTree,
init_pars: dict[str, ArrayLike],
bounds: dict[str, ArrayLike],
poi_name: str,
return_mle_pars: bool = False,
expected_pars: Array | None = None,
cls_method: bool = True,
) -> tuple[Array, Array] | Array:
"""Calculate expected CLs/p-values via qmu test."""
profile_likelihood_ratio, mle_pars = _profile_likelihood_ratio(
test_poi, data, model, init_pars, bounds, poi_name, expected_pars
)
poi_hat = mle_pars[poi_name]
qmu = jnp.where(poi_hat < test_poi, profile_likelihood_ratio, 0.0)
pmu = 1 - jsp.stats.norm.cdf(jnp.sqrt(qmu), loc=0, scale=1)
if cls_method:
altval = 0.0
CLb = 1 - pyhf.tensorlib.normal_cdf(altval)
CLs = CLsb / CLb
alternative_hypothesis = 0.0
power_of_test = 1 - jsp.stats.norm.cdf(alternative_hypothesis, loc=0, scale=1)
result = pmu / power_of_test # same as CLs = p_sb/(1-p_b) = CLs+b/CLb
else:
CLs = CLsb
return (CLs, mle_pars) if return_mle_pars else CLs
result = pmu # this is just the unmodified p-value
return (result, mle_pars) if return_mle_pars else result


@filter_jit
def q0_test(
test_poi: float,
data: Array,
model: PyTree,
init_pars: dict[str, ArrayLike],
bounds: dict[str, ArrayLike],
poi_name: str,
return_mle_pars: bool = False,
expected_pars: Array | None = None,
) -> tuple[Array, Array] | Array:
# hard-code 1 as inits for now
# TODO: need to parse different inits for constrained and global fits
# because init_pars[0] is not necessarily the poi init
init_pars = jnp.asarray(model.config.suggested_init())
conditional_pars = fixed_poi_fit(
data,
model,
poi_condition=test_poi,
init_pars=init_pars[:-1],
"""Calculate expected p-values via q0 test."""
profile_likelihood_ratio, mle_pars = _profile_likelihood_ratio(
test_poi, data, model, init_pars, bounds, poi_name, expected_pars
)
if expected_pars is None:
mle_pars = fit(data, model, init_pars=init_pars)
else:
mle_pars = expected_pars
profile_likelihood = -2 * (
model.logpdf(conditional_pars, data)[0] - model.logpdf(mle_pars, data)[0]
)

poi_hat = mle_pars[model.config.poi_index]
q0 = jnp.where(poi_hat >= test_poi, profile_likelihood, 0.0)
p0 = 1 - pyhf.tensorlib.normal_cdf(jnp.sqrt(q0))

poi_hat = mle_pars[poi_name]
q0 = jnp.where(poi_hat >= test_poi, profile_likelihood_ratio, 0.0)
p0 = 1 - jsp.stats.norm.cdf(jnp.sqrt(q0))
return (p0, mle_pars) if return_mle_pars else p0
33 changes: 13 additions & 20 deletions src/relaxed/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@

__all__ = ("asimov_sig", "gaussianity")

from functools import partial
from typing import TYPE_CHECKING, Any, cast

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import Array
from jax import Array, flatten_util
from jax.random import multivariate_normal

from relaxed.ops import fisher_info

if TYPE_CHECKING:
PyTree = Any
from jax.typing import ArrayLike


@jax.jit
Expand Down Expand Up @@ -43,22 +44,13 @@ def _gaussian_logpdf(
data: Array,
cov: Array,
) -> Array:
return cast(
Array, jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov)
).reshape(
1,
)
return cast(Array, jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov))


@partial(
jax.jit,
static_argnames=[
"n_samples",
],
)
@eqx.filter_jit
def gaussianity(
model: PyTree,
bestfit_pars: Array,
bestfit_pars: dict[str, ArrayLike],
data: Array,
rng_key: Any,
n_samples: int = 1000,
Expand All @@ -70,28 +62,29 @@ def gaussianity(
# - centre the values wrt the best-fit vals to scale the differences

cov_approx = jnp.linalg.inv(fisher_info(model, bestfit_pars, data))

flat_bestfit_pars, tree_structure = flatten_util.ravel_pytree(bestfit_pars)
gaussian_parspace_samples = multivariate_normal(
key=rng_key,
mean=bestfit_pars,
mean=flat_bestfit_pars,
cov=cov_approx,
shape=(n_samples,),
)

relative_nlls_model = jax.vmap(
lambda pars, data: -(
model.logpdf(pars, data)[0] - model.logpdf(bestfit_pars, data)[0]
model.logpdf(pars=tree_structure(pars), data=data)
- model.logpdf(pars=bestfit_pars, data=data)
), # scale origin to bestfit pars
in_axes=(0, None),
)(gaussian_parspace_samples, data)

relative_nlls_gaussian = jax.vmap(
lambda pars, data: -(
_gaussian_logpdf(pars, data, cov_approx)[0]
- _gaussian_logpdf(bestfit_pars, data, cov_approx)[0]
_gaussian_logpdf(pars, data, cov_approx)
- _gaussian_logpdf(flat_bestfit_pars, data, cov_approx)
), # data fixes the lhood shape
in_axes=(0, None),
)(gaussian_parspace_samples, bestfit_pars)
)(gaussian_parspace_samples, flat_bestfit_pars)

diffs = relative_nlls_model - relative_nlls_gaussian
return jnp.nanmean(diffs**2, axis=0)
Loading

0 comments on commit 7c483c8

Please sign in to comment.