Skip to content

Commit

Permalink
Make SplitBam faster with Parallelization (#67)
Browse files Browse the repository at this point in the history
* SplitBam now takes advantage of parallelization

* Will use all available threads by default
  • Loading branch information
tlangs authored May 29, 2019
1 parent 5e04815 commit d4cbc45
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 95 deletions.
287 changes: 208 additions & 79 deletions src/sctools/bam.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@
"""

import functools
from functools import partial, reduce
import math
import os
import warnings
from abc import abstractmethod
from itertools import cycle
from typing import (
Iterator,
Iterable,
Generator,
List,
Set,
Dict,
Union,
Tuple,
Expand All @@ -55,9 +56,15 @@
)

import pysam
import shutil
import multiprocessing
import uuid

from . import consts

# File descriptor to write log messages to
STDERR = 2


class SubsetAlignments:
"""Wrapper for pysam/htslib that extracts reads corresponding to requested chromosome(s)
Expand Down Expand Up @@ -226,19 +233,152 @@ def tag(self, output_bam_name: str, tag_generators) -> None:
outbam.write(sam_record)


def get_barcodes_from_bam(
in_bam: str, tags: List[str], raise_missing: bool
) -> Set[str]:
"""Get all the distinct barcodes from a bam
:param in_bam: str
Input bam file.
:param tags: List[str]
Tags in the bam that might contain barcodes.
:param raise_missing: bool
Raise an error if no barcodes can be found.
:return: set
A set of barcodes found in the bam
This set will not contain a None value
"""
barcodes = set()
# Get all the Barcodes from the BAM
with pysam.AlignmentFile(in_bam, 'rb', check_sq=False) as input_alignments:
for alignment in input_alignments:
barcode = get_barcode_for_alignment(alignment, tags, raise_missing)
# If no provided tag was found on the record that had a non-null value
if barcode is not None:
barcodes.add(barcode)
return barcodes


def get_barcode_for_alignment(
alignment: pysam.AlignedSegment, tags: List[str], raise_missing: bool
) -> str:
""" Get the barcode for an Alignment
:param alignment: pysam.AlignedSegment
An Alignment from pysam.
:param tags: List[str]
Tags in the bam that might contain barcodes. If multiple Tags are passed, will
return the contents of the first tag that contains a barcode.
:param raise_missing: bool
Raise an error if no barcodes can be found.
:return: str
A barcode for the alignment, or None if one is not found and raise_missing is False.
"""
alignment_barcode = None
for tag in tags:
# The non-existent barcode should be the exceptional case, so try/except is faster than if/else
try:
alignment_barcode = alignment.get_tag(tag)
break # Got the key, don't bother getting the next tag
except KeyError:
continue # Try to get the next tag

if raise_missing and alignment_barcode is None:
raise RuntimeError(
'Alignment encountered that is missing {} tag(s).'.format(tags)
)

return alignment_barcode


def write_barcodes_to_bins(
in_bam: str, tags: List[str], barcodes_to_bins: Dict[str, int], raise_missing: bool
) -> List[str]:
""" Write barcodes to appropriate bins as defined by barcodes_to_bins
:param in_bam: str
The bam file to read.
:param tags: List[str]
Tags in the bam that might contain barcodes.
:param barcodes_to_bins: Dict[str, int]
A Dict from barcode to bin. All barcodes of the same type need to be written to the same bin.
These numbered bins are merged after parallelization so that all alignments with the same
barcode are in the same bam.
:param raise_missing: bool
Raise an error if no barcodes can be found.
:return: A list of paths to the written bins.
"""
# Create all the output files
with pysam.AlignmentFile(in_bam, 'rb', check_sq=False) as input_alignments:

# We need a random int appended to the dirname to make sure input bams with the same name don't clash
dirname = (
os.path.splitext(os.path.basename(in_bam))[0] + "_" + str(uuid.uuid4())
)
os.makedirs(dirname)

files = []
bins = list(set(barcodes_to_bins.values()))
filepaths = []
# barcode_to_bins is a dict of barcodes to ints. The ints are contiguous and are used as indices
# in the files array. The files array is an array of open file handles to write to.
for i in range(len(bins)):
out_bam_name = os.path.join(f"{dirname}", f"{dirname}_{i}.bam")
filepaths.append(out_bam_name)
# For now, bam writing uses one thread for compression. Better logic could support more processes without
# starving the machine for resources
open_bam = pysam.AlignmentFile(
out_bam_name, 'wb', template=input_alignments
)
files.append(open_bam)

# Loop over input; check each tag in priority order and partition barcodes into files based
# on the highest priority tag that is identified
for alignment in input_alignments:
barcode = get_barcode_for_alignment(alignment, tags, raise_missing)
if barcode is not None:
# Find or set the file associated with the tag and write the record to the correct file
out_file = files[barcodes_to_bins[barcode]]
out_file.write(alignment)

for file in files:
file.close()

return filepaths


def merge_bams(bams: List[str]) -> str:
""" Merge input bams using samtools.
This cannot be a local function within `split` because then Python "cannot pickle a local object".
:param bams: Name of the final bam + bams to merge.
Because of how its called using multiprocessing, the bam basename is the first element of the list.
:return: The output bam name.
"""
bam_name = os.path.realpath(bams[0] + ".bam")
bams_to_merge = bams[1:]
pysam.merge('-c', '-p', bam_name, *bams_to_merge)
return bam_name


def split(
in_bam: str, out_prefix: str, *tags, approx_mb_per_split=1000, raise_missing=True
in_bams: List[str],
out_prefix: str,
tags: List[str],
approx_mb_per_split: float = 1000,
raise_missing: bool = True,
num_processes: int = None,
) -> List[str]:
"""split `in_bam` by tag into files of `approx_mb_per_split`
Parameters
----------
in_bam : str
Input bam file.
in_bams : str
Input bam files.
out_prefix : str
Prefix for all output files; output will be named as prefix_n where n is an integer equal
to the chunk number.
tags : tuple
tags : List[str]
The bam tags to split on. The tags are checked in order, and sorting is done based on the
first identified tag. Further tags are only checked if the first tag is missing. This is
useful in cases where sorting is executed over a corrected barcode, but some records only
Expand All @@ -248,6 +388,8 @@ def split(
raise_missing : bool, optional
if True, raise a RuntimeError if a record is encountered without a tag. Else silently
discard the record (default = True)
num_processes : int, optional
The number of processes to parallelize over. If not set, will use all available processes.
Returns
-------
Expand All @@ -266,34 +408,11 @@ def split(
if len(tags) == 0:
raise ValueError('At least one tag must be passed')

def _cleanup(
_files_to_counts: Dict[pysam.AlignmentFile, int],
_files_to_names: Dict[pysam.AlignmentFile, str],
rm_all: bool = False,
) -> None:
"""Closes file handles and remove any empty files.
Parameters
----------
_files_to_counts : dict
Dictionary of file objects to the number of reads each contains.
_files_to_names : dict
Dictionary of file objects to file handles.
rm_all : bool, optional
If True, indicates all files should be removed, regardless of count number, else only
empty files without counts are removed (default = False)
"""
for bamfile, count in _files_to_counts.items():
# corner case: clean up files that were created, but didn't get data because
# n_cell < n_barcode
bamfile.close()
if count == 0 or rm_all:
os.remove(_files_to_names[bamfile])
del _files_to_names[bamfile]
if num_processes is None:
num_processes = multiprocessing.cpu_count()

# find correct number of subfiles to spawn
bam_mb = os.path.getsize(in_bam) * 1e-6
bam_mb = sum(os.path.getsize(b) * 1e-6 for b in in_bams)
n_subfiles = int(math.ceil(bam_mb / approx_mb_per_split))
if n_subfiles > consts.MAX_BAM_SPLIT_SUBFILES_TO_WARN:
warnings.warn(
Expand All @@ -307,59 +426,69 @@ def _cleanup(
f'think about increasing max_mb_per_split.'
)

# create all the output files
with pysam.AlignmentFile(in_bam, 'rb', check_sq=False) as input_alignments:
full_pool = multiprocessing.Pool(num_processes)

# map files to counts
files_to_counts: Dict[pysam.AlignmentFile, int] = {}
files_to_names: Dict[pysam.AlignmentFile, str] = {}
for i in range(n_subfiles):
out_bam_name = out_prefix + '_%d.bam' % i
open_bam = pysam.AlignmentFile(
out_bam_name, 'wb', template=input_alignments
)
files_to_counts[open_bam] = 0
files_to_names[open_bam] = out_bam_name
# Get all the barcodes over all the bams
os.write(STDERR, b'Retrieving barcodes from bams\n')
result = full_pool.map(
partial(get_barcodes_from_bam, tags=tags, raise_missing=raise_missing), in_bams
)

# cycler over files to assign new barcodes to next file
file_cycler = cycle(files_to_counts.keys())
barcodes_list = list(reduce(lambda set1, set2: set1.union(set2), result))
os.write(STDERR, b'Retrieved barcodes from bams\n')

# Create the barcodes to bin mapping
os.write(STDERR, b'Allocating bins\n')
barcodes_to_bins_dict = {}

# barcodes_list will always contain non-None elements from get_barcodes_from_bam
if len(barcodes_list) <= n_subfiles:
for barcode_index in range(len(barcodes_list)):
barcodes_to_bins_dict[barcodes_list[barcode_index]] = barcode_index
else:
for barcode_index in range(len(barcodes_list)):
file_index = barcode_index % n_subfiles
barcodes_to_bins_dict[barcodes_list[barcode_index]] = file_index

# Split the bams by barcode in parallel
os.write(STDERR, b'Splitting the bams by barcode\n')
# Samtools needs a thread for compression, so we leave half the given processes open.
write_pool_processes = math.ceil(num_processes / 2) if num_processes > 2 else 1
write_pool = multiprocessing.Pool(write_pool_processes)
scattered_split_result = write_pool.map(
partial(
write_barcodes_to_bins,
tags=list(tags),
raise_missing=raise_missing,
barcodes_to_bins=barcodes_to_bins_dict,
),
in_bams,
)

# create an empty map for (tag, barcode) to files
tags_to_files = {}
bin_indices = list(set(barcodes_to_bins_dict.values()))
# Create a list of lists, where the first element of every sub-list is the name of the final output bam
bins = list([f"{out_prefix}_{index}"] for index in bin_indices)

# loop over input; check each tag in priority order and partition barcodes into files based
# on the highest priority tag that is identified
for alignment in input_alignments:
# A shard is the computation of writing barcodes to bins
# Gather all the files for each bin into the same sub-list.
for shard_index in range(len(scattered_split_result)):
shard = scattered_split_result[shard_index]
for file_index in range(len(shard)):
bins[file_index].append(shard[file_index])

for tag in tags:
try:
tag_content = tag, alignment.get_tag(tag)
break
except KeyError:
tag_content = None
continue # move on to next tag

# No provided tag was found on the record that had a non-null value
if tag_content is None:
if raise_missing:
_cleanup(files_to_counts, files_to_names, rm_all=True)
raise RuntimeError(
'Alignment encountered that is missing {repr(tags)} tag(s).'
)
else:
continue # move on to next alignment

# find or set the file associated with the tag and write the record to the correct file
try:
out_file = tags_to_files[tag_content]
except KeyError:
out_file = next(file_cycler)
tags_to_files[tag_content] = out_file
out_file.write(alignment)
files_to_counts[out_file] += 1

_cleanup(files_to_counts, files_to_names)
return list(files_to_names.values())
write_pool.close()

# Recombine the binned bams
os.write(STDERR, b'Merging temporary bam files\n')
merged_bams = full_pool.map(partial(merge_bams), bins)

os.write(STDERR, b'deleting temporary files\n')
for paths in scattered_split_result:
shutil.rmtree(os.path.dirname(paths[0]))

full_pool.close()

return merged_bams


# todo change this to throw away "None" reads instead of appending them if we are filtering them
Expand Down
16 changes: 13 additions & 3 deletions src/sctools/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def split_bam(cls, args: Iterable = None) -> int:
"""
parser = argparse.ArgumentParser()
parser.add_argument('-b', '--bamfile', required=True, help='input bamfile')
parser.add_argument(
'-b', '--bamfile', nargs='+', required=True, help='input bamfile'
)
parser.add_argument(
'-p', '--output-prefix', required=True, help='prefix for output chunks'
)
Expand All @@ -181,6 +183,13 @@ def split_bam(cls, args: Iterable = None) -> int:
type=float,
help='approximate size target for each subfile (in MB)',
)
parser.add_argument(
'--num-processes',
required=False,
default=None,
type=int,
help='Number of processes to parallelize over',
)
parser.add_argument(
'-t',
'--tags',
Expand All @@ -204,9 +213,10 @@ def split_bam(cls, args: Iterable = None) -> int:
filenames = bam.split(
args.bamfile,
args.output_prefix,
*args.tags,
args.tags,
approx_mb_per_split=args.subfile_size,
raise_missing=args.raise_missing,
raise_missing=args.drop_missing,
num_processes=args.num_processes,
)

print(' '.join(filenames))
Expand Down
Loading

0 comments on commit d4cbc45

Please sign in to comment.