Skip to content

Commit

Permalink
Parallelize histogram and MPH calculations (#63)
Browse files Browse the repository at this point in the history
* Parallelize histogram and MPH calculations

* Remove extraneous variable definitions

* Formatting

* Check if multiple channels extract properly with new histogram function
  • Loading branch information
alex-l-kong authored Oct 2, 2023
1 parent 066f953 commit 2b1e4ce
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 62 deletions.
116 changes: 74 additions & 42 deletions src/mibi_bin_tools/_extract_bin.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ cdef inline int _minimum_larger_value_in_sorted(const DTYPE_t[:] low_range, DTYP
while start <= end:
mid = (start + end) // 2

if low_range[mid] < val:
if low_range[mid] == val:
return mid + 1

elif low_range[mid] < val:
start = mid + 1

else:
Expand Down Expand Up @@ -168,25 +171,13 @@ cdef INT_t[:, :, :, :] _extract_bin(const char* filename,
@boundscheck(False) # Deactivate bounds checking
@wraparound(False) # Deactivate negative indexing
@cdivision(True) # Ignore modulo/divide by zero warning
cdef void _extract_histograms(const char* filename, DTYPE_t low_range,
DTYPE_t high_range, MAXINDEX_t* widths, MAXINDEX_t* intensities,
MAXINDEX_t* pulse_counts):
""" Creates histogram of observed peak widths within specified integration range
Args:
filename (const char*):
Name of bin file to extract.
low_range (uint16_t):
Low time range for integration
high_range (uint16_t):
High time range for integration
"""
cdef _extract_histograms(const char* filename, DTYPE_t[:] low_range,
DTYPE_t[:] high_range):
cdef DTYPE_t num_x, num_y, num_trig, num_frames, desc_len, trig, num_pulses, pulse, time
cdef DTYPE_t intensity
cdef SMALL_t width
cdef MAXINDEX_t data_start, pix
cdef MAXINDEX_t data_start, pix
cdef int idx
cdef DTYPE_t p_cnt = 0

# 10MB buffer
cdef MAXINDEX_t BUFFER_SIZE = 10 * 1024 * 1024
Expand All @@ -210,42 +201,94 @@ cdef void _extract_histograms(const char* filename, DTYPE_t low_range,
data_start = \
<MAXINDEX_t>(num_x) * <MAXINDEX_t>(num_y) * <MAXINDEX_t>(num_frames) * 8 + desc_len + 0x12

widths = \
cvarray(
shape=(<MAXINDEX_t>256, <MAXINDEX_t>low_range.shape[0]),
itemsize=sizeof(MAXINDEX_t),
format='Q'
)
intensities = \
cvarray(
shape=(<MAXINDEX_t>USHRT_MAX, <MAXINDEX_t>low_range.shape[0]),
itemsize=sizeof(MAXINDEX_t),
format='Q'
)
pulse_counts = \
cvarray(
shape=(<MAXINDEX_t>256, <MAXINDEX_t>low_range.shape[0]),
itemsize=sizeof(MAXINDEX_t),
format='Q'
)
cdef MAXINDEX_t[:, :] widths_view = widths
cdef MAXINDEX_t[:, :] intensities_view = intensities
cdef MAXINDEX_t[:, :] pulse_counts_view = pulse_counts

widths_view[:, :] = 0
intensities_view[:, :] = 0
pulse_counts_view[:, :] = 0

p_cnt = \
cvarray(
shape=(<MAXINDEX_t>(num_x) * <MAXINDEX_t>(num_y), <MAXINDEX_t>low_range.shape[0]),
itemsize=sizeof(MAXINDEX_t),
format='Q'
)
cdef MAXINDEX_t[:, :] p_cnt_view = p_cnt
p_cnt_view[:, :] = 0

fseek(fp, data_start, SEEK_SET)
fread(file_buffer, sizeof(char), BUFFER_SIZE, fp)
for pix in range(<MAXINDEX_t>(num_x) * <MAXINDEX_t>(num_y)):
#if pix % num_x == 0:
# if pix % num_x == 0:
# print('\rpix done: ' + str(100 * pix / num_x / num_y) + '%...', end='')
for trig in range(num_trig):
_check_buffer_refill(fp, file_buffer, &buffer_idx, 0x8 * sizeof(char), BUFFER_SIZE)
memcpy(&num_pulses, file_buffer + buffer_idx + 0x6, sizeof(time))
buffer_idx += 0x8
p_cnt = 0
p_cnt_view[pix, :] = 0
for pulse in range(num_pulses):
_check_buffer_refill(fp, file_buffer, &buffer_idx, 0x5 * sizeof(char), BUFFER_SIZE)
memcpy(&time, file_buffer + buffer_idx, sizeof(time))
memcpy(&width, file_buffer + buffer_idx + 0x2, sizeof(width))
memcpy(&intensity, file_buffer + buffer_idx + 0x3, sizeof(intensity))
buffer_idx += 0x5
if time <= high_range and time >= low_range:
widths[width] += 1
intensities[intensity] += 1
p_cnt += 1
if p_cnt != 0:
pulse_counts[p_cnt] += 1
idx = _minimum_larger_value_in_sorted(low_range, time)

if idx > 0:
if time <= high_range[idx - 1]:
widths_view[width, idx - 1] += 1
intensities_view[intensity, idx - 1] += 1
p_cnt_view[pix, idx - 1] += 1
elif idx == -1:
if time <= high_range[low_range.shape[0] - 1]:
widths_view[width, low_range.shape[0] - 1] += 1
intensities_view[intensity, low_range.shape[0] - 1] += 1
p_cnt_view[pix, low_range.shape[0] - 1] += 1

for chan in range(p_cnt_view.shape[1]):
if p_cnt_view[pix, chan] > 0:
pulse_counts_view[p_cnt_view[pix, chan], chan] += 1

fclose(fp)
free(file_buffer)

return (
np.asarray(widths, dtype=np.uint64).reshape((256, low_range.shape[0])),
np.asarray(intensities, dtype=np.uint64).reshape((USHRT_MAX, low_range.shape[0])),
np.asarray(pulse_counts, dtype=np.uint64).reshape((256, low_range.shape[0]))
)


cdef int _comp(const void *a, const void *b) noexcept nogil:
cdef int *x = <int *>a
cdef int *y = <int *>b
return x[0] - y[0]


cdef void _extract_pulse_height_and_positive_pixel(const char* filename, DTYPE_t low_range,
DTYPE_t high_range,
MAXINDEX_t* median_pulse_height,
double* mean_pp):

cdef DTYPE_t num_x, num_y, num_trig, num_frames, desc_len, trig, num_pulses, pulse, time
cdef DTYPE_t intensity
cdef SMALL_t pulse_count, width
Expand Down Expand Up @@ -369,32 +412,21 @@ cdef MAXINDEX_t _extract_total_counts(const char* filename):

return counts


def c_extract_bin(char* filename, DTYPE_t[:] low_range,
DTYPE_t[:] high_range, SMALL_t[:] calc_intensity):
return np.asarray(
_extract_bin(filename, low_range, high_range, calc_intensity)
)


def c_extract_histograms(char* filename, DTYPE_t low_range,
DTYPE_t high_range):

cdef MAXINDEX_t widths[256]
cdef MAXINDEX_t intensity[USHRT_MAX]
cdef MAXINDEX_t pulse_counts[256]

memset(widths, 0, 256 * sizeof(MAXINDEX_t))
memset(intensity, 0, USHRT_MAX * sizeof(MAXINDEX_t))
memset(pulse_counts, 0, 256 * sizeof(MAXINDEX_t))

_extract_histograms(filename, low_range, high_range, widths, intensity, pulse_counts)

return (
np.asarray(widths),
np.asarray(intensity),
np.asarray(pulse_counts)
def c_extract_histograms(char* filename, DTYPE_t[:] low_range,
DTYPE_t[:] high_range):
return _extract_histograms(
filename, low_range, high_range
)


def c_pulse_height_vs_positive_pixel(char* filename, DTYPE_t low_range, DTYPE_t high_range):
cdef MAXINDEX_t median_pulse_height = 0
cdef double mean_pp = 0.0
Expand Down
63 changes: 47 additions & 16 deletions src/mibi_bin_tools/bin_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ def extract_bin_files(data_dir: str, out_dir: Union[str, None],
return image_data


def get_histograms_per_tof(data_dir: str, fov: str, channel: str, mass_range=(-0.3, 0.0),
def get_histograms_per_tof(data_dir: str, fov: str, channels: List[str] = None,
panel: Union[Tuple[float, float], pd.DataFrame] = (-0.3, 0.0),
time_res: float = 500e-6):
"""Generates histograms of pulse widths, pulse counts, and pulse intensities found within the
given mass range
Expand All @@ -414,26 +415,46 @@ def get_histograms_per_tof(data_dir: str, fov: str, channel: str, mass_range=(-0
Directory containing bin files as well as accompanying json metadata files
fov (str):
Fov to generate histogram for
channel (str):
Channel to check widths for
mass_range (tuple):
Integration range
channels (str):
Channels to check widths for, default checks all channels
panel (tuple | pd.DataFrame):
If a tuple, global integration range over all antibodies within json metadata.
If a pd.DataFrame, specific peaks with custom integration ranges. Column names must be
'Mass' and 'Target' with integration ranges specified via 'Start' and 'Stop' columns.
time_res (float):
Time resolution for scaling parabolic transformation
Returns:
tuple (dict):
Tuple of dicts containing widths, intensities, and pulse info per channel
"""
fov = _find_bin_files(data_dir, [fov])[fov]

_fill_fov_metadata(data_dir, fov, mass_range, False, time_res, [channel])
_fill_fov_metadata(data_dir, fov, panel, False, time_res, channels)

local_bin_file = os.path.join(data_dir, fov['bin'])

widths, intensities, pulses = _extract_bin.c_extract_histograms(bytes(local_bin_file, 'utf-8'),
fov['lower_tof_range'][0],
fov['upper_tof_range'][0])
widths, intensities, pulses = _extract_bin.c_extract_histograms(
bytes(local_bin_file, 'utf-8'),
fov['lower_tof_range'],
fov['upper_tof_range']
)

chan_list = fov["targets"].values
widths = {
chan: widths_col for chan, widths_col in zip(chan_list, widths.T)
}
intensities = {
chan: intensities_col for chan, intensities_col in zip(chan_list, intensities.T)
}
pulses = {
chan: pulses_col for chan, pulses_col in zip(chan_list, pulses.T)
}

return widths, intensities, pulses


def get_median_pulse_height(data_dir: str, fov: str, channel: str,
def get_median_pulse_height(data_dir: str, fov: str, channels: List[str] = None,
panel: Union[Tuple[float, float], pd.DataFrame] = (-0.3, 0.0),
time_res: float = 500e-6):
"""Retrieves median pulse intensity and mean pulse count for a given channel
Expand All @@ -450,20 +471,30 @@ def get_median_pulse_height(data_dir: str, fov: str, channel: str,
time_res (float):
Time resolution for scaling parabolic transformation
Returns:
dict:
dictionary of median height values per channel
"""

fov = _find_bin_files(data_dir, [fov])[fov]
_fill_fov_metadata(data_dir, fov, panel, False, time_res, [channel])
_fill_fov_metadata(data_dir, fov, panel, False, time_res, channels)

local_bin_file = os.path.join(data_dir, fov['bin'])

_, intensities, _ = \
_extract_bin.c_extract_histograms(bytes(local_bin_file, 'utf-8'),
fov['lower_tof_range'][0],
fov['upper_tof_range'][0])
_extract_bin.c_extract_histograms(
bytes(local_bin_file, 'utf-8'),
fov['lower_tof_range'],
fov['upper_tof_range']
)

int_bin = np.cumsum(intensities) / intensities.sum()
median_height = (np.abs(int_bin - 0.5)).argmin()
int_bin = np.cumsum(intensities, axis=0) / intensities.sum(axis=0)
median_height = (np.abs(int_bin - 0.5)).argmin(axis=0)

chan_list = fov["targets"].values
median_height = {
chan: mh for chan, mh in zip(chan_list, median_height)
}

return median_height

Expand Down
45 changes: 41 additions & 4 deletions tests/bin_files_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,43 @@ def case_bad_specified_panel(self):
}])
return bad_panel

@case(tags=['multiple'])
def case_multiple_chan_panel(self):
panel = pd.DataFrame([
{
'Mass': 89,
'Target': 'SMA',
'Start': 88.7,
'Stop': 89.0
},
{
'Mass': 152,
'Target': 'CD38',
'Start': 151.7,
'Stop': 152
}
])
return panel

@case(tags=['multiple'])
@pytest.mark.xfail(raises=KeyError, strict=True)
def case_bad_specified_panel(self):
bad_panel = pd.DataFrame([
{
'isotope': 89,
'antibody': 'SMA',
'start': 88.7,
'stop': 89.0
},
{
'isotope': 152,
'antibody': 'CD38',
'start': 151.7,
'stop': 152
}
])
return bad_panel


class FovMetadataTestChannels:

Expand Down Expand Up @@ -274,24 +311,24 @@ def test_extract_bin_files(test_dir, fov, panel, intensities, replace, filepath_


@parametrize_with_cases('test_dir, fov', cases=FovMetadataTestFiles)
@parametrize_with_cases('panel', cases=FovMetadataTestPanels, has_tag='specified')
@parametrize_with_cases('panel', cases=FovMetadataTestPanels, has_tag='multiple')
def test_get_width_histogram(test_dir, fov, panel):
bin_files.get_histograms_per_tof(
test_dir,
fov['json'].split('.')[0],
'SMA',
['SMA', 'CD38'],
panel,
time_res=500e-6
)


@parametrize_with_cases('test_dir, fov', cases=FovMetadataTestFiles)
@parametrize_with_cases('panel', cases=FovMetadataTestPanels, has_tag='specified')
@parametrize_with_cases('panel', cases=FovMetadataTestPanels, has_tag='multiple')
def test_median_height_vs_mean_pp(test_dir, fov, panel):
bin_files.get_median_pulse_height(
test_dir,
fov['json'].split('.')[0],
'SMA',
['SMA', 'CD38'],
panel,
500e-6
)
Expand Down

0 comments on commit 2b1e4ce

Please sign in to comment.