diff --git a/docs/source/advanced/plugin_dev.rst b/docs/source/advanced/plugin_dev.rst index 3fec7efab..665a6c0c4 100644 --- a/docs/source/advanced/plugin_dev.rst +++ b/docs/source/advanced/plugin_dev.rst @@ -35,6 +35,7 @@ There are several plugin types: * ``CutPlugin``: Plugin type where using ``def cut_by(self, )`` inside the plugin a user can return a boolean array that can be used to select data. * ``MergeOnlyPlugin``: This is for internal use and only merges two plugins into a new one. See as an example in straxen the ``EventInfo`` plugin where the following datatypes are merged ``'events', 'event_basics', 'event_positions', 'corrected_areas', 'energy_estimates'``. * ``ParallelSourcePlugin``: For internal use only to parallelize the processing of low level plugins. This can be activated using stating ``parallel = 'process'`` in a plugin. + * ``TimeDelayPlugin``: For plugins that add variable time delays to output data, causing output timestamps to potentially exceed input chunk boundaries. Useful for simulation plugins (e.g., adding electron drift time). The user must define ``compute_with_delay(self, )`` returning the delayed output arrays. Minimal examples @@ -178,6 +179,32 @@ ________ st.get_array(run_id, 'merged_data') +strax.TimeDelayPlugin +_________________________ +.. code-block:: python + + class VariableDelayPlugin(strax.TimeDelayPlugin): + """ + Plugin that adds random delays, simulating e.g. drift time. + Output timestamps may exceed input chunk boundaries. + """ + depends_on = 'records' + provides = 'delayed_records' + data_kind = 'delayed_records' + dtype = strax.record_dtype() + max_delay = 100 # for use in compute_with_delay + + def compute_with_delay(self, records): + result = records.copy() + # Simulate variable drift time + delays = np.random.randint(0, self.max_delay, size=len(result)) + result['time'] = result['time'] + delays + return result + + st.register(VariableDelayPlugin) + st.get_array(run_id, 'delayed_records') + + Plugin inheritance ---------------------- It is possible to inherit the ``compute()`` method of an already existing plugin with another plugin. We call these types of plugins child plugins. Child plugins are recognized by strax when the ``child_plugin`` attribute of the plugin is set to ``True``. Below you can find a simple example of a child plugin with its parent plugin: diff --git a/strax/plugins/__init__.py b/strax/plugins/__init__.py index ff9c0ce34..52e289d81 100644 --- a/strax/plugins/__init__.py +++ b/strax/plugins/__init__.py @@ -6,3 +6,4 @@ from .parrallel_source_plugin import * from .down_chunking_plugin import * from .exhaust_plugin import * +from .time_delay_plugin import * diff --git a/strax/plugins/time_delay_plugin.py b/strax/plugins/time_delay_plugin.py new file mode 100644 index 000000000..e980048d9 --- /dev/null +++ b/strax/plugins/time_delay_plugin.py @@ -0,0 +1,242 @@ +"""Plugin base class for algorithms that add time delays to output.""" + +import numpy as np +import strax +from .plugin import Plugin + +export, __all__ = strax.exporter() + + +@export +class TimeDelayPlugin(Plugin): + """Plugin base class for algorithms that add time delays to output. + + Use this when your algorithm shifts output timestamps forward in time, + potentially beyond input chunk boundaries. Handles variable delays with + known maximum, re-sorting, buffering across chunk boundaries, and + multi-output plugins. + + Subclasses must implement: + compute_with_delay(**kwargs): Return delayed output data (arrays, not Chunks) + + For multi-output plugins, compute_with_delay should return a dict + mapping data_type names to numpy arrays. + + """ + + parallel = False + + def __init__(self): + super().__init__() + self.output_buffer = {} + self.last_output_end = 0 + self.first_output = True + self._cached_superrun = None + self._cached_subruns = None + self._min_buffered_time = float("inf") + + def compute_with_delay(self, **kwargs): + """Compute output data with time delays already applied. + + Input arrays are numpy arrays (not Chunks). Output arrays do NOT need to be sorted. For + multi-output, return a dict mapping data_type to arrays. + + """ + raise NotImplementedError("Subclasses must implement compute_with_delay()") + + def iter(self, iters, executor=None): + """Override iter to flush buffer at end of processing.""" + yield from super().iter(iters, executor=executor) + final_result = self._flush_buffers() + if final_result is not None: + yield final_result + + def _flush_buffers(self): + """Flush all remaining data from buffers.""" + has_data = any(len(self.output_buffer.get(dt, [])) > 0 for dt in self.provides) + if not has_data: + return None + + # Sort buffers and compute chunk_end + chunk_end = self.last_output_end + for data_type in self.provides: + buf = self.output_buffer.get(data_type) + if buf is not None and len(buf) > 0: + buf.sort(order="time") + data_end = int(strax.endtime(buf).max()) + chunk_end = max(chunk_end, data_end) + + # Build result dict + result = {} + for data_type in self.provides: + buf = self.output_buffer.get(data_type, np.empty(0, self.dtype_for(data_type))) + result[data_type] = self._make_chunk( + data=buf, + data_type=data_type, + start=self.last_output_end, + end=chunk_end, + ) + + result = self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) + self.output_buffer = {} + + return self._unwrap_result(result) + + def do_compute(self, chunk_i=None, **kwargs): + """Process input, buffer output, return safe portion.""" + if not kwargs: + raise RuntimeError("TimeDelayPlugin must have dependencies") + first_chunk = next(iter(kwargs.values())) + input_end = first_chunk.end + + self._cached_superrun = self._check_subruns_uniqueness( + kwargs, {k: v.superrun for k, v in kwargs.items()} + ) + self._cached_subruns = self._check_subruns_uniqueness( + kwargs, {k: v.subruns for k, v in kwargs.items()} + ) + + input_data = {k: v.data for k, v in kwargs.items()} + new_output = self.compute_with_delay(**input_data) + + self._add_to_buffers(new_output) + + return self._process_output(safe_boundary=input_end) + + def _unwrap_result(self, result): + """Unwrap result dict to single Chunk for single-output plugins.""" + if self.multi_output: + return result + return result[self.provides[0]] + + def _add_to_buffers(self, new_output): + """Add new output to buffers.""" + # Normalize output to dict format + if self.multi_output: + if not isinstance(new_output, dict): + raise ValueError( + f"{self.__class__.__name__} is multi-output, " + "compute_with_delay must return a dict" + ) + output_dict = new_output + else: + if isinstance(new_output, dict): + raise ValueError( + f"{self.__class__.__name__} is single-output, " + "compute_with_delay should not return a dict" + ) + output_dict = {self.provides[0]: new_output} + + for data_type in self.provides: + arr = output_dict.get(data_type, np.empty(0, self.dtype_for(data_type))) + if not isinstance(arr, np.ndarray): + arr = strax.dict_to_rec(arr, dtype=self.dtype_for(data_type)) + + if data_type not in self.output_buffer: + self.output_buffer[data_type] = arr + elif len(arr) > 0: + self.output_buffer[data_type] = np.concatenate([self.output_buffer[data_type], arr]) + + def _process_output(self, safe_boundary): + """Process buffers and return safe portion.""" + # Sort all buffers + for data_type in self.provides: + buf = self.output_buffer.get(data_type) + if buf is not None and len(buf) > 0: + buf.sort(order="time") + + # Split buffers into safe and remaining portions + safe_data_dict = {} + for data_type in self.provides: + buf = self.output_buffer.get(data_type, np.empty(0, self.dtype_for(data_type))) + safe_data, remaining = self._split_buffer(buf, safe_boundary) + self.output_buffer[data_type] = remaining + safe_data_dict[data_type] = safe_data + + # Update minimum buffered time + min_time = float("inf") + for buf in self.output_buffer.values(): + if buf is not None and len(buf) > 0: + min_time = min(min_time, buf["time"].min()) + self._min_buffered_time = min_time + + # Compute unified chunk boundaries across all data types + chunk_start = None + chunk_end = None + for data_type in self.provides: + safe_data = safe_data_dict[data_type] + dt_start, dt_end = self._get_chunk_boundaries(safe_data, safe_boundary) + if chunk_start is None: + chunk_start = dt_start + chunk_end = dt_end + else: + chunk_start = min(chunk_start, dt_start) + chunk_end = max(chunk_end, dt_end) + + # Build result dict + result = {} + for data_type in self.provides: + result[data_type] = self._make_chunk( + data=safe_data_dict[data_type], + data_type=data_type, + start=chunk_start, + end=chunk_end, + ) + + self.last_output_end = chunk_end + self.first_output = False + + result = self.superrun_transformation(result, self._cached_superrun, self._cached_subruns) + + return self._unwrap_result(result) + + def _split_buffer(self, buf, safe_boundary): + """Split buffer into safe portion (endtime <= boundary) and remainder.""" + if len(buf) == 0: + empty = np.empty(0, buf.dtype) + return empty, empty + + endtimes = strax.endtime(buf) + safe_mask = endtimes <= safe_boundary + + safe_data = buf[safe_mask].copy() + remaining = buf[~safe_mask].copy() + + return safe_data, remaining + + def _get_chunk_boundaries(self, safe_data, safe_boundary): + """Determine chunk start/end ensuring buffered data fits in next chunk.""" + if self.first_output: + if len(safe_data) > 0: + chunk_start = int(safe_data[0]["time"]) + else: + chunk_start = 0 + else: + chunk_start = self.last_output_end + + if len(safe_data) > 0: + data_end = int(strax.endtime(safe_data).max()) + chunk_end = max(data_end, safe_boundary) + else: + chunk_end = safe_boundary + + # Don't advance chunk_end past minimum buffered time + if self._min_buffered_time < float("inf"): + chunk_end = min(chunk_end, int(self._min_buffered_time)) + + chunk_end = max(chunk_start, chunk_end) + + return chunk_start, chunk_end + + def _make_chunk(self, data, data_type, start, end): + """Create a strax Chunk with proper metadata.""" + return strax.Chunk( + start=start, + end=end, + data=data, + data_type=data_type, + data_kind=self.data_kind_for(data_type), + dtype=self.dtype_for(data_type), + run_id=self._run_id, + target_size_mb=self.chunk_target_size_mb, + ) diff --git a/tests/test_time_delay_plugin.py b/tests/test_time_delay_plugin.py new file mode 100644 index 000000000..ffefd6d8d --- /dev/null +++ b/tests/test_time_delay_plugin.py @@ -0,0 +1,263 @@ +"""Tests for TimeDelayPlugin.""" + +import numpy as np +import strax +import pytest + + +def simple_interval_dtype(): + return [ + ("time", np.int64), + ("length", np.int32), + ("dt", np.int16), + ("value", np.int32), + ] + + +class ChunkedSource(strax.Plugin): + """Source plugin that yields pre-defined chunks.""" + + depends_on = tuple() + provides = "source_data" + dtype = simple_interval_dtype() + rechunk_on_save = False + chunks_data: list = [] + + def is_ready(self, chunk_i): + return chunk_i < len(self.chunks_data) + + def source_finished(self): + return True + + def compute(self, chunk_i): + start, end, data = self.chunks_data[chunk_i] + return self.chunk(start=start, end=end, data=data) + + +class ConstantDelayPlugin(strax.TimeDelayPlugin): + """Test plugin that adds a constant delay to all records.""" + + depends_on = ("source_data",) + provides = "delayed_data" + dtype = simple_interval_dtype() + data_kind = "delayed_data" + delay = 0 + + def compute_with_delay(self, source_data): + result = source_data.copy() + result["time"] = result["time"] + self.delay + return result + + +class VariableDelayPlugin(strax.TimeDelayPlugin): + """Test plugin that adds variable delays based on a pattern.""" + + depends_on = ("source_data",) + provides = "variable_delayed_data" + dtype = simple_interval_dtype() + data_kind = "variable_delayed_data" + delay_pattern = [0] + + def compute_with_delay(self, source_data): + result = source_data.copy() + delays = np.array( + [self.delay_pattern[i % len(self.delay_pattern)] for i in range(len(result))] + ) + result["time"] = result["time"] + delays + return result + + +class MultiOutputDelayPlugin(strax.TimeDelayPlugin): + """Test plugin with multiple outputs.""" + + depends_on = ("source_data",) + provides = ("delayed_output_a", "delayed_output_b") + data_kind = { + "delayed_output_a": "delayed_output_a", + "delayed_output_b": "delayed_output_b", + } + delay_a = 0 + delay_b = 0 + + def infer_dtype(self): + return { + "delayed_output_a": simple_interval_dtype(), + "delayed_output_b": simple_interval_dtype(), + } + + def compute_with_delay(self, source_data): + result_a = source_data.copy() + result_a["time"] = result_a["time"] + self.delay_a + result_b = source_data.copy() + result_b["time"] = result_b["time"] + self.delay_b + return {"delayed_output_a": result_a, "delayed_output_b": result_b} + + +def make_test_data(times, length=1, dt=1, values=None): + """Create test data array with given times.""" + n = len(times) + data = np.zeros(n, dtype=simple_interval_dtype()) + data["time"] = times + data["length"] = length + data["dt"] = dt + data["value"] = values if values is not None else np.arange(n) + return data + + +def create_context_with_source(chunks_data): + """Create a strax context with ChunkedSource configured.""" + + class TestSource(ChunkedSource): + pass + + TestSource.chunks_data = chunks_data + st = strax.Context(storage=[]) + st.register(TestSource) + return st + + +def test_constant_delay_across_chunks(): + """Test constant delay with buffering across chunk boundaries.""" + delay = 30 + + data1 = make_test_data(np.array([10, 40]), values=np.array([0, 1])) + data2 = make_test_data(np.array([60, 90]), values=np.array([2, 3])) + + chunks_data = [ + (0, 50, data1), + (50, 100, data2), + ] + st = create_context_with_source(chunks_data) + + class TestDelayPlugin(ConstantDelayPlugin): + delay = 30 + + st.register(TestDelayPlugin) + result = st.get_array(run_id="test", targets="delayed_data") + + expected_times = np.array([10, 40, 60, 90]) + delay + np.testing.assert_array_equal(sorted(result["time"]), sorted(expected_times)) + assert len(result) == 4 + + +def test_variable_delay_reorders_and_buffers(): + """Test variable delays with reordering and buffering.""" + data1 = make_test_data(np.array([0, 10, 20]), values=np.array([0, 1, 2])) + data2 = make_test_data(np.array([50, 60, 70]), values=np.array([3, 4, 5])) + + chunks_data = [ + (0, 50, data1), + (50, 100, data2), + ] + st = create_context_with_source(chunks_data) + + class TestVariableDelay(VariableDelayPlugin): + delay_pattern = [0, 80, 20] + + st.register(TestVariableDelay) + result = st.get_array(run_id="test", targets="variable_delayed_data") + + assert len(result) == 6 + assert np.all(np.diff(result["time"]) >= 0), "Output must be sorted" + + +def test_empty_input_chunk(): + """Test handling of empty input chunks.""" + data1 = make_test_data(np.array([10, 20]), values=np.array([0, 1])) + empty_data = make_test_data(np.array([], dtype=np.int64)) + data3 = make_test_data(np.array([110, 120]), values=np.array([2, 3])) + + chunks_data = [ + (0, 50, data1), + (50, 100, empty_data), + (100, 150, data3), + ] + st = create_context_with_source(chunks_data) + + class TestDelayPlugin(ConstantDelayPlugin): + delay = 20 + + st.register(TestDelayPlugin) + result = st.get_array(run_id="test", targets="delayed_data") + + assert len(result) == 4 + + +def test_multi_output_different_delays(): + """Test multi-output plugin with different delays per output.""" + data = make_test_data(np.array([10, 50]), values=np.array([0, 1])) + chunks_data = [(0, 100, data)] + st = create_context_with_source(chunks_data) + + class TestMultiOutput(MultiOutputDelayPlugin): + delay_a = 20 + delay_b = 60 + + st.register(TestMultiOutput) + + result_a = st.get_array(run_id="test", targets="delayed_output_a") + result_b = st.get_array(run_id="test", targets="delayed_output_b") + + np.testing.assert_array_equal(result_a["time"], [30, 70]) + np.testing.assert_array_equal(result_b["time"], [70, 110]) + + +def test_chunk_continuity(): + """Test that output chunks maintain proper continuity.""" + data1 = make_test_data(np.array([10, 20]), values=np.array([0, 1])) + data2 = make_test_data(np.array([60, 70]), values=np.array([2, 3])) + data3 = make_test_data(np.array([110, 120]), values=np.array([4, 5])) + + chunks_data = [ + (0, 50, data1), + (50, 100, data2), + (100, 150, data3), + ] + st = create_context_with_source(chunks_data) + + class TestDelayPlugin(ConstantDelayPlugin): + delay = 30 + + st.register(TestDelayPlugin) + chunks = list(st.get_iter(run_id="test", targets="delayed_data")) + + for i in range(1, len(chunks)): + assert ( + chunks[i].start == chunks[i - 1].end + ), f"Chunk {i} start ({chunks[i].start}) != chunk {i - 1} end ({chunks[i - 1].end})" + + +def test_straddling_data_across_boundary(): + """Test data that straddles chunk boundary (time < boundary < endtime).""" + # After delay: time=95, endtime=105 (straddles boundary at 100) + data1 = np.zeros(1, dtype=simple_interval_dtype()) + data1["time"] = 85 + data1["length"] = 10 + data1["dt"] = 1 + data1["value"] = 1 + + data2 = np.zeros(1, dtype=simple_interval_dtype()) + data2["time"] = 120 + data2["length"] = 10 + data2["dt"] = 1 + data2["value"] = 2 + + chunks_data = [ + (0, 100, data1), + (100, 200, data2), + ] + st = create_context_with_source(chunks_data) + + class TestDelayPlugin(ConstantDelayPlugin): + delay = 10 + + st.register(TestDelayPlugin) + result = st.get_array(run_id="test", targets="delayed_data") + + assert len(result) == 2 + np.testing.assert_array_equal(result["time"], [95, 130]) + np.testing.assert_array_equal(strax.endtime(result), [105, 140]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])