From 9e198aaf7f15150cf8b9f56b451636c7780ce959 Mon Sep 17 00:00:00 2001 From: Kishori M Konwar <43380010+kishorikonwar@users.noreply.github.com> Date: Mon, 18 May 2020 15:45:32 -0400 Subject: [PATCH] Kmk mt counts (#76) * added n_mitochondrial_genes * fixed a test * fomatted * fixed flake8 errors --- src/sctools/gtf.py | 43 ++++++++++++++++++++++++++++++- src/sctools/metrics/aggregator.py | 10 +++++-- src/sctools/metrics/gatherer.py | 15 ++++++++--- src/sctools/platform.py | 19 ++++++++++++-- 4 files changed, 79 insertions(+), 8 deletions(-) diff --git a/src/sctools/gtf.py b/src/sctools/gtf.py index cf5ff57..fd075ec 100644 --- a/src/sctools/gtf.py +++ b/src/sctools/gtf.py @@ -18,7 +18,8 @@ import logging import string -from typing import List, Dict, Generator, Iterable, Union +import re +from typing import List, Dict, Generator, Iterable, Union, Set from . import reader @@ -260,6 +261,46 @@ def _resolve_multiple_gene_names(gene_name: str): ) +def get_mitochondrial_gene_names( + files: Union[str, List[str]] = "-", mode: str = "r", header_comment_char: str = "#" +) -> Set[str]: + """Extract mitocholdrial gene names from GTF file(s) and returns a set of mitochondrial + gene id occurrence in the given file(s). + + Parameters + ---------- + files : Union[str, List], optional + File(s) to read. If '-', read sys.stdin (default = '-') + mode : {'r', 'rb'}, optional + Open mode. If 'r', read strings. If 'rb', read bytes (default = 'r'). + header_comment_char : str, optional + lines beginning with this character are skipped (default = '#') + + Returns + ------- + Set(str) + A set of the mitochondrial gene ids + """ + + mitochondrial_gene_ids: Set[str] = set() + for record in Reader(files, mode, header_comment_char).filter( + retain_types=["gene"] + ): + gene_name = record.get_attribute("gene_name") + gene_id = record.get_attribute("gene_id") + + if gene_name is None: + raise ValueError( + f"Malformed GTF file detected. Record is of type gene but does not have a " + f'"gene_name" field: {record}' + ) + if re.match('^mt-', gene_name, re.IGNORECASE): + if gene_id not in mitochondrial_gene_ids: + mitochondrial_gene_ids.add(gene_id) + + return mitochondrial_gene_ids + + def extract_gene_names( files: Union[str, List[str]] = "-", mode: str = "r", header_comment_char: str = "#" ) -> Dict[str, int]: diff --git a/src/sctools/metrics/aggregator.py b/src/sctools/metrics/aggregator.py index 72afedf..151c743 100644 --- a/src/sctools/metrics/aggregator.py +++ b/src/sctools/metrics/aggregator.py @@ -412,6 +412,8 @@ class CellMetrics(MetricAggregator): The number of genes detected by this cell genes_detected_multiple_observations : int The number of genes that are observed by more than one read in this cell + n_mitochondrial_genes: int + The number of mitochondrial genes detected by this cell """ @@ -450,8 +452,9 @@ def __init__(self): self.cell_barcode_fraction_bases_above_30_mean: float = None self.n_genes: int = None self.genes_detected_multiple_observations: int = None + self.n_mitochondrial_genes: int = None - def finalize(self): + def finalize(self, mitochondrial_genes=set()): super().finalize() self.cell_barcode_fraction_bases_above_30_mean: float = self._cell_barcode_fraction_bases_above_30.mean @@ -464,6 +467,10 @@ def finalize(self): 1 for v in self._genes_histogram.values() if v > 1 ) + self.n_mitochondrial_genes: int = sum( + 1 for g in self._genes_histogram.keys() if g in mitochondrial_genes + ) + def parse_extra_fields( self, tags: Sequence[str], record: pysam.AlignedSegment ) -> None: @@ -502,7 +509,6 @@ def parse_extra_fields( self.reads_unmapped += 1 # todo track reads_mapped_too_many_loci after multi-alignment is done - self._genes_histogram[tags[2]] += 1 # note that no gene == None diff --git a/src/sctools/metrics/gatherer.py b/src/sctools/metrics/gatherer.py index b207ccc..10d47c5 100644 --- a/src/sctools/metrics/gatherer.py +++ b/src/sctools/metrics/gatherer.py @@ -28,6 +28,7 @@ from contextlib import closing import pysam +from typing import Set from sctools.bam import iter_cell_barcodes, iter_genes, iter_molecule_barcodes from sctools.metrics.aggregator import CellMetrics, GeneMetrics @@ -55,10 +56,17 @@ class MetricGatherer: """ - def __init__(self, bam_file: str, output_stem: str, compress: bool = True): + def __init__( + self, + bam_file: str, + output_stem: str, + mitochondrial_gene_ids: Set[str] = set(), + compress: bool = True, + ): self._bam_file = bam_file self._output_stem = output_stem self._compress = compress + self._mitochondrial_gene_ids = mitochondrial_gene_ids @property def bam_file(self) -> str: @@ -114,7 +122,6 @@ def extract_metrics(self, mode: str = 'rb') -> None: Open mode for self.bam. 'r' -> sam, 'rb' -> bam (default = 'rb'). """ - # open the files with pysam.AlignmentFile(self.bam_file, mode=mode) as bam_iterator, closing( MetricCSVWriter(self._output_stem, self._compress) @@ -146,7 +153,9 @@ def extract_metrics(self, mode: str = 'rb') -> None: ) # write a record for each cell - metric_aggregator.finalize() + metric_aggregator.finalize( + mitochondrial_genes=self._mitochondrial_gene_ids + ) cell_metrics_output.write(cell_tag, vars(metric_aggregator)) diff --git a/src/sctools/platform.py b/src/sctools/platform.py index b9c31db..460f26a 100644 --- a/src/sctools/platform.py +++ b/src/sctools/platform.py @@ -18,7 +18,7 @@ """ import argparse -from typing import Iterable, List, Dict, Optional, Sequence +from typing import Iterable, List, Dict, Set, Optional, Sequence from itertools import chain import pysam @@ -286,13 +286,28 @@ def calculate_cell_metrics(cls, args: Iterable[str] = None) -> int: parser.add_argument( "-o", "--output-filestem", required=True, help="Output file stem." ) + parser.add_argument( + "-a", + "--gtf-annotation-file", + required=False, + default=None, + help="gtf annotation file that bam_file was aligned against", + ) if args is not None: args = parser.parse_args(args) else: args = parser.parse_args() + + # load mitochondrial gene ids from the annotation file + mitochondrial_gene_ids: Set(str) = set() + if args.gtf_annotation_file: + mitochondrial_gene_ids = gtf.get_mitochondrial_gene_names( + args.gtf_annotation_file + ) + cell_metric_gatherer = metrics.gatherer.GatherCellMetrics( - args.input_bam, args.output_filestem + args.input_bam, args.output_filestem, mitochondrial_gene_ids ) cell_metric_gatherer.extract_metrics() return 0