From 73ecc281594f177f5990dc1f0d1931b8acceaed7 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Wed, 2 Aug 2023 21:56:41 +0100 Subject: [PATCH 01/13] start refactor --- src/relaxed/infer.py | 97 ++++++++++++++++------------- src/relaxed/mle.py | 95 ++++++++++++++-------------- tests/dummy_pyhf.py | 145 ++++++++++++++++++++++++------------------- tests/test_fit.py | 14 ++--- 4 files changed, 192 insertions(+), 159 deletions(-) diff --git a/src/relaxed/infer.py b/src/relaxed/infer.py index d8b4232..dc7b1bf 100644 --- a/src/relaxed/infer.py +++ b/src/relaxed/infer.py @@ -4,16 +4,18 @@ __all__ = ("hypotest",) import logging -from typing import Any +from typing import Any, TYPE_CHECKING import jax.numpy as jnp -import pyhf from equinox import filter_jit from jax import Array +import jax.scipy as jsp from relaxed.mle import fit, fixed_poi_fit -PyTree = Any +if TYPE_CHECKING: + PyTree = Any + from jax.typing import ArrayLike @filter_jit @@ -21,9 +23,12 @@ def hypotest( test_poi: float, data: Array, model: PyTree, + init_pars: dict[str, ArrayLike], + bounds: dict[str, ArrayLike], + poi_name: str = "mu", return_mle_pars: bool = False, - test_stat: str = "q", - expected_pars: Array | None = None, + test_stat: str = "qmu", + expected_pars: dict[str, ArrayLike] | None = None, cls_method: bool = True, ) -> tuple[Array, Array] | Array: """Calculate expected CLs/p-values via hypothesis tests. @@ -34,36 +39,39 @@ def hypotest( The value of the test parameter to use for the hypothesis test. data : Array The data to use for the hypothesis test. - model : pyhf.Model + model : PyTree The model to use for the hypothesis test. - lr : float - Learning rate for the MLE fit, done via gradient descent. - return_mle_pars : bool, optional - Whether to return the MLE parameters calculated as a by-product. - test_stat : str, optional - The test statistic to use for the hypothesis test. One of: - - "qmu" (default, used for upper limits) - - "q0" (used for discovery of a positive signal) - expected_pars : Array, optional - Use if calculating expected significance and these are known. If not - provided, the MLE parameters will be fitted. + init_pars : dict[str, ArrayLike] + The initial parameters to use for fits within the hypothesis test. + bounds : dict[str, ArrayLike] + The bounds to use for fits within the hypothesis test. + poi_name : str + The name of the parameter(s) of interest. + return_mle_pars : bool + Whether to return the MLE parameters. + test_stat : str + The test statistic type to use for the hypothesis test. Default is `qmu`. + expected_pars : dict[str, ArrayLike] | None + The MLE parameters from a previous fit, to use as the expected parameters. + cls_method : bool + Whether to use the CLs method for the hypothesis test. Default is True (if qmu test) Returns ------- Array The expected CLs/p-value. - Array - The MLE parameters, if `return_mle_pars` is True. + or tuple[Array, Array] + The expected CLs/p-value and the MLE parameters. Only returned if `return_mle_pars` is True. """ - if test_stat == "q": + if test_stat == "q" or test_stat == "qmu": return qmu_test( - test_poi, data, model, return_mle_pars, expected_pars, cls_method + test_poi, data, model, init_pars, bounds, poi_name, return_mle_pars, expected_pars, cls_method ) if test_stat == "q0": logging.info( "test_poi automatically set to 0 for q0 test (bkg-only null hypothesis)" ) - return q0_test(0.0, data, model, return_mle_pars, expected_pars) + return q0_test(0.0, data, model, init_pars, bounds, poi_name, return_mle_pars, expected_pars) msg = f"Unknown test statistic: {test_stat}" raise ValueError(msg) @@ -74,32 +82,34 @@ def qmu_test( test_poi: float, data: Array, model: PyTree, + init_pars: dict[str, ArrayLike], + bounds: dict[str, ArrayLike], + poi_name: str, return_mle_pars: bool = False, expected_pars: Array | None = None, cls_method: bool = True, ) -> tuple[Array, Array] | Array: - # hard-code 1 as inits for now - # TODO: need to parse different inits for constrained and global fits - # because init_pars[0] is not necessarily the poi init - init_pars = jnp.asarray(model.config.suggested_init()) + # remove the poi from the init_pars + conditional_init = {k: v for k, v in init_pars.items() if k != poi_name} + conditional_bounds = {k: v for k, v in bounds.items() if k != poi_name} conditional_pars = fixed_poi_fit( - data, model, poi_condition=test_poi, init_pars=init_pars[:-1] + data, model, poi_value=test_poi, poi_name=poi_name, init_pars=conditional_init, bounds=conditional_bounds ) if expected_pars is None: - mle_pars = fit(data, model, init_pars=init_pars) + mle_pars = fit(data, model, init_pars=init_pars, bounds=bounds) else: mle_pars = expected_pars profile_likelihood = -2 * ( - model.logpdf(conditional_pars, data)[0] - model.logpdf(mle_pars, data)[0] + model.logpdf(pars=conditional_pars, data=data) - model.logpdf(pars=mle_pars, data=data) ) - poi_hat = mle_pars[model.config.poi_index] + poi_hat = mle_pars[poi_name] qmu = jnp.where(poi_hat < test_poi, profile_likelihood, 0.0) - CLsb = 1 - pyhf.tensorlib.normal_cdf(jnp.sqrt(qmu)) + CLsb = 1 - jsp.stats.norm.cdf(jnp.sqrt(qmu), loc=0, scale=1) if cls_method: altval = 0.0 - CLb = 1 - pyhf.tensorlib.normal_cdf(altval) + CLb = 1 - jsp.stats.norm.cdf(altval, loc=0, scale=1) CLs = CLsb / CLb else: CLs = CLsb @@ -111,29 +121,28 @@ def q0_test( test_poi: float, data: Array, model: PyTree, + init_pars: dict[str, ArrayLike], + bounds: dict[str, ArrayLike], + poi_name: str, return_mle_pars: bool = False, expected_pars: Array | None = None, ) -> tuple[Array, Array] | Array: - # hard-code 1 as inits for now - # TODO: need to parse different inits for constrained and global fits - # because init_pars[0] is not necessarily the poi init - init_pars = jnp.asarray(model.config.suggested_init()) + # remove the poi from the init_pars + conditional_init = {k: v for k, v in init_pars.items() if k != poi_name} + conditional_bounds = {k: v for k, v in bounds.items() if k != poi_name} conditional_pars = fixed_poi_fit( - data, - model, - poi_condition=test_poi, - init_pars=init_pars[:-1], + data, model, poi_value=test_poi, poi_name=poi_name, init_pars=conditional_init, bounds=conditional_bounds ) if expected_pars is None: - mle_pars = fit(data, model, init_pars=init_pars) + mle_pars = fit(data, model, init_pars=init_pars, bounds=bounds) else: mle_pars = expected_pars profile_likelihood = -2 * ( - model.logpdf(conditional_pars, data)[0] - model.logpdf(mle_pars, data)[0] + model.logpdf(pars=conditional_pars, data=data) - model.logpdf(pars=mle_pars, data=data) ) - poi_hat = mle_pars[model.config.poi_index] + poi_hat = mle_pars[poi_name] q0 = jnp.where(poi_hat >= test_poi, profile_likelihood, 0.0) - p0 = 1 - pyhf.tensorlib.normal_cdf(jnp.sqrt(q0)) + p0 = 1 - jsp.stats.norm.cdf(jnp.sqrt(q0)) return (p0, mle_pars) if return_mle_pars else p0 diff --git a/src/relaxed/mle.py b/src/relaxed/mle.py index 5d50944..94948d8 100644 --- a/src/relaxed/mle.py +++ b/src/relaxed/mle.py @@ -3,25 +3,51 @@ __all__ = ("fit", "fixed_poi_fit") import inspect -from typing import TYPE_CHECKING, Any, Callable, cast +from typing import TYPE_CHECKING, Any, Callable, cast, Sequence -import jax.numpy as jnp import jaxopt from equinox import filter_jit +import jax +import numpy as np +import jax.numpy as jnp +from jax import Array -if TYPE_CHECKING: - from jax import Array +if TYPE_CHECKING: + from jax.typing import ArrayLike PyTree = Any +def _get_bounds(bounds: dict[str, ArrayLike], init_pars: dict[str, ArrayLike]) -> tuple[dict[str, ArrayLike], dict[str, ArrayLike]]: + """Convert dict of bounds to a dict of lower and a dict of upper bounds.""" + lower = {} + upper = {} + + for k, v in bounds.items(): + # Convert to array for easy manipulation + v = jnp.asarray(v) + + # Check if v is 1D or 2D + if v.ndim == 1: + if isinstance(init_pars[k], (list, jax.Array, np.ndarray)) and init_pars[k].size > 1: # If the initial parameter is a list or array + lower[k] = jnp.array([v[0]] * len(init_pars[k])) + upper[k] = jnp.array([v[1]] * len(init_pars[k])) + else: # If the initial parameter is a single value + lower[k] = v[0] + upper[k] = v[1] + else: + lower[k] = jnp.array([item[0] for item in v]) + upper[k] = jnp.array([item[1] for item in v]) + + return lower, upper + @filter_jit def _minimize( fit_objective: Callable[[Array], float], model: PyTree, data: Array, - init_pars: Array, - bounds: Array, + init_pars: dict[str, ArrayLike], + bounds: dict[str, ArrayLike], method: str = "LBFGSB", maxiter: int = 500, tol: float = 1e-6, @@ -34,7 +60,8 @@ def _minimize( fun=fit_objective, implicit_diff=True, **other_settings ) if "bounds" in inspect.signature(minimizer.init_state).parameters: - return minimizer.run(init_pars, bounds=bounds, model=model, data=data)[0] + lower, upper = _get_bounds(bounds, init_pars) + return minimizer.run(init_pars, bounds=(lower, upper), model=model, data=data)[0] return minimizer.run(init_pars, model=model, data=data)[0] @@ -42,21 +69,15 @@ def _minimize( def fit( data: Array, model: PyTree, - init_pars: Array | None = None, - bounds: tuple[Array, Array] | None = None, + init_pars: dict[str, ArrayLike], + bounds: dict[str, Array], method: str = "LBFGSB", maxiter: int = 500, tol: float = 1e-6, other_settings: dict[str, float] | None = None, -) -> Array: +) -> dict[str, Array]: def fit_objective(pars: Array, model: PyTree, data: Array) -> float: - return cast(float, -model.logpdf(pars, data)[0]) - - if bounds is None: - bounds = model.config.suggested_bounds() - - if init_pars is None: - init_pars = model.config.suggested_init() + return cast(float, -model.logpdf(data=data, pars=pars)) return _minimize( fit_objective=fit_objective, @@ -75,36 +96,22 @@ def fit_objective(pars: Array, model: PyTree, data: Array) -> float: def fixed_poi_fit( data: Array, model: PyTree, - poi_condition: float, - init_pars: Array | None = None, - bounds: Array | None = None, + poi_value: float, + poi_name: str, + init_pars: dict[str, ArrayLike], + bounds: dict[str, Array], method: str = "LBFGSB", maxiter: int = 500, tol: float = 1e-6, other_settings: dict[str, float] | None = None, -) -> Array: - poi_idx = model.config.poi_index +) -> dict[str, Array]: - def fit_objective(pars: Array, model: PyTree, data: Array) -> float: # NLL + def fit_objective(pars: dict[str, Array], model: PyTree, data: Array) -> float: # NLL """lhood_pars_to_optimize: either all pars, or just nuisance pars""" - # pyhf.Model.logpdf returns list[float] - blank = jnp.zeros_like(jnp.asarray(model.config.suggested_init())) - blank += pars - return cast(float, -model.logpdf(blank.at[poi_idx].set(poi_condition), data)[0]) - - if bounds is None: - lower, upper = model.config.suggested_bounds() - # ignore poi bounds - upper = jnp.delete(upper, poi_idx) - lower = jnp.delete(lower, poi_idx) - bounds = jnp.array([lower, upper]) - - if init_pars is None: - init_pars = model.config.suggested_init() - # ignore poi init - init_pars = jnp.delete(init_pars, poi_idx) - - fit_res = _minimize( + pars[poi_name] = poi_value + return cast(float, -model.logpdf(data=data, pars=pars)) + + res = _minimize( fit_objective=fit_objective, model=model, data=data, @@ -115,7 +122,5 @@ def fit_objective(pars: Array, model: PyTree, data: Array) -> float: # NLL tol=tol, other_settings=other_settings, ) - blank = jnp.zeros_like(jnp.asarray(model.config.suggested_init())) - blank += fit_res - poi_idx = model.config.poi_index - return blank.at[poi_idx].set(poi_condition) + res[poi_name] = poi_value + return res \ No newline at end of file diff --git a/tests/dummy_pyhf.py b/tests/dummy_pyhf.py index a553c49..32c9152 100644 --- a/tests/dummy_pyhf.py +++ b/tests/dummy_pyhf.py @@ -1,81 +1,100 @@ -from __future__ import annotations - -__all__ = ("example_model", "uncorrelated_background") - -from typing import Any, Iterable - -import jax +import jax.scipy as jsp +import equinox as eqx import jax.numpy as jnp -import pyhf -from equinox import Module as PyTree - - -class _Config(PyTree): - poi_index: int - npars: int - auxdata: jax.Array - - def __init__(self, aux) -> None: - self.poi_index = 0 - self.npars = 2 - self.auxdata = aux - - def suggested_init(self) -> jax.Array: - return jnp.asarray([1.0, 1.0]) - - def suggested_bounds(self) -> tuple[jax.Array, jax.Array]: - return jnp.asarray([[0.0, 0.0], [10.0, 10.0]]) - - -class Model(PyTree): - """Dummy class to mimic the functionality of `pyhf.Model`.""" - - sig: jax.Array - nominal: jax.Array - uncert: jax.Array - factor: jax.Array - config: _Config - - def __init__(self, spec: Iterable[Any]) -> None: - self.sig, self.nominal, self.uncert = spec - self.factor = (self.nominal / self.uncert) ** 2 - self.config = _Config(1.0 * self.factor) - - def expected_data(self, pars: jax.Array) -> jax.Array: - mu, gamma = pars - expected_main = jnp.asarray([gamma * self.nominal + mu * self.sig]) - return jnp.concatenate([expected_main, jnp.array([self.config.auxdata])]) +from jax import Array +import jax - # logpdf as the call method - def logpdf(self, pars: jax.Array, data: jax.Array) -> jax.Array: +jax.config.update("jax_enable_x64", True) + + +@jax.jit +def poisson_logpdf(n, lam): + return n * jnp.log(lam) - lam - jsp.special.gammaln(n + 1) + + +class Model(eqx.Module): + def logpdf(self, data: Array, pars: dict[str, Array] | Array) -> Array: + raise NotImplementedError + + def expected_data(self, pars: dict[str, Array] | Array) -> Array: + raise NotImplementedError + + +class Systematic(eqx.Module): + name: str = eqx.field(static=True) + constraint: Model + + +class PoissonConstraint(Model): + scaled_binwise_uncerts: Array + + def __init__(self, nominal_bkg: Array, binwise_uncerts: Array) -> None: + eqx.error_if( + nominal_bkg, + nominal_bkg.shape != binwise_uncerts.shape, + f"Nominal bkg shape {nominal_bkg.shape} does not match binwise uncertainty shape {binwise_uncerts.shape}" + ) + self.scaled_binwise_uncerts = binwise_uncerts / nominal_bkg + + def expected_data(self, gamma: Array) -> Array: + return gamma*self.scaled_binwise_uncerts**-2 + + def logpdf(self, auxdata, gamma): + eqx.error_if( + gamma, + gamma.shape != self.scaled_binwise_uncerts.shape, + f"Constrained param shape {gamma.shape} does not match number of bins {self.scaled_binwise_uncerts.shape}" + ) + return jnp.sum( + poisson_logpdf(auxdata, (gamma*self.scaled_binwise_uncerts**-2)), + axis=None + ) + +class UncorrelatedShape(Systematic): + def __init__(self, name: str, nominal_bkg: Array, binwise_uncerts: Array) -> None: + self.name = name + self.constraint = PoissonConstraint(nominal_bkg, binwise_uncerts) + + +class HEPDataLike(Model): + sig: Array + bkg: Array + db: Array + systematic: UncorrelatedShape + + def __init__(self, sig: Array, bkg: Array, db: Array) -> None: + self.sig = sig + self.bkg = bkg + self.db = db + self.systematic = UncorrelatedShape("shapesys", bkg, db) + + def expected_data(self, pars: dict[str, Array]) -> Array: + mu, gamma = pars["mu"], pars["shapesys"] + return mu * self.sig + gamma * self.bkg, self.systematic.constraint.expected_data(gamma) + + def logpdf(self, data: Array, pars: dict[str, Array]) -> Array: maindata, auxdata = data main, _ = self.expected_data(pars) - _, gamma = pars - main = pyhf.probability.Poisson(main).log_prob(maindata) - constraint = pyhf.probability.Poisson(gamma * self.factor).log_prob(auxdata) - # sum log probs over bins - return [jnp.sum(jnp.asarray([main + constraint]), axis=None)] - + main = jnp.sum(poisson_logpdf(maindata, main), axis=None) + constraint = self.systematic.constraint.logpdf(auxdata, pars["shapesys"]) + return main + constraint -def uncorrelated_background(s: jax.Array, b: jax.Array, db: jax.Array) -> Model: - """Dummy class to mimic the functionality of `pyhf.simplemodels.hepdata_like`.""" - return Model([s, b, db]) -def _calc_yields(x: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +def _calc_yields(x: jnp.ndarray, n_bins: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: s = 15 + x b = 45 - 2 * x db = 1 + 0.2 * x**2 - return [s], [b], [db] + return [s]*n_bins, [b]*n_bins, [db]*n_bins def example_model( - phi: jnp.ndarray, return_yields: bool = False + phi: jnp.ndarray, n_bins:int, return_yields: bool = False ) -> Model | tuple[Model, jnp.ndarray]: - s, b, db = yields = _calc_yields(phi) + s, b, db = yields = _calc_yields(phi, n_bins) - model = uncorrelated_background( - jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db]) + model = HEPDataLike( + jnp.asarray(s), jnp.asarray(b), jnp.asarray(db) ) if return_yields: diff --git a/tests/test_fit.py b/tests/test_fit.py index c95874d..3efb47e 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -72,15 +72,15 @@ def test_fixed_poi_fit_grad(): pyhf.set_backend("jax") def pipeline(x): - model = uncorrelated_background(x * 5.0, x * 20, x * 2) - lower, upper = model.config.suggested_bounds() + model = example_model(x) + pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} - return relaxed.mle.fixed_poi_fit( + relaxed.infer.hypotest( + 1., + data=model.expected_data(pars), model=model, - data=model.expected_data(jnp.array([0.0, 1.0])), - init_pars=model.config.suggested_init()[1:], - poi_condition=1.0, - bounds=(lower[1:], upper[1:]), + init_pars={"mu": jnp.array(1.0), "shapesys": jnp.array([1.0, 1.0])}, + bounds={"mu": (0, 10), "shapesys": (0, 10)}, ) jacrev(pipeline)(jnp.asarray(0.5)) From 8bbb3232ca9135f3a54da1da65f450833f81d1ad Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Thu, 3 Aug 2023 10:05:30 +0100 Subject: [PATCH 02/13] add some more --- src/relaxed/infer.py | 79 ++++++++++++++++++++++++-------------------- src/relaxed/mle.py | 4 +-- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/src/relaxed/infer.py b/src/relaxed/infer.py index dc7b1bf..d3d555f 100644 --- a/src/relaxed/infer.py +++ b/src/relaxed/infer.py @@ -24,7 +24,7 @@ def hypotest( data: Array, model: PyTree, init_pars: dict[str, ArrayLike], - bounds: dict[str, ArrayLike], + bounds: dict[str, ArrayLike] | None = None, poi_name: str = "mu", return_mle_pars: bool = False, test_stat: str = "qmu", @@ -40,11 +40,12 @@ def hypotest( data : Array The data to use for the hypothesis test. model : PyTree - The model to use for the hypothesis test. + The model to use for the hypothesis test. Has a `logpdf` method with signature + `logpdf(pars: dict[str, ArrayLike], data: Array) -> Array`. init_pars : dict[str, ArrayLike] The initial parameters to use for fits within the hypothesis test. - bounds : dict[str, ArrayLike] - The bounds to use for fits within the hypothesis test. + bounds : dict[str, ArrayLike] | None + (optional) The bounds to use on parameters for fits within the hypothesis test. poi_name : str The name of the parameter(s) of interest. return_mle_pars : bool @@ -78,20 +79,21 @@ def hypotest( @filter_jit -def qmu_test( +def _profile_likelihood_ratio( test_poi: float, data: Array, model: PyTree, init_pars: dict[str, ArrayLike], - bounds: dict[str, ArrayLike], + bounds: dict[str, ArrayLike] | None, poi_name: str, - return_mle_pars: bool = False, expected_pars: Array | None = None, - cls_method: bool = True, -) -> tuple[Array, Array] | Array: - # remove the poi from the init_pars +) -> tuple[Array, Array]: + # remove the poi from the init_pars -- dict-based logic! conditional_init = {k: v for k, v in init_pars.items() if k != poi_name} - conditional_bounds = {k: v for k, v in bounds.items() if k != poi_name} + if bounds is not None: + conditional_bounds = {k: v for k, v in bounds.items() if k != poi_name} + else: + conditional_bounds = None conditional_pars = fixed_poi_fit( data, model, poi_value=test_poi, poi_name=poi_name, init_pars=conditional_init, bounds=conditional_bounds ) @@ -99,21 +101,39 @@ def qmu_test( mle_pars = fit(data, model, init_pars=init_pars, bounds=bounds) else: mle_pars = expected_pars - profile_likelihood = -2 * ( + profile_likelihood_ratio = -2 * ( model.logpdf(pars=conditional_pars, data=data) - model.logpdf(pars=mle_pars, data=data) ) - poi_hat = mle_pars[poi_name] - qmu = jnp.where(poi_hat < test_poi, profile_likelihood, 0.0) + return profile_likelihood_ratio, mle_pars + - CLsb = 1 - jsp.stats.norm.cdf(jnp.sqrt(qmu), loc=0, scale=1) +@filter_jit +def qmu_test( + test_poi: float, + data: Array, + model: PyTree, + init_pars: dict[str, ArrayLike], + bounds: dict[str, ArrayLike], + poi_name: str, + return_mle_pars: bool = False, + expected_pars: Array | None = None, + cls_method: bool = True, +) -> tuple[Array, Array] | Array: + """Calculate expected CLs/p-values via qmu test.""" + profile_likelihood_ratio, mle_pars = _profile_likelihood_ratio( + test_poi, data, model, init_pars, bounds, poi_name, expected_pars + ) + poi_hat = mle_pars[poi_name] + qmu = jnp.where(poi_hat < test_poi, profile_likelihood_ratio, 0.0) + pmu = 1 - jsp.stats.norm.cdf(jnp.sqrt(qmu), loc=0, scale=1) if cls_method: - altval = 0.0 - CLb = 1 - jsp.stats.norm.cdf(altval, loc=0, scale=1) - CLs = CLsb / CLb + alternative_hypothesis = 0.0 # point alternative is bkg-only + power_of_test = 1 - jsp.stats.norm.cdf(alternative_hypothesis, loc=0, scale=1) + result = pmu / power_of_test # same as CLs = p_sb/(1-p_b) = CLs+b/CLb else: - CLs = CLsb - return (CLs, mle_pars) if return_mle_pars else CLs + result = pmu # this is just the unmodified p-value + return (result, mle_pars) if return_mle_pars else result @filter_jit @@ -127,22 +147,11 @@ def q0_test( return_mle_pars: bool = False, expected_pars: Array | None = None, ) -> tuple[Array, Array] | Array: - # remove the poi from the init_pars - conditional_init = {k: v for k, v in init_pars.items() if k != poi_name} - conditional_bounds = {k: v for k, v in bounds.items() if k != poi_name} - conditional_pars = fixed_poi_fit( - data, model, poi_value=test_poi, poi_name=poi_name, init_pars=conditional_init, bounds=conditional_bounds + """Calculate expected p-values via q0 test.""" + profile_likelihood_ratio, mle_pars = _profile_likelihood_ratio( + test_poi, data, model, init_pars, bounds, poi_name, expected_pars ) - if expected_pars is None: - mle_pars = fit(data, model, init_pars=init_pars, bounds=bounds) - else: - mle_pars = expected_pars - profile_likelihood = -2 * ( - model.logpdf(pars=conditional_pars, data=data) - model.logpdf(pars=mle_pars, data=data) - ) - poi_hat = mle_pars[poi_name] - q0 = jnp.where(poi_hat >= test_poi, profile_likelihood, 0.0) + q0 = jnp.where(poi_hat >= test_poi, profile_likelihood_ratio, 0.0) p0 = 1 - jsp.stats.norm.cdf(jnp.sqrt(q0)) - return (p0, mle_pars) if return_mle_pars else p0 diff --git a/src/relaxed/mle.py b/src/relaxed/mle.py index 94948d8..2f6173f 100644 --- a/src/relaxed/mle.py +++ b/src/relaxed/mle.py @@ -70,7 +70,7 @@ def fit( data: Array, model: PyTree, init_pars: dict[str, ArrayLike], - bounds: dict[str, Array], + bounds: dict[str, Array] | None = None, method: str = "LBFGSB", maxiter: int = 500, tol: float = 1e-6, @@ -99,7 +99,7 @@ def fixed_poi_fit( poi_value: float, poi_name: str, init_pars: dict[str, ArrayLike], - bounds: dict[str, Array], + bounds: dict[str, Array] | None = None, method: str = "LBFGSB", maxiter: int = 500, tol: float = 1e-6, From d06ae26d750e0c0a1135c3a8127422cd3c7a30eb Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Thu, 3 Aug 2023 13:19:56 +0100 Subject: [PATCH 03/13] make some changes to ops that support dicts --- src/relaxed/infer.py | 37 ++++++++++++++++++++----- src/relaxed/metrics.py | 10 ++----- src/relaxed/mle.py | 48 ++++++++++++++++++-------------- src/relaxed/ops.py | 20 ++++++++++---- tests/dummy_pyhf.py | 63 +++++++++++++++++++++++------------------- tests/test_fit.py | 63 +++++++++++++++++++++--------------------- tests/test_infer.py | 50 +++++++++++++++++++++------------ tests/test_metrics.py | 14 ++++++---- tests/test_ops.py | 38 ++++++++++++------------- 9 files changed, 200 insertions(+), 143 deletions(-) diff --git a/src/relaxed/infer.py b/src/relaxed/infer.py index d3d555f..4679aa7 100644 --- a/src/relaxed/infer.py +++ b/src/relaxed/infer.py @@ -4,12 +4,12 @@ __all__ = ("hypotest",) import logging -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any import jax.numpy as jnp +import jax.scipy as jsp from equinox import filter_jit from jax import Array -import jax.scipy as jsp from relaxed.mle import fit, fixed_poi_fit @@ -45,7 +45,7 @@ def hypotest( init_pars : dict[str, ArrayLike] The initial parameters to use for fits within the hypothesis test. bounds : dict[str, ArrayLike] | None - (optional) The bounds to use on parameters for fits within the hypothesis test. + (optional) The bounds to use on parameters for fits within the hypothesis test. poi_name : str The name of the parameter(s) of interest. return_mle_pars : bool @@ -66,13 +66,30 @@ def hypotest( """ if test_stat == "q" or test_stat == "qmu": return qmu_test( - test_poi, data, model, init_pars, bounds, poi_name, return_mle_pars, expected_pars, cls_method + test_poi, + data, + model, + init_pars, + bounds, + poi_name, + return_mle_pars, + expected_pars, + cls_method, ) if test_stat == "q0": logging.info( "test_poi automatically set to 0 for q0 test (bkg-only null hypothesis)" ) - return q0_test(0.0, data, model, init_pars, bounds, poi_name, return_mle_pars, expected_pars) + return q0_test( + 0.0, + data, + model, + init_pars, + bounds, + poi_name, + return_mle_pars, + expected_pars, + ) msg = f"Unknown test statistic: {test_stat}" raise ValueError(msg) @@ -95,14 +112,20 @@ def _profile_likelihood_ratio( else: conditional_bounds = None conditional_pars = fixed_poi_fit( - data, model, poi_value=test_poi, poi_name=poi_name, init_pars=conditional_init, bounds=conditional_bounds + data, + model, + poi_value=test_poi, + poi_name=poi_name, + init_pars=conditional_init, + bounds=conditional_bounds, ) if expected_pars is None: mle_pars = fit(data, model, init_pars=init_pars, bounds=bounds) else: mle_pars = expected_pars profile_likelihood_ratio = -2 * ( - model.logpdf(pars=conditional_pars, data=data) - model.logpdf(pars=mle_pars, data=data) + model.logpdf(pars=conditional_pars, data=data) + - model.logpdf(pars=mle_pars, data=data) ) return profile_likelihood_ratio, mle_pars diff --git a/src/relaxed/metrics.py b/src/relaxed/metrics.py index 02521f5..4b20c44 100644 --- a/src/relaxed/metrics.py +++ b/src/relaxed/metrics.py @@ -43,11 +43,7 @@ def _gaussian_logpdf( data: Array, cov: Array, ) -> Array: - return cast( - Array, jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov) - ).reshape( - 1, - ) + return cast(Array, jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov)) @partial( @@ -58,7 +54,7 @@ def _gaussian_logpdf( ) def gaussianity( model: PyTree, - bestfit_pars: Array, + bestfit_pars: dict[str, Array], data: Array, rng_key: Any, n_samples: int = 1000, @@ -80,7 +76,7 @@ def gaussianity( relative_nlls_model = jax.vmap( lambda pars, data: -( - model.logpdf(pars, data)[0] - model.logpdf(bestfit_pars, data)[0] + model.logpdf(pars, data) - model.logpdf(bestfit_pars, data) ), # scale origin to bestfit pars in_axes=(0, None), )(gaussian_parspace_samples, data) diff --git a/src/relaxed/mle.py b/src/relaxed/mle.py index 2f6173f..b6a08f7 100644 --- a/src/relaxed/mle.py +++ b/src/relaxed/mle.py @@ -3,44 +3,50 @@ __all__ = ("fit", "fixed_poi_fit") import inspect -from typing import TYPE_CHECKING, Any, Callable, cast, Sequence +from typing import TYPE_CHECKING, Any, Callable, cast -import jaxopt -from equinox import filter_jit import jax -import numpy as np import jax.numpy as jnp +import jaxopt +import numpy as np +from equinox import filter_jit from jax import Array - if TYPE_CHECKING: from jax.typing import ArrayLike + PyTree = Any -def _get_bounds(bounds: dict[str, ArrayLike], init_pars: dict[str, ArrayLike]) -> tuple[dict[str, ArrayLike], dict[str, ArrayLike]]: +def _parse_bounds( + bounds: dict[str, ArrayLike], init_pars: dict[str, ArrayLike] +) -> tuple[dict[str, ArrayLike], dict[str, ArrayLike]]: """Convert dict of bounds to a dict of lower and a dict of upper bounds.""" lower = {} upper = {} for k, v in bounds.items(): # Convert to array for easy manipulation - v = jnp.asarray(v) + array_v = jnp.asarray(v) # Check if v is 1D or 2D - if v.ndim == 1: - if isinstance(init_pars[k], (list, jax.Array, np.ndarray)) and init_pars[k].size > 1: # If the initial parameter is a list or array - lower[k] = jnp.array([v[0]] * len(init_pars[k])) - upper[k] = jnp.array([v[1]] * len(init_pars[k])) + if array_v.ndim == 1: + if ( + isinstance(init_pars[k], (list, jax.Array, np.ndarray)) + and init_pars[k].size > 1 + ): # If the initial parameter is a list or array + lower[k] = jnp.array([array_v[0]] * len(init_pars[k])) + upper[k] = jnp.array([array_v[1]] * len(init_pars[k])) else: # If the initial parameter is a single value - lower[k] = v[0] - upper[k] = v[1] + lower[k] = array_v[0] + upper[k] = array_v[1] else: - lower[k] = jnp.array([item[0] for item in v]) - upper[k] = jnp.array([item[1] for item in v]) + lower[k] = jnp.array([item[0] for item in array_v]) + upper[k] = jnp.array([item[1] for item in array_v]) return lower, upper + @filter_jit def _minimize( fit_objective: Callable[[Array], float], @@ -60,8 +66,9 @@ def _minimize( fun=fit_objective, implicit_diff=True, **other_settings ) if "bounds" in inspect.signature(minimizer.init_state).parameters: - lower, upper = _get_bounds(bounds, init_pars) - return minimizer.run(init_pars, bounds=(lower, upper), model=model, data=data)[0] + if bounds is not None: + bounds = _parse_bounds(bounds, init_pars) + return minimizer.run(init_pars, bounds=bounds, model=model, data=data)[0] return minimizer.run(init_pars, model=model, data=data)[0] @@ -105,8 +112,9 @@ def fixed_poi_fit( tol: float = 1e-6, other_settings: dict[str, float] | None = None, ) -> dict[str, Array]: - - def fit_objective(pars: dict[str, Array], model: PyTree, data: Array) -> float: # NLL + def fit_objective( + pars: dict[str, Array], model: PyTree, data: Array + ) -> float: # NLL """lhood_pars_to_optimize: either all pars, or just nuisance pars""" pars[poi_name] = poi_value return cast(float, -model.logpdf(data=data, pars=pars)) @@ -123,4 +131,4 @@ def fit_objective(pars: dict[str, Array], model: PyTree, data: Array) -> float: other_settings=other_settings, ) res[poi_name] = poi_value - return res \ No newline at end of file + return res diff --git a/src/relaxed/ops.py b/src/relaxed/ops.py index 477db27..32be3e5 100644 --- a/src/relaxed/ops.py +++ b/src/relaxed/ops.py @@ -94,7 +94,7 @@ def hist( @jax.jit -def fisher_info(model: Any, pars: Array, data: Array) -> Array: +def fisher_info(model: Any, pars: dict[str, Array], data: Array) -> Array: """Fisher information matrix for a model with a logpdf method. Parameters @@ -102,17 +102,27 @@ def fisher_info(model: Any, pars: Array, data: Array) -> Array: model : Any The model to compute the Fisher information matrix for. Needs to have a logpdf method (that returns list[float] for now). - pars : Array - The (MLE) parameters of the model. + pars : dict[str, Array] + The (MLE) parameters of the model, as a dict of arrays/floats. data : Array The data to compute the Fisher information matrix for. Returns ------- Array - Fisher information matrix. + Fisher information matrix of shape (num_pars, num_pars). + Order of columns is the same as the order of the parameters in pars. + Parameters with multiple dimensions are flattened into their own columns. """ - return jnp.linalg.inv(-jax.hessian(lambda p, d: model.logpdf(p, d)[0])(pars, data)) + + def lpdf(pars, data): # handle keyword arguments + return model.logpdf(pars=pars, data=data) + + num_pars = len(jax.tree_util.tree_flatten(pars)[0]) + hessian = jnp.array( + jax.tree_util.tree_flatten(jax.hessian(lpdf)(pars, data))[0] + ).reshape(num_pars, num_pars) + return jnp.linalg.inv(-hessian) @jax.jit diff --git a/tests/dummy_pyhf.py b/tests/dummy_pyhf.py index 32c9152..5e37c45 100644 --- a/tests/dummy_pyhf.py +++ b/tests/dummy_pyhf.py @@ -1,8 +1,10 @@ -import jax.scipy as jsp +from __future__ import annotations + import equinox as eqx +import jax import jax.numpy as jnp +import jax.scipy as jsp from jax import Array -import jax jax.config.update("jax_enable_x64", True) @@ -15,10 +17,10 @@ def poisson_logpdf(n, lam): class Model(eqx.Module): def logpdf(self, data: Array, pars: dict[str, Array] | Array) -> Array: raise NotImplementedError - + def expected_data(self, pars: dict[str, Array] | Array) -> Array: raise NotImplementedError - + class Systematic(eqx.Module): name: str = eqx.field(static=True) @@ -29,27 +31,28 @@ class PoissonConstraint(Model): scaled_binwise_uncerts: Array def __init__(self, nominal_bkg: Array, binwise_uncerts: Array) -> None: - eqx.error_if( - nominal_bkg, - nominal_bkg.shape != binwise_uncerts.shape, - f"Nominal bkg shape {nominal_bkg.shape} does not match binwise uncertainty shape {binwise_uncerts.shape}" - ) + if nominal_bkg.shape != binwise_uncerts.shape: + msg = f"Nominal bkg shape {nominal_bkg.shape} does not match binwise uncertainty shape {binwise_uncerts.shape}" + raise ValueError(msg) self.scaled_binwise_uncerts = binwise_uncerts / nominal_bkg - + def expected_data(self, gamma: Array) -> Array: - return gamma*self.scaled_binwise_uncerts**-2 - + return gamma * self.scaled_binwise_uncerts**-2 + def logpdf(self, auxdata, gamma): - eqx.error_if( - gamma, - gamma.shape != self.scaled_binwise_uncerts.shape, - f"Constrained param shape {gamma.shape} does not match number of bins {self.scaled_binwise_uncerts.shape}" - ) + if not isinstance(gamma, Array): + gamma = jnp.array(gamma) + if gamma.shape != self.scaled_binwise_uncerts.shape and not ( + gamma.shape == () and self.scaled_binwise_uncerts.shape == (1,) + ): + msg = f"Constrained param shape {gamma.shape} does not match number of bins {self.scaled_binwise_uncerts.shape}" + raise ValueError(msg) return jnp.sum( - poisson_logpdf(auxdata, (gamma*self.scaled_binwise_uncerts**-2)), - axis=None + poisson_logpdf(auxdata, (gamma * self.scaled_binwise_uncerts**-2)), + axis=None, ) - + + class UncorrelatedShape(Systematic): def __init__(self, name: str, nominal_bkg: Array, binwise_uncerts: Array) -> None: self.name = name @@ -70,8 +73,11 @@ def __init__(self, sig: Array, bkg: Array, db: Array) -> None: def expected_data(self, pars: dict[str, Array]) -> Array: mu, gamma = pars["mu"], pars["shapesys"] - return mu * self.sig + gamma * self.bkg, self.systematic.constraint.expected_data(gamma) - + return ( + mu * self.sig + gamma * self.bkg, + self.systematic.constraint.expected_data(gamma), + ) + def logpdf(self, data: Array, pars: dict[str, Array]) -> Array: maindata, auxdata = data main, _ = self.expected_data(pars) @@ -80,22 +86,21 @@ def logpdf(self, data: Array, pars: dict[str, Array]) -> Array: return main + constraint - -def _calc_yields(x: jnp.ndarray, n_bins: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: +def _calc_yields( + x: jnp.ndarray, n_bins: int +) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: s = 15 + x b = 45 - 2 * x db = 1 + 0.2 * x**2 - return [s]*n_bins, [b]*n_bins, [db]*n_bins + return [s] * n_bins, [b] * n_bins, [db] * n_bins def example_model( - phi: jnp.ndarray, n_bins:int, return_yields: bool = False + phi: jnp.ndarray, n_bins: int, return_yields: bool = False ) -> Model | tuple[Model, jnp.ndarray]: s, b, db = yields = _calc_yields(phi, n_bins) - model = HEPDataLike( - jnp.asarray(s), jnp.asarray(b), jnp.asarray(db) - ) + model = HEPDataLike(jnp.asarray(s), jnp.asarray(b), jnp.asarray(db)) if return_yields: return model, yields diff --git a/tests/test_fit.py b/tests/test_fit.py index 3efb47e..34cee50 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -4,37 +4,40 @@ import numpy as np import pyhf import pytest -from dummy_pyhf import example_model, uncorrelated_background -from jax import jacrev +from dummy_pyhf import example_model +from jax import config, jacrev, tree_util import relaxed +config.update("jax_enable_x64", True) + @pytest.mark.parametrize("phi", np.linspace(0.0, 10.0, 5)) def test_fit(phi): - pyhf.set_backend("jax") - analytic_pars = jnp.array([0.0, 1.0]) - model = example_model(phi) + analytic_pars = {"mu": 0.0, "shapesys": 1.0} + model = example_model(phi, n_bins=1) mle_pars = relaxed.mle.fit( model=model, data=model.expected_data(analytic_pars), - init_pars=model.config.suggested_init(), - bounds=model.config.suggested_bounds(), + init_pars={"mu": 1.0, "shapesys": 1.0}, + bounds={"mu": (-1, 10), "shapesys": (-1, 10)}, + ) + assert np.allclose( + tree_util.tree_flatten(mle_pars)[0], tree_util.tree_flatten(analytic_pars)[0] ) - assert np.allclose(mle_pars, analytic_pars, atol=0.05) def test_fit_grad(): - pyhf.set_backend("jax") - def pipeline(x): - analytic_pars = jnp.array([0.0, 1.0]) - model = example_model(x) - return relaxed.mle.fit( + model = example_model(x, n_bins=2) + pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} + + return relaxed.infer.hypotest( + 1.0, + data=model.expected_data(pars), model=model, - data=model.expected_data(analytic_pars), - init_pars=model.config.suggested_init(), - bounds=model.config.suggested_bounds(), + init_pars={"mu": jnp.array(1.0), "shapesys": jnp.array([1.0, 1.0])}, + bounds={"mu": (0, 10), "shapesys": (0, 10)}, ) jacrev(pipeline)(jnp.asarray(0.5)) @@ -42,22 +45,20 @@ def pipeline(x): @pytest.mark.parametrize("phi", np.linspace(0.0, 10.0, 5)) def test_fixed_poi_fit(phi): - pyhf.set_backend("jax") - analytic_pars = jnp.array([0.0, 1.0]) + pars = {"mu": 0.0, "shapesys": 1.0} + model, yields = example_model(phi, return_yields=True, n_bins=1) + init = {"shapesys": 1.0} - model, yields = example_model(phi, return_yields=True) - init = np.asarray(model.config.suggested_init()) - init = jnp.asarray(np.delete(init, model.config.poi_index)) - lower, upper = model.config.suggested_bounds() relaxed_mle = relaxed.mle.fixed_poi_fit( model=model, - data=model.expected_data(analytic_pars), + data=model.expected_data(pars), init_pars=init, - poi_condition=1.0, - bounds=(lower[1:], upper[1:]), + poi_value=1.0, + poi_name="mu", ) - + pyhf.set_backend("jax") m = pyhf.simplemodels.uncorrelated_background(*yields) + analytic_pars = jnp.array(m.config.suggested_init()).at[m.config.poi_index].set(0.0) pyhf_mle = pyhf.infer.mle.fixed_poi_fit( 1.0, @@ -65,18 +66,16 @@ def test_fixed_poi_fit(phi): m, ) - assert np.allclose(relaxed_mle, pyhf_mle, atol=0.05) + assert np.allclose(tree_util.tree_flatten(relaxed_mle)[0], pyhf_mle, atol=1e-4) def test_fixed_poi_fit_grad(): - pyhf.set_backend("jax") - def pipeline(x): - model = example_model(x) + model = example_model(x, n_bins=2) pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} - relaxed.infer.hypotest( - 1., + return relaxed.infer.hypotest( + 1.0, data=model.expected_data(pars), model=model, init_pars={"mu": jnp.array(1.0), "shapesys": jnp.array([1.0, 1.0])}, diff --git a/tests/test_infer.py b/tests/test_infer.py index 6c23f42..fb31cfd 100644 --- a/tests/test_infer.py +++ b/tests/test_infer.py @@ -5,7 +5,7 @@ import numpy as np import pyhf import pytest -from dummy_pyhf import example_model, uncorrelated_background +from dummy_pyhf import example_model from jax import jacrev import relaxed @@ -19,14 +19,21 @@ def test_hypotest_validity(phi, test_stat): pyhf.set_backend("jax") if test_stat == "q": analytic_pars = jnp.array([0.0, 1.0]) # bkg-only hypothesis + analytic_pars_dict = {"mu": 0.0, "shapesys": 1.0} elif test_stat == "q0": analytic_pars = jnp.array([1.0, 1.0]) # nominal sig+bkg hypothesis + analytic_pars_dict = {"mu": 1.0, "shapesys": 1.0} else: msg = f"Unknown test statistic: {test_stat}" raise ValueError(msg) - model, yields = example_model(phi, return_yields=True) + model, yields = example_model(phi, return_yields=True, n_bins=1) relaxed_cls = relaxed.infer.hypotest( - 1, model.expected_data(analytic_pars), model, test_stat=test_stat + test_poi=1, + poi_name="mu", + data=model.expected_data(analytic_pars_dict), + model=model, + init_pars={"mu": 1.0, "shapesys": 1.0}, + test_stat=test_stat, ) m = pyhf.simplemodels.uncorrelated_background(*yields) pyhf_cls = pyhf.infer.hypotest( @@ -35,7 +42,7 @@ def test_hypotest_validity(phi, test_stat): assert np.allclose( relaxed_cls, pyhf_cls, - ) # tested working without dummy_pyhf on a pyhf fork, but not main yet + ) @pytest.mark.parametrize("test_stat", ["q", "q0"]) @@ -43,18 +50,22 @@ def test_hypotest_expected(test_stat): pyhf.set_backend("jax") if test_stat == "q": analytic_pars = jnp.array([0.0, 1.0]) # bkg-only hypothesis + analytic_pars_dict = {"mu": 0.0, "shapesys": 1.0} elif test_stat == "q0": analytic_pars = jnp.array([1.0, 1.0]) # nominal sig+bkg hypothesis + analytic_pars_dict = {"mu": 1.0, "shapesys": 1.0} else: msg = f"Unknown test statistic: {test_stat}" raise ValueError(msg) - model, yields = example_model(5.0, return_yields=True) + model, yields = example_model(5.0, return_yields=True, n_bins=1) relaxed_cls = relaxed.infer.hypotest( - 1, - model.expected_data(analytic_pars), - model, + test_poi=1, + poi_name="mu", + data=model.expected_data(analytic_pars_dict), + model=model, + init_pars={"mu": 1.0, "shapesys": 1.0}, test_stat=test_stat, - expected_pars=analytic_pars, + expected_pars=analytic_pars_dict, ) m = pyhf.simplemodels.uncorrelated_background(*yields) pyhf_cls = pyhf.infer.hypotest( @@ -63,19 +74,20 @@ def test_hypotest_expected(test_stat): assert np.allclose( relaxed_cls, pyhf_cls, - ) # tested working without dummy_pyhf on a pyhf fork, but not main yet + ) @pytest.mark.parametrize("test_stat", ["q", "q0"]) @pytest.mark.parametrize("expected_pars", [True, False]) def test_hypotest_grad(test_stat, expected_pars): - pars = jnp.array([0.0, 1.0]) + pars = {"mu": 0.0, "shapesys": 1.0} expars = pars if expected_pars else None def pipeline(x): - model = uncorrelated_background(x * 5.0, x * 20, x * 2) + model = example_model(x * 5.0, n_bins=1) return relaxed.infer.hypotest( 1.0, + init_pars={"mu": 1.0, "shapesys": 1.0}, model=model, data=model.expected_data(pars), test_stat=test_stat, @@ -87,16 +99,17 @@ def pipeline(x): @pytest.mark.parametrize("expected_pars", [True, False]) def test_hypotest_grad_noCLs(expected_pars): - pars = jnp.array([0.0, 1.0]) + pars = {"mu": 0.0, "shapesys": 1.0} expars = pars if expected_pars else None def pipeline(x): - model = uncorrelated_background(x * 5.0, x * 20, x * 2) + model = example_model(x * 5.0, n_bins=1) return relaxed.infer.hypotest( 1.0, + init_pars={"mu": 1.0, "shapesys": 1.0}, model=model, data=model.expected_data(pars), - test_stat="q", + test_stat="qmu", expected_pars=expars, cls_method=False, ) @@ -105,11 +118,12 @@ def pipeline(x): def test_wrong_test_stat(): - model = example_model(0.0) + model = example_model(0.0, n_bins=1) with pytest.raises(ValueError, match="Unknown test statistic: q1"): relaxed.infer.hypotest( 1, - model.expected_data(jnp.array([0.0, 1.0])), - model, + data=model.expected_data({"mu": 0.0, "shapesys": 1.0}), + model=model, test_stat="q1", + init_pars={"mu": 1.0, "shapesys": 1.0}, ) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 23b3af8..55f32b8 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -4,8 +4,8 @@ import numpy as np import pyhf import pytest -from dummy_pyhf import example_model, uncorrelated_background -from jax import jacrev +from dummy_pyhf import example_model +from jax import jacrev, tree_map from jax.random import PRNGKey import relaxed @@ -23,7 +23,7 @@ def data(): def test_gaussianity(): pyhf.set_backend("jax") - m = uncorrelated_background( + m = pyhf.simplemodels.uncorrelated_background( jnp.array([5.0, 5.0]), jnp.array([10.0, 10.0]), jnp.array([0.1, 0.1]), @@ -35,10 +35,12 @@ def test_gaussianity(): def test_gaussianity_grad(): def pipeline(x): - model = example_model(5.0) - pars = model.config.suggested_init() + model = example_model(5.0, n_bins=2) + pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} data = model.expected_data(pars) - return relaxed.metrics.gaussianity(model, pars * x, data * x, PRNGKey(0)) + return relaxed.metrics.gaussianity( + model, tree_map(lambda a: a * x, pars), data * x, PRNGKey(0) + ) jacrev(pipeline)(4.0) # just check you can calc it w/o exception diff --git a/tests/test_ops.py b/tests/test_ops.py index 7b6c787..218be53 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -7,8 +7,8 @@ import numpy as np import pyhf import pytest -from dummy_pyhf import example_model, uncorrelated_background -from jax import jacrev, vmap +from dummy_pyhf import HEPDataLike, example_model +from jax import jacrev, tree_map, vmap from jax.random import PRNGKey, normal import relaxed @@ -104,8 +104,8 @@ def true_grad(mu, bins): def test_fisher_info(): pyhf.set_backend("jax") - model = example_model(5.0) - pars = model.config.suggested_init() + model = example_model(1.0, n_bins=1) + pars = {"mu": 1.0, "shapesys": 1.0} data = model.expected_data(pars) # just check that it doesn't crash relaxed.fisher_info(model, pars, data) @@ -128,37 +128,39 @@ def test_fisher_uncerts_validity(): # minuit fit uncerts mle_pars, mle_uncerts = fit_res[:, 0], fit_res[:, 1] - + mle_pars_dict = {"mu": mle_pars[0], "shapesys": mle_pars[1]} # uncertainties from autodiff hessian - dummy_m = uncorrelated_background( + dummy_m = HEPDataLike( jnp.array([5]), jnp.array([50]), jnp.array([5]), ) - relaxed_uncerts = relaxed.cramer_rao_uncert(dummy_m, mle_pars, data) + relaxed_uncerts = relaxed.cramer_rao_uncert(dummy_m, mle_pars_dict, data) assert np.allclose(mle_uncerts, relaxed_uncerts, rtol=0.05) def test_fisher_info_grad(): - pyhf.set_backend("jax") - def pipeline(x): - model = example_model(5.0) - pars = model.config.suggested_init() + model = example_model(5.0, n_bins=2) + pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} data = model.expected_data(pars) - return relaxed.fisher_info(model, pars * x, data * x) + return relaxed.metrics.gaussianity( + model, tree_map(lambda a: a * x, pars), data * x, PRNGKey(0) + ) - jacrev(pipeline)(4.0) # just check you can calc it w/o exception + jacrev(pipeline)(4.0) def test_fisher_uncert_grad(): pyhf.set_backend("jax") def pipeline(x): - model = example_model(5.0) - pars = model.config.suggested_init() + model = example_model(5.0, n_bins=2) + pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} data = model.expected_data(pars) - return relaxed.cramer_rao_uncert(model, pars * x, data * x) + return relaxed.cramer_rao_uncert( + model, tree_map(lambda a: a * x, pars), data * x + ) jacrev(pipeline)(4.0) # just check you can calc it w/o exception @@ -178,9 +180,7 @@ def test_cut_validity(big_sample, keep): @pytest.mark.parametrize("keep", ["above", "below"]) def test_cut_grad(keep): def pipeline(x): - model = example_model(5.0) - pars = model.config.suggested_init() - data = model.expected_data(pars) + data = jnp.array([1.0, 2.0, 3.0]) return relaxed.cut(data, x, keep=keep) jacrev(pipeline)(4.0) # just check you can calc it w/o exception From 20166a90de034c7a0aeaa66e07b78c9a5b1f983f Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Thu, 3 Aug 2023 15:04:21 +0100 Subject: [PATCH 04/13] add stuff --- src/relaxed/metrics.py | 23 ++++++++++------------- src/relaxed/ops.py | 13 +++++++------ tests/test_metrics.py | 19 ++++++++----------- tests/test_ops.py | 27 +++++++++++++++------------ 4 files changed, 40 insertions(+), 42 deletions(-) diff --git a/src/relaxed/metrics.py b/src/relaxed/metrics.py index 4b20c44..78dd639 100644 --- a/src/relaxed/metrics.py +++ b/src/relaxed/metrics.py @@ -2,13 +2,13 @@ __all__ = ("asimov_sig", "gaussianity") -from functools import partial from typing import TYPE_CHECKING, Any, cast +import equinox as eqx import jax import jax.numpy as jnp import jax.scipy as jsp -from jax import Array +from jax import Array, flatten_util from jax.random import multivariate_normal from relaxed.ops import fisher_info @@ -46,12 +46,7 @@ def _gaussian_logpdf( return cast(Array, jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov)) -@partial( - jax.jit, - static_argnames=[ - "n_samples", - ], -) +@eqx.filter_jit def gaussianity( model: PyTree, bestfit_pars: dict[str, Array], @@ -66,25 +61,27 @@ def gaussianity( # - centre the values wrt the best-fit vals to scale the differences cov_approx = jnp.linalg.inv(fisher_info(model, bestfit_pars, data)) - + flat_bestfit_pars, tree_structure = flatten_util.ravel_pytree(bestfit_pars) gaussian_parspace_samples = multivariate_normal( key=rng_key, - mean=bestfit_pars, + mean=flat_bestfit_pars, cov=cov_approx, shape=(n_samples,), ) + gaussian_parspace_samples = tree_structure(gaussian_parspace_samples) relative_nlls_model = jax.vmap( lambda pars, data: -( - model.logpdf(pars, data) - model.logpdf(bestfit_pars, data) + model.logpdf(pars=pars, data=data) + - model.logpdf(pars=bestfit_pars, data=data) ), # scale origin to bestfit pars in_axes=(0, None), )(gaussian_parspace_samples, data) relative_nlls_gaussian = jax.vmap( lambda pars, data: -( - _gaussian_logpdf(pars, data, cov_approx)[0] - - _gaussian_logpdf(bestfit_pars, data, cov_approx)[0] + _gaussian_logpdf(pars, data, cov_approx) + - _gaussian_logpdf(bestfit_pars, data, cov_approx) ), # data fixes the lhood shape in_axes=(0, None), )(gaussian_parspace_samples, bestfit_pars) diff --git a/src/relaxed/ops.py b/src/relaxed/ops.py index 32be3e5..8e059af 100644 --- a/src/relaxed/ops.py +++ b/src/relaxed/ops.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp import jax.scipy as jsp -from jax import Array +from jax import Array, flatten_util @partial(jax.jit, static_argnames=["keep"]) @@ -118,10 +118,10 @@ def fisher_info(model: Any, pars: dict[str, Array], data: Array) -> Array: def lpdf(pars, data): # handle keyword arguments return model.logpdf(pars=pars, data=data) - num_pars = len(jax.tree_util.tree_flatten(pars)[0]) - hessian = jnp.array( - jax.tree_util.tree_flatten(jax.hessian(lpdf)(pars, data))[0] - ).reshape(num_pars, num_pars) + num_pars = len(flatten_util.ravel_pytree(pars)[0]) + hessian = flatten_util.ravel_pytree(jax.hessian(lpdf)(pars, data))[0].reshape( + num_pars, num_pars + ) return jnp.linalg.inv(-hessian) @@ -146,4 +146,5 @@ def cramer_rao_uncert(model: Any, pars: Array, data: Array) -> Array: Array Cramer-Rao uncertainty on the MLE parameters. """ - return jnp.sqrt(jnp.diag(fisher_info(model, pars, data))) + fisher = fisher_info(model, pars, data) + return jnp.sqrt(jnp.diag(fisher)) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 55f32b8..65ba638 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -2,7 +2,6 @@ import jax.numpy as jnp import numpy as np -import pyhf import pytest from dummy_pyhf import example_model from jax import jacrev, tree_map @@ -22,15 +21,10 @@ def data(): def test_gaussianity(): - pyhf.set_backend("jax") - m = pyhf.simplemodels.uncorrelated_background( - jnp.array([5.0, 5.0]), - jnp.array([10.0, 10.0]), - jnp.array([0.1, 0.1]), - ) - pars = jnp.asarray(m.config.suggested_init()) - data = jnp.asarray(m.expected_data(pars)) - relaxed.metrics.gaussianity(m, pars, data, PRNGKey(0)) + model = example_model(5.0, n_bins=2) + pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} + data = model.expected_data(pars) + relaxed.metrics.gaussianity(model, pars, data, PRNGKey(0)) def test_gaussianity_grad(): @@ -39,7 +33,10 @@ def pipeline(x): pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} data = model.expected_data(pars) return relaxed.metrics.gaussianity( - model, tree_map(lambda a: a * x, pars), data * x, PRNGKey(0) + model, + tree_map(lambda a: a * x, pars), + (data[0] * x, data[1] * x), + PRNGKey(0), ) jacrev(pipeline)(4.0) # just check you can calc it w/o exception diff --git a/tests/test_ops.py b/tests/test_ops.py index 218be53..37de2eb 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -111,9 +111,12 @@ def test_fisher_info(): relaxed.fisher_info(model, pars, data) -def test_fisher_uncerts_validity(): +@pytest.mark.parametrize("n_bins", [1, 1]) +def test_fisher_uncerts_validity(n_bins): pyhf.set_backend("jax", pyhf.optimize.minuit_optimizer(verbose=1)) - m = pyhf.simplemodels.uncorrelated_background([5], [50], [5]) + m = pyhf.simplemodels.uncorrelated_background( + [5] * n_bins, [50] * n_bins, [5] * n_bins + ) data = jnp.array([50.0, *m.config.auxdata]) fit_res = pyhf.infer.mle.fit( @@ -122,18 +125,18 @@ def test_fisher_uncerts_validity(): return_uncertainties=True, par_bounds=[ [-1, 10], - [-1, 10], + *[[-1, 10]] * n_bins, ], # fit @ boundary produces unstable uncerts ) # minuit fit uncerts mle_pars, mle_uncerts = fit_res[:, 0], fit_res[:, 1] - mle_pars_dict = {"mu": mle_pars[0], "shapesys": mle_pars[1]} + mle_pars_dict = {"mu": mle_pars[0], "shapesys": mle_pars[1:]} # uncertainties from autodiff hessian dummy_m = HEPDataLike( - jnp.array([5]), - jnp.array([50]), - jnp.array([5]), + jnp.array([5] * n_bins), + jnp.array([50] * n_bins), + jnp.array([5] * n_bins), ) relaxed_uncerts = relaxed.cramer_rao_uncert(dummy_m, mle_pars_dict, data) assert np.allclose(mle_uncerts, relaxed_uncerts, rtol=0.05) @@ -144,24 +147,24 @@ def pipeline(x): model = example_model(5.0, n_bins=2) pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} data = model.expected_data(pars) - return relaxed.metrics.gaussianity( - model, tree_map(lambda a: a * x, pars), data * x, PRNGKey(0) + return relaxed.fisher_info( + model, tree_map(lambda a: a * x, pars), tree_map(lambda a: a * x, data) ) + pipeline(4.0) # just check you can calc it w/o exception jacrev(pipeline)(4.0) def test_fisher_uncert_grad(): - pyhf.set_backend("jax") - def pipeline(x): model = example_model(5.0, n_bins=2) pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} data = model.expected_data(pars) return relaxed.cramer_rao_uncert( - model, tree_map(lambda a: a * x, pars), data * x + model, tree_map(lambda a: a * x, pars), (data[0] * x, data[1] * x) ) + pipeline(4.0) # just check you can calc it w/o exceptio jacrev(pipeline)(4.0) # just check you can calc it w/o exception From b9aaed965aa0e7d832572591c8046b665333a1a8 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Thu, 3 Aug 2023 19:13:07 +0100 Subject: [PATCH 05/13] impose tree structure in right places --- src/relaxed/metrics.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/relaxed/metrics.py b/src/relaxed/metrics.py index 78dd639..2376146 100644 --- a/src/relaxed/metrics.py +++ b/src/relaxed/metrics.py @@ -43,7 +43,7 @@ def _gaussian_logpdf( data: Array, cov: Array, ) -> Array: - return cast(Array, jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov)) + return jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov) @eqx.filter_jit @@ -68,11 +68,10 @@ def gaussianity( cov=cov_approx, shape=(n_samples,), ) - gaussian_parspace_samples = tree_structure(gaussian_parspace_samples) relative_nlls_model = jax.vmap( lambda pars, data: -( - model.logpdf(pars=pars, data=data) + model.logpdf(pars=tree_structure(pars), data=data) - model.logpdf(pars=bestfit_pars, data=data) ), # scale origin to bestfit pars in_axes=(0, None), @@ -81,10 +80,10 @@ def gaussianity( relative_nlls_gaussian = jax.vmap( lambda pars, data: -( _gaussian_logpdf(pars, data, cov_approx) - - _gaussian_logpdf(bestfit_pars, data, cov_approx) + - _gaussian_logpdf(flat_bestfit_pars, data, cov_approx) ), # data fixes the lhood shape in_axes=(0, None), - )(gaussian_parspace_samples, bestfit_pars) + )(gaussian_parspace_samples, flat_bestfit_pars) diffs = relative_nlls_model - relative_nlls_gaussian return jnp.nanmean(diffs**2, axis=0) From 1bf2e0179fefd59ef2efcd9b11e584dca0c4143e Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Thu, 3 Aug 2023 19:38:01 +0100 Subject: [PATCH 06/13] sort typing --- src/relaxed/metrics.py | 5 +++-- src/relaxed/mle.py | 23 +++++++++++++---------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/relaxed/metrics.py b/src/relaxed/metrics.py index 2376146..ad87862 100644 --- a/src/relaxed/metrics.py +++ b/src/relaxed/metrics.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: PyTree = Any + from jax.typing import ArrayLike @jax.jit @@ -43,13 +44,13 @@ def _gaussian_logpdf( data: Array, cov: Array, ) -> Array: - return jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov) + return cast(Array, jsp.stats.multivariate_normal.logpdf(data, bestfit_pars, cov)) @eqx.filter_jit def gaussianity( model: PyTree, - bestfit_pars: dict[str, Array], + bestfit_pars: dict[str, ArrayLike], data: Array, rng_key: Any, n_samples: int = 1000, diff --git a/src/relaxed/mle.py b/src/relaxed/mle.py index b6a08f7..41427c7 100644 --- a/src/relaxed/mle.py +++ b/src/relaxed/mle.py @@ -3,7 +3,7 @@ __all__ = ("fit", "fixed_poi_fit") import inspect -from typing import TYPE_CHECKING, Any, Callable, cast +from typing import TYPE_CHECKING, Any, Callable, Sized, cast import jax import jax.numpy as jnp @@ -20,7 +20,7 @@ def _parse_bounds( bounds: dict[str, ArrayLike], init_pars: dict[str, ArrayLike] -) -> tuple[dict[str, ArrayLike], dict[str, ArrayLike]]: +) -> tuple[dict[str, Array], dict[str, Array]]: """Convert dict of bounds to a dict of lower and a dict of upper bounds.""" lower = {} upper = {} @@ -33,10 +33,10 @@ def _parse_bounds( if array_v.ndim == 1: if ( isinstance(init_pars[k], (list, jax.Array, np.ndarray)) - and init_pars[k].size > 1 + and jnp.array(init_pars[k]).size > 1 ): # If the initial parameter is a list or array - lower[k] = jnp.array([array_v[0]] * len(init_pars[k])) - upper[k] = jnp.array([array_v[1]] * len(init_pars[k])) + lower[k] = jnp.array([array_v[0]] * len(cast(Sized, init_pars[k]))) + upper[k] = jnp.array([array_v[1]] * len(cast(Sized, init_pars[k]))) else: # If the initial parameter is a single value lower[k] = array_v[0] upper[k] = array_v[1] @@ -53,7 +53,7 @@ def _minimize( model: PyTree, data: Array, init_pars: dict[str, ArrayLike], - bounds: dict[str, ArrayLike], + bounds: dict[str, ArrayLike] | None, method: str = "LBFGSB", maxiter: int = 500, tol: float = 1e-6, @@ -67,8 +67,11 @@ def _minimize( ) if "bounds" in inspect.signature(minimizer.init_state).parameters: if bounds is not None: - bounds = _parse_bounds(bounds, init_pars) - return minimizer.run(init_pars, bounds=bounds, model=model, data=data)[0] + lower, upper = _parse_bounds(bounds, init_pars) + return minimizer.run( + init_pars, bounds=(lower, upper), model=model, data=data + )[0] + return minimizer.run(init_pars, bounds=None, model=model, data=data)[0] return minimizer.run(init_pars, model=model, data=data)[0] @@ -103,7 +106,7 @@ def fit_objective(pars: Array, model: PyTree, data: Array) -> float: def fixed_poi_fit( data: Array, model: PyTree, - poi_value: float, + poi_value: ArrayLike, poi_name: str, init_pars: dict[str, ArrayLike], bounds: dict[str, Array] | None = None, @@ -113,7 +116,7 @@ def fixed_poi_fit( other_settings: dict[str, float] | None = None, ) -> dict[str, Array]: def fit_objective( - pars: dict[str, Array], model: PyTree, data: Array + pars: dict[str, ArrayLike], model: PyTree, data: Array ) -> float: # NLL """lhood_pars_to_optimize: either all pars, or just nuisance pars""" pars[poi_name] = poi_value From d73e8ce8f2c3b53d6c783388c1bc8cc12ca50846 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Fri, 4 Aug 2023 08:45:45 +0100 Subject: [PATCH 07/13] update comments --- src/relaxed/infer.py | 2 +- src/relaxed/mle.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/relaxed/infer.py b/src/relaxed/infer.py index 4679aa7..577238c 100644 --- a/src/relaxed/infer.py +++ b/src/relaxed/infer.py @@ -151,7 +151,7 @@ def qmu_test( qmu = jnp.where(poi_hat < test_poi, profile_likelihood_ratio, 0.0) pmu = 1 - jsp.stats.norm.cdf(jnp.sqrt(qmu), loc=0, scale=1) if cls_method: - alternative_hypothesis = 0.0 # point alternative is bkg-only + alternative_hypothesis = 0.0 power_of_test = 1 - jsp.stats.norm.cdf(alternative_hypothesis, loc=0, scale=1) result = pmu / power_of_test # same as CLs = p_sb/(1-p_b) = CLs+b/CLb else: diff --git a/src/relaxed/mle.py b/src/relaxed/mle.py index 41427c7..2dac622 100644 --- a/src/relaxed/mle.py +++ b/src/relaxed/mle.py @@ -21,7 +21,7 @@ def _parse_bounds( bounds: dict[str, ArrayLike], init_pars: dict[str, ArrayLike] ) -> tuple[dict[str, Array], dict[str, Array]]: - """Convert dict of bounds to a dict of lower and a dict of upper bounds.""" + """Convert dict of bounds to a dict of lower bounds and a dict of upper bounds.""" lower = {} upper = {} @@ -115,10 +115,7 @@ def fixed_poi_fit( tol: float = 1e-6, other_settings: dict[str, float] | None = None, ) -> dict[str, Array]: - def fit_objective( - pars: dict[str, ArrayLike], model: PyTree, data: Array - ) -> float: # NLL - """lhood_pars_to_optimize: either all pars, or just nuisance pars""" + def fit_objective(pars: dict[str, ArrayLike], model: PyTree, data: Array) -> float: pars[poi_name] = poi_value return cast(float, -model.logpdf(data=data, pars=pars)) From 765048b1187f46dfbaa71a3dffec6f21d71f6d25 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Fri, 4 Aug 2023 09:01:21 +0100 Subject: [PATCH 08/13] update deps --- pyproject.toml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f503297..a91359f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,9 +30,8 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "equinox", - "jaxopt>=0.7", # LBFGSB - "optax>=0.1.2", # deprecated jax.tree_multimap + "equinox>=0.10.6", # eqx.field + "jaxopt>=0.7", # LBFGSB "typing_extensions >=4.6; python_version<'3.11'", ] From bc59b0ad29506790f6c52086266e2517bee1a43e Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Fri, 4 Aug 2023 10:49:35 +0100 Subject: [PATCH 09/13] fix problem with uncerts --- pyproject.toml | 2 +- src/relaxed/ops.py | 23 ++++++++++++++--------- tests/test_ops.py | 33 +++++++++++++++++++++++++-------- 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a91359f..c78cbef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ test = [ "pytest >=6", "pytest-cov >=3", - "pyhf[jax] >=0.7.0", + "pyhf[jax]", # should be pyhf[jax] >=0.7.1, but not released yet "iminuit", ] dev = [ diff --git a/src/relaxed/ops.py b/src/relaxed/ops.py index 8e059af..9dec23a 100644 --- a/src/relaxed/ops.py +++ b/src/relaxed/ops.py @@ -5,6 +5,7 @@ from functools import partial from typing import Any +import equinox as eqx import jax import jax.numpy as jnp import jax.scipy as jsp @@ -115,18 +116,18 @@ def fisher_info(model: Any, pars: dict[str, Array], data: Array) -> Array: Parameters with multiple dimensions are flattened into their own columns. """ + flat_pars, tree_structure = flatten_util.ravel_pytree(pars) + def lpdf(pars, data): # handle keyword arguments - return model.logpdf(pars=pars, data=data) + return model.logpdf(pars=tree_structure(pars), data=data) - num_pars = len(flatten_util.ravel_pytree(pars)[0]) - hessian = flatten_util.ravel_pytree(jax.hessian(lpdf)(pars, data))[0].reshape( - num_pars, num_pars - ) - return jnp.linalg.inv(-hessian) + return jnp.linalg.inv(-jax.hessian(lpdf)(flat_pars, data)) -@jax.jit -def cramer_rao_uncert(model: Any, pars: Array, data: Array) -> Array: +@eqx.filter_jit +def cramer_rao_uncert( + model: Any, pars: dict[str, Array], data: Array, return_tree=True +) -> Array: """Approximate uncertainties on MLE parameters for a model with a logpdf method. Defined as the square root of the diagonal of the Fisher information matrix, valid via the Cramer-Rao lower bound. @@ -147,4 +148,8 @@ def cramer_rao_uncert(model: Any, pars: Array, data: Array) -> Array: Cramer-Rao uncertainty on the MLE parameters. """ fisher = fisher_info(model, pars, data) - return jnp.sqrt(jnp.diag(fisher)) + uncert = jnp.sqrt(jnp.diag(fisher)) + if return_tree: + _, tree_structure = flatten_util.ravel_pytree(pars) + return tree_structure(uncert) + return uncert diff --git a/tests/test_ops.py b/tests/test_ops.py index 37de2eb..957161a 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -111,13 +111,13 @@ def test_fisher_info(): relaxed.fisher_info(model, pars, data) -@pytest.mark.parametrize("n_bins", [1, 1]) +@pytest.mark.parametrize("n_bins", [1, 2, 10]) def test_fisher_uncerts_validity(n_bins): pyhf.set_backend("jax", pyhf.optimize.minuit_optimizer(verbose=1)) m = pyhf.simplemodels.uncorrelated_background( [5] * n_bins, [50] * n_bins, [5] * n_bins ) - data = jnp.array([50.0, *m.config.auxdata]) + data = jnp.array([*[50.0] * n_bins, *m.config.auxdata]) fit_res = pyhf.infer.mle.fit( data, @@ -138,8 +138,21 @@ def test_fisher_uncerts_validity(n_bins): jnp.array([50] * n_bins), jnp.array([5] * n_bins), ) - relaxed_uncerts = relaxed.cramer_rao_uncert(dummy_m, mle_pars_dict, data) - assert np.allclose(mle_uncerts, relaxed_uncerts, rtol=0.05) + data_rlx = jnp.array([50.0] * n_bins), jnp.array(m.config.auxdata) + + fisher_rlx = relaxed.fisher_info(dummy_m, mle_pars_dict, data_rlx) + fisher_pyhf = jnp.linalg.inv( + -jax.hessian(lambda p, d: m.logpdf(p, d)[0])(mle_pars, data) + ) + + # compare + assert np.allclose(fisher_rlx, fisher_pyhf) # exact match for fisher info + relaxed_uncerts = relaxed.cramer_rao_uncert( + dummy_m, mle_pars_dict, data_rlx, return_tree=False + ) + assert np.allclose( + mle_uncerts, relaxed_uncerts, rtol=5e-2 + ) # within 5%, don't expect exact match def test_fisher_info_grad(): @@ -152,19 +165,23 @@ def pipeline(x): ) pipeline(4.0) # just check you can calc it w/o exception - jacrev(pipeline)(4.0) + jacrev(pipeline)(4.0) # just check you can calc it w/o exception -def test_fisher_uncert_grad(): +@pytest.mark.parametrize("return_tree", [True, False]) +def test_fisher_uncert_grad(return_tree): def pipeline(x): model = example_model(5.0, n_bins=2) pars = {"mu": jnp.array(0.0), "shapesys": jnp.array([1.0, 1.0])} data = model.expected_data(pars) return relaxed.cramer_rao_uncert( - model, tree_map(lambda a: a * x, pars), (data[0] * x, data[1] * x) + model, + tree_map(lambda a: a * x, pars), + (data[0] * x, data[1] * x), + return_tree=return_tree, ) - pipeline(4.0) # just check you can calc it w/o exceptio + pipeline(4.0) # just check you can calc it w/o exception jacrev(pipeline)(4.0) # just check you can calc it w/o exception From 3314b9acafb316358afcbfe1e0f52a1c0b92a5e3 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Fri, 4 Aug 2023 10:50:04 +0100 Subject: [PATCH 10/13] verbosity --- tests/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 957161a..3f29a51 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -152,7 +152,7 @@ def test_fisher_uncerts_validity(n_bins): ) assert np.allclose( mle_uncerts, relaxed_uncerts, rtol=5e-2 - ) # within 5%, don't expect exact match + ) # within 5%, don't expect exact match with minuit def test_fisher_info_grad(): From 96fc7fcb05ec0af26841486e3cf0c963a4c7f03f Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Fri, 4 Aug 2023 10:52:01 +0100 Subject: [PATCH 11/13] fix ci --- .github/workflows/ci.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 45bbf85..6d042ff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,7 +41,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.11"] + python-version: ["3.9", "3.11"] runs-on: [ubuntu-latest] steps: @@ -57,10 +57,15 @@ jobs: - name: Install package run: python -m pip install .[test] - - name: Install jaxopt from source (temp) + - name: Install jaxopt from source (temp -- my lbfgs fix) run: python -m pip install --upgrade --force-reinstall git+https://github.com/google/jaxopt.git + - name: Install pyhf from source (temp -- jax devicearray not found fix) + run: + python -m pip install --upgrade --force-reinstall + git+https://github.com/scikit-hep/pyhf.git + - name: Test package run: >- pytest -ra --cov --cov-report=xml --cov-report=term --durations=20 From bf788eec8638439e603dd78084b2c90a0bc3c5d4 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Fri, 4 Aug 2023 10:57:18 +0100 Subject: [PATCH 12/13] update python version to drop 3.8 --- pyproject.toml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c78cbef..a5f6fcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] description = "Differentiable versions of common HEP operations." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" classifiers = [ "Development Status :: 1 - Planning", "Intended Audience :: Science/Research", @@ -20,7 +20,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -93,7 +92,7 @@ port.exclude_lines = [ [tool.mypy] files = ["src", "tests"] -python_version = "3.8" +python_version = "3.9" warn_unused_configs = true strict = true show_error_codes = true From 17a4b289d452287a99848bfa433c4f46c100482b Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Fri, 4 Aug 2023 11:14:01 +0100 Subject: [PATCH 13/13] lint --- src/relaxed/mle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relaxed/mle.py b/src/relaxed/mle.py index 2dac622..68d8a72 100644 --- a/src/relaxed/mle.py +++ b/src/relaxed/mle.py @@ -3,7 +3,8 @@ __all__ = ("fit", "fixed_poi_fit") import inspect -from typing import TYPE_CHECKING, Any, Callable, Sized, cast +from collections.abc import Sized +from typing import TYPE_CHECKING, Any, Callable, cast import jax import jax.numpy as jnp