diff --git a/strax/chunk.py b/strax/chunk.py index 273c1ab57..d119fb8c6 100644 --- a/strax/chunk.py +++ b/strax/chunk.py @@ -39,7 +39,8 @@ def __init__(self, end, data, subruns=None, - target_size_mb=default_chunk_size_mb): + target_size_mb=default_chunk_size_mb, + strict_bounds=True): self.data_type = data_type self.data_kind = data_kind self.dtype = np.dtype(dtype) @@ -47,6 +48,7 @@ def __init__(self, self.start = start self.end = end self.subruns = subruns + self.strict_bounds = strict_bounds if data is None: data = np.empty(0, dtype) self.data = data @@ -73,7 +75,7 @@ def __init__(self, raise ValueError(f"Attempt to create chunk {self} " f"with negative length") - if len(self.data): + if len(self.data) and strict_bounds: data_starts_at = self.data[0]['time'] # Check the last 500 samples (arbitrary number) as sanity check data_ends_at = strax.endtime(self.data[-500:]).max() @@ -152,16 +154,25 @@ def _mbs(self): def split(self, t: ty.Union[int, None], - allow_early_split=False): + allow_early_split=False, + allow_overlap=False): """Return (chunk_left, chunk_right) split at time t. :param t: Time at which to split the data. - All data in the left chunk will have their (exclusive) end <= t, - all data in the right chunk will have (inclusive) start >=t. + :param allow_early_split: If False, raise CannotSplit if the requirements above cannot be met. If True, split at the closest possible time before t. + :param allow_overlap: + Whether to allow the split chunks to overlap. + if True, data will be included in a given interval (before/after t) + based on whether it overlaps the interval or not. + if False data will be included based on containment within a given interval: + All data in the left chunk will have their (exclusive) end <= t, + all data in the right chunk will have (inclusive) start >=t. + """ + t = max(min(t, self.end), self.start) if t == self.end: data1, data2 = self.data, self.data[:0] @@ -171,14 +182,16 @@ def split(self, data1, data2, t = split_array( data=self.data, t=t, - allow_early_split=allow_early_split) + allow_early_split=allow_early_split, + allow_overlap=allow_overlap) common_kwargs = dict( run_id=self.run_id, dtype=self.dtype, data_type=self.data_type, data_kind=self.data_kind, - target_size_mb=self.target_size_mb) + target_size_mb=self.target_size_mb, + strict_bounds=not allow_overlap) c1 = strax.Chunk( start=self.start, @@ -191,7 +204,7 @@ def split(self, data=data2, **common_kwargs) return c1, c2 - + @classmethod def merge(cls, chunks, data_type=''): """Create chunk by merging columns of chunks of same data kind @@ -273,6 +286,9 @@ def concatenate(cls, chunks): run_id = run_ids[0] subruns = _update_subruns_in_chunk(chunks) + for c in chunks: + if not c.strict_bounds: + c.data = c.data[c.data['time']>=c.start] prev_end = 0 for c in chunks: @@ -291,8 +307,8 @@ def concatenate(cls, chunks): run_id=run_id, subruns=subruns, data=np.concatenate([c.data for c in chunks]), - target_size_mb=max([c.target_size_mb for c in chunks])) - + target_size_mb=max([c.target_size_mb for c in chunks]), + strict_bounds=all([c.strict_bounds for c in chunks])) @export def continuity_check(chunk_iter): @@ -330,7 +346,7 @@ class CannotSplit(Exception): @export @numba.njit(cache=True, nogil=True) -def split_array(data, t, allow_early_split=False): +def split_array(data, t, allow_early_split=False, allow_overlap=False): """Return (data left of t, data right of t, t), or raise CannotSplit if that would split a data element in two. @@ -349,6 +365,12 @@ def split_array(data, t, allow_early_split=False): if data[0]['time'] >= t: return data[:0], data, t + # Overlaps allowed, split is trivial. + # All data starting before t go in first part + # all data ending after t goes in second part + if allow_overlap: + return data[data['time']<=t], data[strax.endtime(data)>=t], t + # Find: # i_first_beyond: the first element starting after t # splittable_i: nearest index left of t where we can safely split BEFORE diff --git a/tests/test_general_processing.py b/tests/test_general_processing.py index 338df38a1..f4319af55 100644 --- a/tests/test_general_processing.py +++ b/tests/test_general_processing.py @@ -171,17 +171,18 @@ def test_split(things, split_indices): @hypothesis.settings(deadline=None) @hypothesis.given(strax.testutils.several_fake_records, hypothesis.strategies.integers(0, 50), + hypothesis.strategies.booleans(), hypothesis.strategies.booleans()) -def test_split_array(data, t, allow_early_split): +def test_split_array(data, t, allow_early_split, allow_overlap): print(f"\nCalled with {np.transpose([data['time'], strax.endtime(data)]).tolist()}, " f"{t}, {allow_early_split}") try: data1, data2, tsplit = strax.split_array( - data, t, allow_early_split=allow_early_split) + data, t, allow_early_split=allow_early_split, allow_overlap=allow_overlap) except strax.CannotSplit: - assert not allow_early_split + assert not allow_early_split and not allow_overlap # There must be data straddling t for d in data: if d['time'] < t < strax.endtime(d): @@ -193,10 +194,14 @@ def test_split_array(data, t, allow_early_split): if allow_early_split: assert tsplit <= t t = tsplit - - assert len(data1) + len(data2) == len(data) - assert np.all(strax.endtime(data1) <= t) - assert np.all(data2['time'] >= t) + if allow_overlap: + assert tsplit == t + assert np.all(strax.endtime(data2) >= t) + assert np.all(data1['time'] <= t) + else: + assert len(data1) + len(data2) == len(data) + assert np.all(strax.endtime(data1) <= t) + assert np.all(data2['time'] >= t) @hypothesis.settings(deadline=None)