Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Catch invalid model specification with duplicate channels, samples and modifiers. #2396

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/pyhf/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Any, Sequence

from pyhf import exceptions
from pyhf.typing import Channel

log = logging.getLogger(__name__)
Expand All @@ -21,6 +22,10 @@ class _ChannelSummaryMixin:
def __init__(self, *args: Any, **kwargs: Sequence[Channel]):
channels = kwargs.pop('channels')
super().__init__(*args, **kwargs)

# check for duplicates
self._check_for_duplicates(channels)

self._channels: list[str] = []
self._samples: list[str] = []
self._modifiers: list[tuple[str, str]] = []
Expand Down Expand Up @@ -89,3 +94,45 @@ def channel_slices(self) -> dict[str, slice]:
Dictionary mapping channel name to the bin slices in the model.
"""
return self._channel_slices

def _check_for_duplicates(self, channels: Sequence[Channel]) -> None:
"""
Check for duplicate channels.
Check for duplicate samples within each channel.
Check for duplicate modifiers within each sample.
"""
channel_names = [channel['name'] for channel in channels]
if len(channel_names) != len(set(channel_names)):
duplicates = sorted(
set([f"'{x}'" for x in channel_names if channel_names.count(x) > 1])
)
raise exceptions.InvalidModel(
"Duplicate channels "
+ ", ".join(duplicates)
+ " found in the model. Rename one of them."
)
for channel in channels:
sample_names = [samples['name'] for samples in channel['samples']]
if len(sample_names) != len(set(sample_names)):
duplicates = sorted(
set([f"'{x}'" for x in sample_names if sample_names.count(x) > 1])
)
raise exceptions.InvalidModel(
"Duplicate samples "
+ ", ".join(duplicates)
+ f" found in the channel '{channel['name']}'. Rename one of them."
)
for sample in channel['samples']:
modifiers = [
(modifier['name'], modifier['type'])
for modifier in sample['modifiers']
]
if len(modifiers) != len(set(modifiers)):
duplicates = sorted(
set([f"'{x[0]}'" for x in modifiers if modifiers.count(x) > 1])
)
raise exceptions.InvalidModel(
"Duplicate modifiers "
+ ", ".join(duplicates)
+ f" of the same type found in channel '{channel['name']}' and sample '{sample['name']}'. Rename one of them."
)
15 changes: 15 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,18 @@ def test_schema_tensor_type_disallowed(mocker, backend):
}
with pytest.raises(pyhf.exceptions.InvalidSpecification):
pyhf.schema.validate(spec, "model.json")


@pytest.mark.parametrize(
'model_file',
[
'model_duplicate_channels.json',
'model_duplicate_samples.json',
'model_duplicate_modifiers.json',
],
)
def test_schema_catch_duplicates(datadir, model_file):
with open(datadir.joinpath(model_file), encoding="utf-8") as spec_file:
model_spec = json.load(spec_file)
with pytest.raises(pyhf.exceptions.InvalidModel):
pyhf.Model(model_spec)
50 changes: 50 additions & 0 deletions tests/test_schema/model_duplicate_channels.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"channels": [
{
"name": "singlechannel",
"samples": [
{
"name": "signal",
"data": [10],
"modifiers": [
{"name": "mu", "type": "normfactor", "data": null}
]
},
{
"name": "background",
"data": [15],
"modifiers": [
{
"name": "uncorr_bkguncrt",
"type": "shapesys",
"data": [5]
}
]
}
]
},
{
"name": "singlechannel",
"samples": [
{
"name": "signal",
"data": [10],
"modifiers": [
{"name": "mu", "type": "normfactor", "data": null}
]
},
{
"name": "background",
"data": [15],
"modifiers": [
{
"name": "uncorr_bkguncrt",
"type": "shapesys",
"data": [5]
}
]
}
]
}
]
}
28 changes: 28 additions & 0 deletions tests/test_schema/model_duplicate_modifiers.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"channels": [
{
"name": "singlechannel",
"samples": [
{
"name": "signal",
"data": [10],
"modifiers": [
{"name": "mu", "type": "normfactor", "data": null},
{"name": "mu", "type": "normfactor", "data": null}
]
},
{
"name": "background",
"data": [15],
"modifiers": [
{
"name": "uncorr_bkguncrt",
"type": "shapesys",
"data": [5]
}
]
}
]
}
]
}
34 changes: 34 additions & 0 deletions tests/test_schema/model_duplicate_samples.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"channels": [
{
"name": "singlechannel",
"samples": [
{
"name": "signal",
"data": [10],
"modifiers": [
{"name": "mu", "type": "normfactor", "data": null}
]
},
{
"name": "signal",
"data": [10],
"modifiers": [
{"name": "mu", "type": "normfactor", "data": null}
]
},
{
"name": "background",
"data": [15],
"modifiers": [
{
"name": "uncorr_bkguncrt",
"type": "shapesys",
"data": [5]
}
]
}
]
}
]
}