From d474a4442d3d2a54a76edad5bd76c77a508540c3 Mon Sep 17 00:00:00 2001 From: Nathan Simpson Date: Fri, 28 Oct 2022 09:59:59 +0200 Subject: [PATCH] add new kwarg for CLsb --- src/relaxed/infer.py | 17 ++++++++++++----- tests/test_infer.py | 24 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/relaxed/infer.py b/src/relaxed/infer.py index 55434e3..8fc1acd 100644 --- a/src/relaxed/infer.py +++ b/src/relaxed/infer.py @@ -22,6 +22,7 @@ def hypotest( return_mle_pars: bool = False, test_stat: str = "q", expected_pars: Array | None = None, + cls_method: bool = True, ) -> tuple[Array, Array] | Array: """Calculate expected CLs/p-values via hypothesis tests. @@ -53,7 +54,9 @@ def hypotest( The MLE parameters, if `return_mle_pars` is True. """ if test_stat == "q": - return qmu_test(test_poi, data, model, lr, return_mle_pars, expected_pars) + return qmu_test( + test_poi, data, model, lr, return_mle_pars, expected_pars, cls_method + ) elif test_stat == "q0": logging.info( "test_poi automatically set to 0 for q0 test (bkg-only null hypothesis)" @@ -64,7 +67,7 @@ def hypotest( @partial( - jit, static_argnames=["model", "return_mle_pars"] + jit, static_argnames=["model", "return_mle_pars", "cls_method"] ) # can remove model eventually def qmu_test( test_poi: float, @@ -73,6 +76,7 @@ def qmu_test( lr: float, return_mle_pars: bool = False, expected_pars: Array | None = None, + cls_method: bool = True, ) -> tuple[Array, Array] | Array: # hard-code 1 as inits for now # TODO: need to parse different inits for constrained and global fits @@ -93,9 +97,12 @@ def qmu_test( qmu = jnp.where(poi_hat < test_poi, profile_likelihood, 0.0) CLsb = 1 - pyhf.tensorlib.normal_cdf(jnp.sqrt(qmu)) - altval = 0.0 - CLb = 1 - pyhf.tensorlib.normal_cdf(altval) - CLs = CLsb / CLb + if cls_method: + altval = 0.0 + CLb = 1 - pyhf.tensorlib.normal_cdf(altval) + CLs = CLsb / CLb + else: + CLs = CLsb return (CLs, mle_pars) if return_mle_pars else CLs diff --git a/tests/test_infer.py b/tests/test_infer.py index afeeda6..9052481 100644 --- a/tests/test_infer.py +++ b/tests/test_infer.py @@ -86,6 +86,30 @@ def pipeline(x): jacrev(pipeline)(jnp.asarray(0.5)) +@pytest.mark.parametrize("expected_pars", [True, False]) +def test_hypotest_grad_noCLs(expected_pars): + pars = jnp.array([0.0, 1.0]) + if expected_pars: + expars = pars + else: + expars = None + + def pipeline(x): + model = uncorrelated_background(x * 5.0, x * 20, x * 2) + expected_cls = relaxed.infer.hypotest( + 1.0, + model=model, + data=model.expected_data(pars), + lr=1e-2, + test_stat="q", + expected_pars=expars, + cls_method=False, + ) + return expected_cls + + jacrev(pipeline)(jnp.asarray(0.5)) + + def test_wrong_test_stat(): with pytest.raises(ValueError): model = example_model(0.0)