Skip to content

Commit

Permalink
add new kwarg for CLsb
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathan Simpson committed Oct 28, 2022
1 parent 4c5a148 commit d474a44
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/relaxed/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def hypotest(
return_mle_pars: bool = False,
test_stat: str = "q",
expected_pars: Array | None = None,
cls_method: bool = True,
) -> tuple[Array, Array] | Array:
"""Calculate expected CLs/p-values via hypothesis tests.
Expand Down Expand Up @@ -53,7 +54,9 @@ def hypotest(
The MLE parameters, if `return_mle_pars` is True.
"""
if test_stat == "q":
return qmu_test(test_poi, data, model, lr, return_mle_pars, expected_pars)
return qmu_test(
test_poi, data, model, lr, return_mle_pars, expected_pars, cls_method
)
elif test_stat == "q0":
logging.info(
"test_poi automatically set to 0 for q0 test (bkg-only null hypothesis)"
Expand All @@ -64,7 +67,7 @@ def hypotest(


@partial(
jit, static_argnames=["model", "return_mle_pars"]
jit, static_argnames=["model", "return_mle_pars", "cls_method"]
) # can remove model eventually
def qmu_test(
test_poi: float,
Expand All @@ -73,6 +76,7 @@ def qmu_test(
lr: float,
return_mle_pars: bool = False,
expected_pars: Array | None = None,
cls_method: bool = True,
) -> tuple[Array, Array] | Array:
# hard-code 1 as inits for now
# TODO: need to parse different inits for constrained and global fits
Expand All @@ -93,9 +97,12 @@ def qmu_test(
qmu = jnp.where(poi_hat < test_poi, profile_likelihood, 0.0)

CLsb = 1 - pyhf.tensorlib.normal_cdf(jnp.sqrt(qmu))
altval = 0.0
CLb = 1 - pyhf.tensorlib.normal_cdf(altval)
CLs = CLsb / CLb
if cls_method:
altval = 0.0
CLb = 1 - pyhf.tensorlib.normal_cdf(altval)
CLs = CLsb / CLb
else:
CLs = CLsb
return (CLs, mle_pars) if return_mle_pars else CLs


Expand Down
24 changes: 24 additions & 0 deletions tests/test_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,30 @@ def pipeline(x):
jacrev(pipeline)(jnp.asarray(0.5))


@pytest.mark.parametrize("expected_pars", [True, False])
def test_hypotest_grad_noCLs(expected_pars):
pars = jnp.array([0.0, 1.0])
if expected_pars:
expars = pars
else:
expars = None

def pipeline(x):
model = uncorrelated_background(x * 5.0, x * 20, x * 2)
expected_cls = relaxed.infer.hypotest(
1.0,
model=model,
data=model.expected_data(pars),
lr=1e-2,
test_stat="q",
expected_pars=expars,
cls_method=False,
)
return expected_cls

jacrev(pipeline)(jnp.asarray(0.5))


def test_wrong_test_stat():
with pytest.raises(ValueError):
model = example_model(0.0)
Expand Down

0 comments on commit d474a44

Please sign in to comment.