From 2b1e4ce4fd835e81c69c47afe9ab786e10c186bd Mon Sep 17 00:00:00 2001 From: alex-l-kong <31424707+alex-l-kong@users.noreply.github.com> Date: Mon, 2 Oct 2023 15:05:15 -0700 Subject: [PATCH] Parallelize histogram and MPH calculations (#63) * Parallelize histogram and MPH calculations * Remove extraneous variable definitions * Formatting * Check if multiple channels extract properly with new histogram function --- src/mibi_bin_tools/_extract_bin.pyx | 116 ++++++++++++++++++---------- src/mibi_bin_tools/bin_files.py | 63 +++++++++++---- tests/bin_files_test.py | 45 ++++++++++- 3 files changed, 162 insertions(+), 62 deletions(-) diff --git a/src/mibi_bin_tools/_extract_bin.pyx b/src/mibi_bin_tools/_extract_bin.pyx index e084189..e9db5d3 100644 --- a/src/mibi_bin_tools/_extract_bin.pyx +++ b/src/mibi_bin_tools/_extract_bin.pyx @@ -29,7 +29,10 @@ cdef inline int _minimum_larger_value_in_sorted(const DTYPE_t[:] low_range, DTYP while start <= end: mid = (start + end) // 2 - if low_range[mid] < val: + if low_range[mid] == val: + return mid + 1 + + elif low_range[mid] < val: start = mid + 1 else: @@ -168,25 +171,13 @@ cdef INT_t[:, :, :, :] _extract_bin(const char* filename, @boundscheck(False) # Deactivate bounds checking @wraparound(False) # Deactivate negative indexing @cdivision(True) # Ignore modulo/divide by zero warning -cdef void _extract_histograms(const char* filename, DTYPE_t low_range, - DTYPE_t high_range, MAXINDEX_t* widths, MAXINDEX_t* intensities, - MAXINDEX_t* pulse_counts): - """ Creates histogram of observed peak widths within specified integration range - - Args: - filename (const char*): - Name of bin file to extract. - low_range (uint16_t): - Low time range for integration - high_range (uint16_t): - High time range for integration - """ +cdef _extract_histograms(const char* filename, DTYPE_t[:] low_range, + DTYPE_t[:] high_range): cdef DTYPE_t num_x, num_y, num_trig, num_frames, desc_len, trig, num_pulses, pulse, time cdef DTYPE_t intensity cdef SMALL_t width - cdef MAXINDEX_t data_start, pix + cdef MAXINDEX_t data_start, pix cdef int idx - cdef DTYPE_t p_cnt = 0 # 10MB buffer cdef MAXINDEX_t BUFFER_SIZE = 10 * 1024 * 1024 @@ -210,42 +201,94 @@ cdef void _extract_histograms(const char* filename, DTYPE_t low_range, data_start = \ (num_x) * (num_y) * (num_frames) * 8 + desc_len + 0x12 + widths = \ + cvarray( + shape=(256, low_range.shape[0]), + itemsize=sizeof(MAXINDEX_t), + format='Q' + ) + intensities = \ + cvarray( + shape=(USHRT_MAX, low_range.shape[0]), + itemsize=sizeof(MAXINDEX_t), + format='Q' + ) + pulse_counts = \ + cvarray( + shape=(256, low_range.shape[0]), + itemsize=sizeof(MAXINDEX_t), + format='Q' + ) + cdef MAXINDEX_t[:, :] widths_view = widths + cdef MAXINDEX_t[:, :] intensities_view = intensities + cdef MAXINDEX_t[:, :] pulse_counts_view = pulse_counts + + widths_view[:, :] = 0 + intensities_view[:, :] = 0 + pulse_counts_view[:, :] = 0 + + p_cnt = \ + cvarray( + shape=((num_x) * (num_y), low_range.shape[0]), + itemsize=sizeof(MAXINDEX_t), + format='Q' + ) + cdef MAXINDEX_t[:, :] p_cnt_view = p_cnt + p_cnt_view[:, :] = 0 + fseek(fp, data_start, SEEK_SET) fread(file_buffer, sizeof(char), BUFFER_SIZE, fp) for pix in range((num_x) * (num_y)): - #if pix % num_x == 0: + # if pix % num_x == 0: # print('\rpix done: ' + str(100 * pix / num_x / num_y) + '%...', end='') for trig in range(num_trig): _check_buffer_refill(fp, file_buffer, &buffer_idx, 0x8 * sizeof(char), BUFFER_SIZE) memcpy(&num_pulses, file_buffer + buffer_idx + 0x6, sizeof(time)) buffer_idx += 0x8 - p_cnt = 0 + p_cnt_view[pix, :] = 0 for pulse in range(num_pulses): _check_buffer_refill(fp, file_buffer, &buffer_idx, 0x5 * sizeof(char), BUFFER_SIZE) memcpy(&time, file_buffer + buffer_idx, sizeof(time)) memcpy(&width, file_buffer + buffer_idx + 0x2, sizeof(width)) memcpy(&intensity, file_buffer + buffer_idx + 0x3, sizeof(intensity)) buffer_idx += 0x5 - if time <= high_range and time >= low_range: - widths[width] += 1 - intensities[intensity] += 1 - p_cnt += 1 - if p_cnt != 0: - pulse_counts[p_cnt] += 1 + idx = _minimum_larger_value_in_sorted(low_range, time) + + if idx > 0: + if time <= high_range[idx - 1]: + widths_view[width, idx - 1] += 1 + intensities_view[intensity, idx - 1] += 1 + p_cnt_view[pix, idx - 1] += 1 + elif idx == -1: + if time <= high_range[low_range.shape[0] - 1]: + widths_view[width, low_range.shape[0] - 1] += 1 + intensities_view[intensity, low_range.shape[0] - 1] += 1 + p_cnt_view[pix, low_range.shape[0] - 1] += 1 + + for chan in range(p_cnt_view.shape[1]): + if p_cnt_view[pix, chan] > 0: + pulse_counts_view[p_cnt_view[pix, chan], chan] += 1 fclose(fp) free(file_buffer) + return ( + np.asarray(widths, dtype=np.uint64).reshape((256, low_range.shape[0])), + np.asarray(intensities, dtype=np.uint64).reshape((USHRT_MAX, low_range.shape[0])), + np.asarray(pulse_counts, dtype=np.uint64).reshape((256, low_range.shape[0])) + ) + + cdef int _comp(const void *a, const void *b) noexcept nogil: cdef int *x = a cdef int *y = b return x[0] - y[0] + cdef void _extract_pulse_height_and_positive_pixel(const char* filename, DTYPE_t low_range, DTYPE_t high_range, MAXINDEX_t* median_pulse_height, double* mean_pp): - cdef DTYPE_t num_x, num_y, num_trig, num_frames, desc_len, trig, num_pulses, pulse, time cdef DTYPE_t intensity cdef SMALL_t pulse_count, width @@ -369,6 +412,7 @@ cdef MAXINDEX_t _extract_total_counts(const char* filename): return counts + def c_extract_bin(char* filename, DTYPE_t[:] low_range, DTYPE_t[:] high_range, SMALL_t[:] calc_intensity): return np.asarray( @@ -376,25 +420,13 @@ def c_extract_bin(char* filename, DTYPE_t[:] low_range, ) -def c_extract_histograms(char* filename, DTYPE_t low_range, - DTYPE_t high_range): - - cdef MAXINDEX_t widths[256] - cdef MAXINDEX_t intensity[USHRT_MAX] - cdef MAXINDEX_t pulse_counts[256] - - memset(widths, 0, 256 * sizeof(MAXINDEX_t)) - memset(intensity, 0, USHRT_MAX * sizeof(MAXINDEX_t)) - memset(pulse_counts, 0, 256 * sizeof(MAXINDEX_t)) - - _extract_histograms(filename, low_range, high_range, widths, intensity, pulse_counts) - - return ( - np.asarray(widths), - np.asarray(intensity), - np.asarray(pulse_counts) +def c_extract_histograms(char* filename, DTYPE_t[:] low_range, + DTYPE_t[:] high_range): + return _extract_histograms( + filename, low_range, high_range ) + def c_pulse_height_vs_positive_pixel(char* filename, DTYPE_t low_range, DTYPE_t high_range): cdef MAXINDEX_t median_pulse_height = 0 cdef double mean_pp = 0.0 diff --git a/src/mibi_bin_tools/bin_files.py b/src/mibi_bin_tools/bin_files.py index 8640407..f46be9a 100644 --- a/src/mibi_bin_tools/bin_files.py +++ b/src/mibi_bin_tools/bin_files.py @@ -404,7 +404,8 @@ def extract_bin_files(data_dir: str, out_dir: Union[str, None], return image_data -def get_histograms_per_tof(data_dir: str, fov: str, channel: str, mass_range=(-0.3, 0.0), +def get_histograms_per_tof(data_dir: str, fov: str, channels: List[str] = None, + panel: Union[Tuple[float, float], pd.DataFrame] = (-0.3, 0.0), time_res: float = 500e-6): """Generates histograms of pulse widths, pulse counts, and pulse intensities found within the given mass range @@ -414,26 +415,46 @@ def get_histograms_per_tof(data_dir: str, fov: str, channel: str, mass_range=(-0 Directory containing bin files as well as accompanying json metadata files fov (str): Fov to generate histogram for - channel (str): - Channel to check widths for - mass_range (tuple): - Integration range + channels (str): + Channels to check widths for, default checks all channels + panel (tuple | pd.DataFrame): + If a tuple, global integration range over all antibodies within json metadata. + If a pd.DataFrame, specific peaks with custom integration ranges. Column names must be + 'Mass' and 'Target' with integration ranges specified via 'Start' and 'Stop' columns. time_res (float): Time resolution for scaling parabolic transformation + + Returns: + tuple (dict): + Tuple of dicts containing widths, intensities, and pulse info per channel """ fov = _find_bin_files(data_dir, [fov])[fov] - _fill_fov_metadata(data_dir, fov, mass_range, False, time_res, [channel]) + _fill_fov_metadata(data_dir, fov, panel, False, time_res, channels) local_bin_file = os.path.join(data_dir, fov['bin']) - widths, intensities, pulses = _extract_bin.c_extract_histograms(bytes(local_bin_file, 'utf-8'), - fov['lower_tof_range'][0], - fov['upper_tof_range'][0]) + widths, intensities, pulses = _extract_bin.c_extract_histograms( + bytes(local_bin_file, 'utf-8'), + fov['lower_tof_range'], + fov['upper_tof_range'] + ) + + chan_list = fov["targets"].values + widths = { + chan: widths_col for chan, widths_col in zip(chan_list, widths.T) + } + intensities = { + chan: intensities_col for chan, intensities_col in zip(chan_list, intensities.T) + } + pulses = { + chan: pulses_col for chan, pulses_col in zip(chan_list, pulses.T) + } + return widths, intensities, pulses -def get_median_pulse_height(data_dir: str, fov: str, channel: str, +def get_median_pulse_height(data_dir: str, fov: str, channels: List[str] = None, panel: Union[Tuple[float, float], pd.DataFrame] = (-0.3, 0.0), time_res: float = 500e-6): """Retrieves median pulse intensity and mean pulse count for a given channel @@ -450,20 +471,30 @@ def get_median_pulse_height(data_dir: str, fov: str, channel: str, time_res (float): Time resolution for scaling parabolic transformation + Returns: + dict: + dictionary of median height values per channel """ fov = _find_bin_files(data_dir, [fov])[fov] - _fill_fov_metadata(data_dir, fov, panel, False, time_res, [channel]) + _fill_fov_metadata(data_dir, fov, panel, False, time_res, channels) local_bin_file = os.path.join(data_dir, fov['bin']) _, intensities, _ = \ - _extract_bin.c_extract_histograms(bytes(local_bin_file, 'utf-8'), - fov['lower_tof_range'][0], - fov['upper_tof_range'][0]) + _extract_bin.c_extract_histograms( + bytes(local_bin_file, 'utf-8'), + fov['lower_tof_range'], + fov['upper_tof_range'] + ) - int_bin = np.cumsum(intensities) / intensities.sum() - median_height = (np.abs(int_bin - 0.5)).argmin() + int_bin = np.cumsum(intensities, axis=0) / intensities.sum(axis=0) + median_height = (np.abs(int_bin - 0.5)).argmin(axis=0) + + chan_list = fov["targets"].values + median_height = { + chan: mh for chan, mh in zip(chan_list, median_height) + } return median_height diff --git a/tests/bin_files_test.py b/tests/bin_files_test.py index 8f2af5c..6de9640 100644 --- a/tests/bin_files_test.py +++ b/tests/bin_files_test.py @@ -61,6 +61,43 @@ def case_bad_specified_panel(self): }]) return bad_panel + @case(tags=['multiple']) + def case_multiple_chan_panel(self): + panel = pd.DataFrame([ + { + 'Mass': 89, + 'Target': 'SMA', + 'Start': 88.7, + 'Stop': 89.0 + }, + { + 'Mass': 152, + 'Target': 'CD38', + 'Start': 151.7, + 'Stop': 152 + } + ]) + return panel + + @case(tags=['multiple']) + @pytest.mark.xfail(raises=KeyError, strict=True) + def case_bad_specified_panel(self): + bad_panel = pd.DataFrame([ + { + 'isotope': 89, + 'antibody': 'SMA', + 'start': 88.7, + 'stop': 89.0 + }, + { + 'isotope': 152, + 'antibody': 'CD38', + 'start': 151.7, + 'stop': 152 + } + ]) + return bad_panel + class FovMetadataTestChannels: @@ -274,24 +311,24 @@ def test_extract_bin_files(test_dir, fov, panel, intensities, replace, filepath_ @parametrize_with_cases('test_dir, fov', cases=FovMetadataTestFiles) -@parametrize_with_cases('panel', cases=FovMetadataTestPanels, has_tag='specified') +@parametrize_with_cases('panel', cases=FovMetadataTestPanels, has_tag='multiple') def test_get_width_histogram(test_dir, fov, panel): bin_files.get_histograms_per_tof( test_dir, fov['json'].split('.')[0], - 'SMA', + ['SMA', 'CD38'], panel, time_res=500e-6 ) @parametrize_with_cases('test_dir, fov', cases=FovMetadataTestFiles) -@parametrize_with_cases('panel', cases=FovMetadataTestPanels, has_tag='specified') +@parametrize_with_cases('panel', cases=FovMetadataTestPanels, has_tag='multiple') def test_median_height_vs_mean_pp(test_dir, fov, panel): bin_files.get_median_pulse_height( test_dir, fov['json'].split('.')[0], - 'SMA', + ['SMA', 'CD38'], panel, 500e-6 )