diff --git a/src/sctools/bam.py b/src/sctools/bam.py index 9e0411f..43f1034 100644 --- a/src/sctools/bam.py +++ b/src/sctools/bam.py @@ -36,16 +36,17 @@ """ import functools +from functools import partial, reduce import math import os import warnings from abc import abstractmethod -from itertools import cycle from typing import ( Iterator, Iterable, Generator, List, + Set, Dict, Union, Tuple, @@ -55,9 +56,15 @@ ) import pysam +import shutil +import multiprocessing +import uuid from . import consts +# File descriptor to write log messages to +STDERR = 2 + class SubsetAlignments: """Wrapper for pysam/htslib that extracts reads corresponding to requested chromosome(s) @@ -226,19 +233,152 @@ def tag(self, output_bam_name: str, tag_generators) -> None: outbam.write(sam_record) +def get_barcodes_from_bam( + in_bam: str, tags: List[str], raise_missing: bool +) -> Set[str]: + """Get all the distinct barcodes from a bam + + :param in_bam: str + Input bam file. + :param tags: List[str] + Tags in the bam that might contain barcodes. + :param raise_missing: bool + Raise an error if no barcodes can be found. + :return: set + A set of barcodes found in the bam + This set will not contain a None value + """ + barcodes = set() + # Get all the Barcodes from the BAM + with pysam.AlignmentFile(in_bam, 'rb', check_sq=False) as input_alignments: + for alignment in input_alignments: + barcode = get_barcode_for_alignment(alignment, tags, raise_missing) + # If no provided tag was found on the record that had a non-null value + if barcode is not None: + barcodes.add(barcode) + return barcodes + + +def get_barcode_for_alignment( + alignment: pysam.AlignedSegment, tags: List[str], raise_missing: bool +) -> str: + """ Get the barcode for an Alignment + + :param alignment: pysam.AlignedSegment + An Alignment from pysam. + :param tags: List[str] + Tags in the bam that might contain barcodes. If multiple Tags are passed, will + return the contents of the first tag that contains a barcode. + :param raise_missing: bool + Raise an error if no barcodes can be found. + :return: str + A barcode for the alignment, or None if one is not found and raise_missing is False. + """ + alignment_barcode = None + for tag in tags: + # The non-existent barcode should be the exceptional case, so try/except is faster than if/else + try: + alignment_barcode = alignment.get_tag(tag) + break # Got the key, don't bother getting the next tag + except KeyError: + continue # Try to get the next tag + + if raise_missing and alignment_barcode is None: + raise RuntimeError( + 'Alignment encountered that is missing {} tag(s).'.format(tags) + ) + + return alignment_barcode + + +def write_barcodes_to_bins( + in_bam: str, tags: List[str], barcodes_to_bins: Dict[str, int], raise_missing: bool +) -> List[str]: + """ Write barcodes to appropriate bins as defined by barcodes_to_bins + + :param in_bam: str + The bam file to read. + :param tags: List[str] + Tags in the bam that might contain barcodes. + :param barcodes_to_bins: Dict[str, int] + A Dict from barcode to bin. All barcodes of the same type need to be written to the same bin. + These numbered bins are merged after parallelization so that all alignments with the same + barcode are in the same bam. + :param raise_missing: bool + Raise an error if no barcodes can be found. + :return: A list of paths to the written bins. + """ + # Create all the output files + with pysam.AlignmentFile(in_bam, 'rb', check_sq=False) as input_alignments: + + # We need a random int appended to the dirname to make sure input bams with the same name don't clash + dirname = ( + os.path.splitext(os.path.basename(in_bam))[0] + "_" + str(uuid.uuid4()) + ) + os.makedirs(dirname) + + files = [] + bins = list(set(barcodes_to_bins.values())) + filepaths = [] + # barcode_to_bins is a dict of barcodes to ints. The ints are contiguous and are used as indices + # in the files array. The files array is an array of open file handles to write to. + for i in range(len(bins)): + out_bam_name = os.path.join(f"{dirname}", f"{dirname}_{i}.bam") + filepaths.append(out_bam_name) + # For now, bam writing uses one thread for compression. Better logic could support more processes without + # starving the machine for resources + open_bam = pysam.AlignmentFile( + out_bam_name, 'wb', template=input_alignments + ) + files.append(open_bam) + + # Loop over input; check each tag in priority order and partition barcodes into files based + # on the highest priority tag that is identified + for alignment in input_alignments: + barcode = get_barcode_for_alignment(alignment, tags, raise_missing) + if barcode is not None: + # Find or set the file associated with the tag and write the record to the correct file + out_file = files[barcodes_to_bins[barcode]] + out_file.write(alignment) + + for file in files: + file.close() + + return filepaths + + +def merge_bams(bams: List[str]) -> str: + """ Merge input bams using samtools. + + This cannot be a local function within `split` because then Python "cannot pickle a local object". + :param bams: Name of the final bam + bams to merge. + Because of how its called using multiprocessing, the bam basename is the first element of the list. + :return: The output bam name. + """ + bam_name = os.path.realpath(bams[0] + ".bam") + bams_to_merge = bams[1:] + pysam.merge('-c', '-p', bam_name, *bams_to_merge) + return bam_name + + def split( - in_bam: str, out_prefix: str, *tags, approx_mb_per_split=1000, raise_missing=True + in_bams: List[str], + out_prefix: str, + tags: List[str], + approx_mb_per_split: float = 1000, + raise_missing: bool = True, + num_processes: int = None, ) -> List[str]: """split `in_bam` by tag into files of `approx_mb_per_split` Parameters ---------- - in_bam : str - Input bam file. + in_bams : str + Input bam files. out_prefix : str Prefix for all output files; output will be named as prefix_n where n is an integer equal to the chunk number. - tags : tuple + tags : List[str] The bam tags to split on. The tags are checked in order, and sorting is done based on the first identified tag. Further tags are only checked if the first tag is missing. This is useful in cases where sorting is executed over a corrected barcode, but some records only @@ -248,6 +388,8 @@ def split( raise_missing : bool, optional if True, raise a RuntimeError if a record is encountered without a tag. Else silently discard the record (default = True) + num_processes : int, optional + The number of processes to parallelize over. If not set, will use all available processes. Returns ------- @@ -266,34 +408,11 @@ def split( if len(tags) == 0: raise ValueError('At least one tag must be passed') - def _cleanup( - _files_to_counts: Dict[pysam.AlignmentFile, int], - _files_to_names: Dict[pysam.AlignmentFile, str], - rm_all: bool = False, - ) -> None: - """Closes file handles and remove any empty files. - - Parameters - ---------- - _files_to_counts : dict - Dictionary of file objects to the number of reads each contains. - _files_to_names : dict - Dictionary of file objects to file handles. - rm_all : bool, optional - If True, indicates all files should be removed, regardless of count number, else only - empty files without counts are removed (default = False) - - """ - for bamfile, count in _files_to_counts.items(): - # corner case: clean up files that were created, but didn't get data because - # n_cell < n_barcode - bamfile.close() - if count == 0 or rm_all: - os.remove(_files_to_names[bamfile]) - del _files_to_names[bamfile] + if num_processes is None: + num_processes = multiprocessing.cpu_count() # find correct number of subfiles to spawn - bam_mb = os.path.getsize(in_bam) * 1e-6 + bam_mb = sum(os.path.getsize(b) * 1e-6 for b in in_bams) n_subfiles = int(math.ceil(bam_mb / approx_mb_per_split)) if n_subfiles > consts.MAX_BAM_SPLIT_SUBFILES_TO_WARN: warnings.warn( @@ -307,59 +426,69 @@ def _cleanup( f'think about increasing max_mb_per_split.' ) - # create all the output files - with pysam.AlignmentFile(in_bam, 'rb', check_sq=False) as input_alignments: + full_pool = multiprocessing.Pool(num_processes) - # map files to counts - files_to_counts: Dict[pysam.AlignmentFile, int] = {} - files_to_names: Dict[pysam.AlignmentFile, str] = {} - for i in range(n_subfiles): - out_bam_name = out_prefix + '_%d.bam' % i - open_bam = pysam.AlignmentFile( - out_bam_name, 'wb', template=input_alignments - ) - files_to_counts[open_bam] = 0 - files_to_names[open_bam] = out_bam_name + # Get all the barcodes over all the bams + os.write(STDERR, b'Retrieving barcodes from bams\n') + result = full_pool.map( + partial(get_barcodes_from_bam, tags=tags, raise_missing=raise_missing), in_bams + ) - # cycler over files to assign new barcodes to next file - file_cycler = cycle(files_to_counts.keys()) + barcodes_list = list(reduce(lambda set1, set2: set1.union(set2), result)) + os.write(STDERR, b'Retrieved barcodes from bams\n') + + # Create the barcodes to bin mapping + os.write(STDERR, b'Allocating bins\n') + barcodes_to_bins_dict = {} + + # barcodes_list will always contain non-None elements from get_barcodes_from_bam + if len(barcodes_list) <= n_subfiles: + for barcode_index in range(len(barcodes_list)): + barcodes_to_bins_dict[barcodes_list[barcode_index]] = barcode_index + else: + for barcode_index in range(len(barcodes_list)): + file_index = barcode_index % n_subfiles + barcodes_to_bins_dict[barcodes_list[barcode_index]] = file_index + + # Split the bams by barcode in parallel + os.write(STDERR, b'Splitting the bams by barcode\n') + # Samtools needs a thread for compression, so we leave half the given processes open. + write_pool_processes = math.ceil(num_processes / 2) if num_processes > 2 else 1 + write_pool = multiprocessing.Pool(write_pool_processes) + scattered_split_result = write_pool.map( + partial( + write_barcodes_to_bins, + tags=list(tags), + raise_missing=raise_missing, + barcodes_to_bins=barcodes_to_bins_dict, + ), + in_bams, + ) - # create an empty map for (tag, barcode) to files - tags_to_files = {} + bin_indices = list(set(barcodes_to_bins_dict.values())) + # Create a list of lists, where the first element of every sub-list is the name of the final output bam + bins = list([f"{out_prefix}_{index}"] for index in bin_indices) - # loop over input; check each tag in priority order and partition barcodes into files based - # on the highest priority tag that is identified - for alignment in input_alignments: + # A shard is the computation of writing barcodes to bins + # Gather all the files for each bin into the same sub-list. + for shard_index in range(len(scattered_split_result)): + shard = scattered_split_result[shard_index] + for file_index in range(len(shard)): + bins[file_index].append(shard[file_index]) - for tag in tags: - try: - tag_content = tag, alignment.get_tag(tag) - break - except KeyError: - tag_content = None - continue # move on to next tag - - # No provided tag was found on the record that had a non-null value - if tag_content is None: - if raise_missing: - _cleanup(files_to_counts, files_to_names, rm_all=True) - raise RuntimeError( - 'Alignment encountered that is missing {repr(tags)} tag(s).' - ) - else: - continue # move on to next alignment - - # find or set the file associated with the tag and write the record to the correct file - try: - out_file = tags_to_files[tag_content] - except KeyError: - out_file = next(file_cycler) - tags_to_files[tag_content] = out_file - out_file.write(alignment) - files_to_counts[out_file] += 1 - - _cleanup(files_to_counts, files_to_names) - return list(files_to_names.values()) + write_pool.close() + + # Recombine the binned bams + os.write(STDERR, b'Merging temporary bam files\n') + merged_bams = full_pool.map(partial(merge_bams), bins) + + os.write(STDERR, b'deleting temporary files\n') + for paths in scattered_split_result: + shutil.rmtree(os.path.dirname(paths[0])) + + full_pool.close() + + return merged_bams # todo change this to throw away "None" reads instead of appending them if we are filtering them diff --git a/src/sctools/platform.py b/src/sctools/platform.py index 272b693..70e723d 100644 --- a/src/sctools/platform.py +++ b/src/sctools/platform.py @@ -169,7 +169,9 @@ def split_bam(cls, args: Iterable = None) -> int: """ parser = argparse.ArgumentParser() - parser.add_argument('-b', '--bamfile', required=True, help='input bamfile') + parser.add_argument( + '-b', '--bamfile', nargs='+', required=True, help='input bamfile' + ) parser.add_argument( '-p', '--output-prefix', required=True, help='prefix for output chunks' ) @@ -181,6 +183,13 @@ def split_bam(cls, args: Iterable = None) -> int: type=float, help='approximate size target for each subfile (in MB)', ) + parser.add_argument( + '--num-processes', + required=False, + default=None, + type=int, + help='Number of processes to parallelize over', + ) parser.add_argument( '-t', '--tags', @@ -204,9 +213,10 @@ def split_bam(cls, args: Iterable = None) -> int: filenames = bam.split( args.bamfile, args.output_prefix, - *args.tags, + args.tags, approx_mb_per_split=args.subfile_size, - raise_missing=args.raise_missing, + raise_missing=args.drop_missing, + num_processes=args.num_processes, ) print(' '.join(filenames)) diff --git a/src/sctools/test/test_bam.py b/src/sctools/test/test_bam.py index 286536a..91695a8 100644 --- a/src/sctools/test/test_bam.py +++ b/src/sctools/test/test_bam.py @@ -1,6 +1,7 @@ from copy import deepcopy import glob import os +import shutil import pysam import pytest @@ -136,9 +137,9 @@ def test_split_bam_raises_value_error_when_passed_bam_without_barcodes(bamfile): ) # our test data is very small, 0.01mb = ~10kb, which should yield 5 files. with pytest.raises(RuntimeError): bam.split( - bamfile, + [bamfile], 'test_output', - consts.CELL_BARCODE_TAG_KEY, + [consts.CELL_BARCODE_TAG_KEY], approx_mb_per_split=split_size, ) @@ -164,10 +165,9 @@ def tagged_bam(): def test_split_on_tagged_bam(tagged_bam): split_size = 0.005 # our test data is very small, this value should yield 3 files outputs = bam.split( - tagged_bam, + [tagged_bam], 'test_output', - consts.CELL_BARCODE_TAG_KEY, - consts.RAW_CELL_BARCODE_TAG_KEY, + [consts.CELL_BARCODE_TAG_KEY, consts.RAW_CELL_BARCODE_TAG_KEY], approx_mb_per_split=split_size, ) assert len(outputs) == 3 @@ -181,10 +181,9 @@ def test_split_on_tagged_bam(tagged_bam): def test_split_with_large_chunk_size_generates_one_file(tagged_bam): split_size = 1024 # our test data is very small, this value should yield 1 file outputs = bam.split( - tagged_bam, + [tagged_bam], 'test_output', - consts.CELL_BARCODE_TAG_KEY, - consts.RAW_CELL_BARCODE_TAG_KEY, + [consts.CELL_BARCODE_TAG_KEY, consts.RAW_CELL_BARCODE_TAG_KEY], approx_mb_per_split=split_size, ) assert len(outputs) == 1 @@ -204,10 +203,10 @@ def test_split_with_raise_missing_true_raises_warning_without_cr_barcode_passed( ): split_size = 1024 # our test data is very small, this value should yield 1 file with pytest.raises(RuntimeError): - outputs = bam.split( - tagged_bam, + bam.split( + [tagged_bam], 'test_output', - consts.CELL_BARCODE_TAG_KEY, + [consts.CELL_BARCODE_TAG_KEY], approx_mb_per_split=split_size, raise_missing=True, ) @@ -221,9 +220,9 @@ def test_split_with_raise_missing_true_raises_warning_without_cr_barcode_passed( def test_split_succeeds_with_raise_missing_false_and_no_cr_barcode_passed(tagged_bam): split_size = 1024 # our test data is very small, this value should yield 1 file outputs = bam.split( - tagged_bam, + [tagged_bam], 'test_output', - consts.CELL_BARCODE_TAG_KEY, + [consts.CELL_BARCODE_TAG_KEY], approx_mb_per_split=split_size, raise_missing=False, ) @@ -242,6 +241,68 @@ def test_split_succeeds_with_raise_missing_false_and_no_cr_barcode_passed(tagged os.remove(f) +def test_get_barcodes_from_bam(tagged_bam): + outputs = bam.get_barcodes_from_bam( + tagged_bam, + [consts.CELL_BARCODE_TAG_KEY, consts.RAW_CELL_BARCODE_TAG_KEY], + raise_missing=True, + ) + assert len(outputs) == 99 + + +def test_get_barcodes_from_bam_with_raise_missing_true_raises_warning_without_cr_barcode_passed( + tagged_bam +): + with pytest.raises(RuntimeError): + bam.get_barcodes_from_bam( + tagged_bam, [consts.CELL_BARCODE_TAG_KEY], raise_missing=True + ) + + +def test_write_barcodes_to_bins(tagged_bam): + barcodes = bam.get_barcodes_from_bam( + tagged_bam, + [consts.CELL_BARCODE_TAG_KEY, consts.RAW_CELL_BARCODE_TAG_KEY], + raise_missing=True, + ) + + test_barcodes_to_bins = {} + for barcode in barcodes: + test_barcodes_to_bins[barcode] = 0 + + filenames = bam.write_barcodes_to_bins( + tagged_bam, + [consts.CELL_BARCODE_TAG_KEY, consts.RAW_CELL_BARCODE_TAG_KEY], + test_barcodes_to_bins, + raise_missing=False, + ) + + assert len(filenames) == 1 + + # cleanup + for f in filenames: + shutil.rmtree(os.path.dirname(f)) + + +def test_get_barcode_for_alignment(tagged_bam): + with pysam.AlignmentFile(tagged_bam, 'rb', check_sq=False) as input_alignments: + for alignment in input_alignments: + barcode = bam.get_barcode_for_alignment( + alignment, + [consts.CELL_BARCODE_TAG_KEY, consts.RAW_CELL_BARCODE_TAG_KEY], + raise_missing=False, + ) + assert barcode == "NTAAGAGTCTGCAAGT" + break + + +def test_get_barcode_for_alignment_raises_error_for_missing_tag(tagged_bam): + with pysam.AlignmentFile(tagged_bam, 'rb', check_sq=False) as input_alignments: + for alignment in input_alignments: + with pytest.raises(RuntimeError): + bam.get_barcode_for_alignment(alignment, TAG_KEYS, raise_missing=True) + + # TEST SORTING