Skip to content

Commit

Permalink
Add ability to sort bams by arbitrary list of tags plus queryname, an…
Browse files Browse the repository at this point in the history
…d to verify sorting. (170) (#46)
  • Loading branch information
David Shiga authored Aug 24, 2018
1 parent 9bff1c6 commit 63a7932
Show file tree
Hide file tree
Showing 8 changed files with 508 additions and 29 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
'MergeCellMetrics = sctools.platform:GenericPlatform.merge_cell_metrics',
'CreateCountMatrix = sctools.platform:GenericPlatform.bam_to_count_matrix',
'MergeCountMatrices = sctools.platform:GenericPlatform.merge_count_matrices',
'TagSortBam = sctools.platform:GenericPlatform.tag_sort_bam',
'VerifyBamSort = sctools.platform:GenericPlatform.verify_bam_sort'
]
},
classifiers=CLASSIFIERS,
Expand Down
110 changes: 86 additions & 24 deletions src/sctools/bam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,31 @@
iter_cell_barcodes wrapper for iter_tag_groups that iterates over cell barcode tags
iter_genes wrapper for iter_tag_groups that iterates over gene tags
iter_molecules wrapper for iter_tag_groups that iterates over molecule tags
sort_by_tags_and_queryname sort bam by given list of zero or more tags, followed by query name
verify_sort verifies whether bam is correctly sorted by given list of tags, then query name
Classes
-------
SubsetAlignments class to extract reads specific to requested chromosome(s)
Tagger class to add tags to sam/bam records from paired fastq records
AlignmentSortOrder abstract class to represent alignment sort orders
QueryNameSortOrder alignment sort order by query name
CellMoleculeGeneQueryNameSortOrder alignment sort order hierarchically cell > molecule > gene > query name
TagSortableRecord class to facilitate sorting of pysam.AlignedSegments
SortError error raised when sorting is incorrect
References
----------
htslib : https://github.com/samtools/htslib
"""

import functools
import math
import os
import warnings
from abc import abstractmethod
from itertools import cycle
from typing import Iterator, Generator, List, Dict, Union, Tuple, Callable, Any, Optional
from typing import Iterator, Iterable, Generator, List, Dict, Union, Tuple, Callable, Any, Optional

import pysam

Expand Down Expand Up @@ -464,29 +468,87 @@ def __repr__(self) -> str:
return 'query_name'


class CellMoleculeGeneQueryNameSortOrder(AlignmentSortOrder):
"""Hierarchical alignment record sort order (cell barcode >= molecule barcode >= gene name >= query name)."""
@functools.total_ordering
class TagSortableRecord(object):
"""Wrapper for pysam.AlignedSegment that facilitates sorting by tags and query name."""

def __init__(
self,
cell_barcode_tag_key: str = consts.CELL_BARCODE_TAG_KEY,
molecule_barcode_tag_key: str = consts.MOLECULE_BARCODE_TAG_KEY,
gene_name_tag_key: str = consts.GENE_NAME_TAG_KEY) -> None:
assert cell_barcode_tag_key, "Cell barcode tag key can not be None"
assert molecule_barcode_tag_key, "Molecule barcode tag key can not be None"
assert gene_name_tag_key, "Gene name tag key can not be None"
self.cell_barcode_tag_key = cell_barcode_tag_key
self.molecule_barcode_tag_key = molecule_barcode_tag_key
self.gene_name_tag_key = gene_name_tag_key

def _get_sort_key(self, alignment: pysam.AlignedSegment) -> Tuple[str, str, str, str]:
return (get_tag_or_default(alignment, self.cell_barcode_tag_key, default='N'),
get_tag_or_default(alignment, self.molecule_barcode_tag_key, default='N'),
get_tag_or_default(alignment, self.gene_name_tag_key, default='N'),
alignment.query_name)

@property
def key_generator(self) -> Callable[[pysam.AlignedSegment], Tuple[str, str, str, str]]:
return self._get_sort_key
tag_keys: Iterable[str],
tag_values: Iterable[str],
query_name: str,
record: pysam.AlignedSegment = None) -> None:
self.tag_keys = tag_keys
self.tag_values = tag_values
self.query_name = query_name
self.record = record

@classmethod
def from_aligned_segment(cls, record: pysam.AlignedSegment, tag_keys: Iterable[str]) -> 'TagSortableRecord':
"""Create a TagSortableRecord from a pysam.AlignedSegment and list of tag keys"""
assert record is not None
tag_values = [get_tag_or_default(record, key, '') for key in tag_keys]
query_name = record.query_name
return cls(tag_keys, tag_values, query_name, record)

def __lt__(self, other: object) -> bool:
if not isinstance(other, TagSortableRecord):
return NotImplemented
self.__verify_tag_keys_match(other)
for (self_tag_value, other_tag_value) in zip(self.tag_values, other.tag_values):
if self_tag_value < other_tag_value:
return True
elif self_tag_value > other_tag_value:
return False
return self.query_name < other.query_name

def __eq__(self, other: object) -> bool:
# TODO: Add more error checking
if not isinstance(other, TagSortableRecord):
return NotImplemented
self.__verify_tag_keys_match(other)
for (self_tag_value, other_tag_value) in zip(self.tag_values, other.tag_values):
if self_tag_value != other_tag_value:
return False
return self.query_name == other.query_name

def __verify_tag_keys_match(self, other) -> None:
if self.tag_keys != other.tag_keys:
format_str = 'Cannot compare records using different tag lists: {0}, {1}'
raise ValueError(format_str.format(self.tag_keys, other.tag_keys))

def __str__(self) -> str:
return self.__repr__()

def __repr__(self) -> str:
return 'hierarchical__cell_molecule_gene_query_name'
format_str = 'TagSortableRecord(tags: {0}, tag_values: {1}, query_name: {2}'
return format_str.format(self.tag_keys, self.tag_values, self.query_name)


def sort_by_tags_and_queryname(
records: Iterable[pysam.AlignedSegment],
tag_keys: Iterable[str]) -> Iterable[pysam.AlignedSegment]:
"""Sorts the given bam records by the given tags, followed by query name.
If no tags are given, just sorts by query name.
"""
tag_sortable_records = (TagSortableRecord.from_aligned_segment(r, tag_keys) for r in records)
sorted_records = sorted(tag_sortable_records)
aligned_segments = (r.record for r in sorted_records)
return aligned_segments


def verify_sort(records: Iterable[TagSortableRecord], tag_keys: Iterable[str]) -> None:
"""Raise AssertionError if the given records are not correctly sorted by the given tags and query name"""
# Setting tag values and query name to empty string ensures first record will never be less than old_record
old_record = TagSortableRecord(tag_keys=tag_keys, tag_values=['' for _ in tag_keys], query_name='', record=None)
i = 0
for record in records:
i += 1
if not record >= old_record:
msg = 'Records {0} and {1} are not in correct order:\n{1}:{2} \nis less than \n{0}:{3}'
raise SortError(msg.format(i - 1, i, record, old_record))
old_record = record


class SortError(Exception):
pass
96 changes: 95 additions & 1 deletion src/sctools/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
"""

import argparse
from typing import Iterable, List, Dict
from typing import Iterable, List, Dict, Optional, Sequence
from itertools import chain

import pysam
from sctools import fastq, bam, metrics, count, consts, gtf


Expand All @@ -28,6 +30,10 @@ class GenericPlatform:
Platform-Agnostic Methods
-------------------------
tag_sort_bam():
sort a bam file by zero or more tags and then by queryname
verify_bam_sort():
verifies whether bam file is correctly sorted by given list of zero or more tags, then queryname
split_bam()
split a bam file into subfiles of equal size
calculate_gene_metrics()
Expand All @@ -44,6 +50,94 @@ class GenericPlatform:
merge multiple csr-format count matrices into a single csr matrix
"""

@classmethod
def tag_sort_bam(cls, args: Iterable=None) -> int:
"""Command line entrypoint for sorting a bam file by zero or more tags, followed by queryname.
Parameters
----------
args : Iterable[str], optional
arguments list, for testing (see test/test_entrypoints.py for example). The default
value of None, when passed to `parser.parse_args` causes the parser to
read `sys.argv`
Returns
-------
return_call : 0
return call if the program completes successfully
"""
description = 'Sorts bam by list of zero or more tags, followed by query name'
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
'-i', '--input_bam', required=True,
help='input bamfile')
parser.add_argument(
'-o', '--output_bam', required=True,
help='output bamfile')
parser.add_argument('-t', '--tags', nargs='+', action='append',
help='tag(s) to sort by, separated by space, e.g. -t CB GE UB')
if args is not None:
args = parser.parse_args(args)
else:
args = parser.parse_args()

tags = cls.get_tags(args.tags)
with pysam.AlignmentFile(args.input_bam, 'rb') as f:
header = f.header
records = f.fetch(until_eof=True)
sorted_records = bam.sort_by_tags_and_queryname(records, tags)
with pysam.AlignmentFile(args.output_bam, 'wb', header=header) as f:
for record in sorted_records:
f.write(record)

return 0

@classmethod
def verify_bam_sort(cls, args: Iterable=None) -> int:
"""Command line entrypoint for verifying bam is properly sorted by zero or more tags, followed by queryname.
Parameters
----------
args : Iterable[str], optional
arguments list, for testing (see test/test_entrypoints.py for example). The default
value of None, when passed to `parser.parse_args` causes the parser to
read `sys.argv`
Returns
-------
return_call : 0
return call if the program completes successfully
"""
description = 'Verifies whether bam is sorted by the list of zero or more tags, followed by query name'
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
'-i', '--input_bam', required=True,
help='input bamfile')
parser.add_argument('-t', '--tags', nargs='+', action='append',
help='tag(s) to use to verify sorting, separated by space, e.g. -t CB GE UB')
if args is not None:
args = parser.parse_args(args)
else:
args = parser.parse_args()

tags = cls.get_tags(args.tags)
with pysam.AlignmentFile(args.input_bam, 'rb') as f:
aligned_segments = f.fetch(until_eof=True)
sortable_records = (bam.TagSortableRecord.from_aligned_segment(r, tags) for r in aligned_segments)
bam.verify_sort(sortable_records, tags)

print('{0} is correctly sorted by {1} and query name'.format(args.input_bam, tags))
return 0

@classmethod
def get_tags(cls, raw_tags: Optional[Sequence[str]]) -> Iterable[str]:
if raw_tags is None:
raw_tags = []
# Flattens into single list when tags specified like -t A -t B -t C
return [t for t in chain.from_iterable(raw_tags)]

@classmethod
def split_bam(cls, args: Iterable=None) -> int:
"""Command line entrypoint for splitting a bamfile into subfiles of equal size.
Expand Down
Binary file not shown.
Binary file added src/sctools/test/data/unsorted.bam
Binary file not shown.
Loading

0 comments on commit 63a7932

Please sign in to comment.