diff --git a/dynast/__init__.py b/dynast/__init__.py index cd7ca49..6ba6abe 100755 --- a/dynast/__init__.py +++ b/dynast/__init__.py @@ -1 +1 @@ -__version__ = '1.0.1' +__version__ = '1.0.2.beta' diff --git a/dynast/consensus.py b/dynast/consensus.py index b7bcdef..f1f68f1 100755 --- a/dynast/consensus.py +++ b/dynast/consensus.py @@ -24,8 +24,10 @@ def consensus( add_RS_RI: bool = False, n_threads: int = 8, temp_dir: Optional[str] = None, + collapse_r1_r2: bool = False # <-- NEW parameter ): - """Main interface for the `consensus` command. + """ + Main interface for the `consensus` command. Args: bam_path: Path to BAM to call consensus from @@ -40,15 +42,19 @@ def consensus( add_RS_RI: Add RS and RI tags to BAM. Mostly useful for debugging. n_threads: Number of threads to use temp_dir: Temporary directory + collapse_r1_r2: If True, reads from R1 and R2 for the same UMI + will be combined into a single consensus read. If False, + R1 and R2 will each get their own consensus. """ stats = Stats() stats.start() stats_path = os.path.join( - out_dir, f'{constants.STATS_PREFIX}_{dt.datetime.strftime(stats.start_time, "%Y%m%d_%H%M%S_%f")}.json' + out_dir, + f'{constants.STATS_PREFIX}_{dt.datetime.strftime(stats.start_time, "%Y%m%d_%H%M%S_%f")}.json' ) os.makedirs(out_dir, exist_ok=True) - # Sort and index bam. + # Sort and index BAM bam_path = preprocessing.sort_and_index_bam( bam_path, '{}.sortedByCoord{}'.format(*os.path.splitext(bam_path)), n_threads=n_threads ) @@ -105,6 +111,10 @@ def consensus( consensus_path = os.path.join(out_dir, constants.CONSENSUS_BAM_FILENAME) logger.info(f'Calling consensus sequences from BAM to {consensus_path}') + + # ---------------------------------------------------------------------- + # Pass collapse_r1_r2 into preprocessing.call_consensus + # ---------------------------------------------------------------------- preprocessing.call_consensus( bam_path, consensus_path, @@ -117,7 +127,9 @@ def consensus( quality=quality, add_RS_RI=add_RS_RI, temp_dir=temp_dir, - n_threads=n_threads + n_threads=n_threads, + collapse_r1_r2=collapse_r1_r2 ) + stats.end() stats.save(stats_path) diff --git a/dynast/main.py b/dynast/main.py index d106c4d..9b5ac0d 100755 --- a/dynast/main.py +++ b/dynast/main.py @@ -167,16 +167,7 @@ def setup_align_args(parser: argparse.ArgumentParser, parent: argparse.ArgumentP def setup_consensus_args(parser: argparse.ArgumentParser, parent: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Helper function to set up a subparser for the `consensus` command. - - Args: - parser: Argparse parser to add the `consensus` command to - parent: Argparse parser parent of the newly added subcommand. - Used to inherit shared commands/flags - - Returns: - The newly added parser - """ + """Helper function to set up a subparser for the `consensus` command.""" parser_consensus = parser.add_parser( 'consensus', description='Generate consensus sequences', @@ -258,6 +249,22 @@ def setup_consensus_args(parser: argparse.ArgumentParser, parent: argparse.Argum ), action='store_true' ) + + # ---------------------------------------------------------------- + # NEW ARGUMENT to toggle separate R1/R2. Default is False => + # (which means R1/R2 are collapsed together internally). + # ---------------------------------------------------------------- + parser_consensus.add_argument( + '--separate-r1-r2', + action='store_true', + default=False, + help=( + 'If specified, R1 and R2 from the same UMI will be collapsed ' + 'separately (i.e. produce two consensus reads). By default, ' + 'they are collapsed into one consensus read.' + ), + ) + parser_consensus.add_argument( 'bam', help=( @@ -268,7 +275,6 @@ def setup_consensus_args(parser: argparse.ArgumentParser, parent: argparse.Argum ) return parser_consensus - def setup_count_args(parser: argparse.ArgumentParser, parent: argparse.ArgumentParser) -> argparse.ArgumentParser: """Helper function to set up a subparser for the `count` command. @@ -727,12 +733,9 @@ def parse_align(parser, args, temp_dir=None): def parse_consensus(parser: argparse.ArgumentParser, args: argparse.Namespace, temp_dir: Optional[str] = None): - """Parser for the `consensus` command. - - Args: - parser: The parser - args: Command-line arguments dictionary, as parsed by argparse - temp_dir: Temporary directory + """ + Parser for the `consensus` command. We also handle whether to + separate R1/R2 (based on `args.separate_r1_r2`). """ # Check quality if args.quality < 0 or args.quality > 41: @@ -759,11 +762,19 @@ def parse_consensus(parser: argparse.ArgumentParser, args: argparse.Namespace, t else: logger.info(f'Auto-detected strandedness: {strand}. Use `--strand` to override.') + # ------------------------------------------------------------------ + # Convert user-specified --separate-r1-r2 to collapse_r1_r2 bool. + # If separate_r1_r2 == True => we do NOT collapse them together => + # collapse_r1_r2 = False + # If separate_r1_r2 == False => we collapse them => True + # ------------------------------------------------------------------ + collapse_r1_r2 = not args.separate_r1_r2 + from .consensus import consensus consensus( - args.bam, - args.g, - args.o, + bam_path=args.bam, + gtf_path=args.g, + out_dir=args.o, strand=strand, umi_tag=args.umi_tag, barcode_tag=args.barcode_tag, @@ -773,6 +784,7 @@ def parse_consensus(parser: argparse.ArgumentParser, args: argparse.Namespace, t add_RS_RI=args.add_RS_RI, n_threads=args.t, temp_dir=temp_dir, + collapse_r1_r2=collapse_r1_r2, # <-- pass the bool down ) @@ -995,19 +1007,20 @@ def main(): metavar='', ) - # Add common options to this parent parser + # Common parent parser parent = argparse.ArgumentParser(add_help=False) parent.add_argument('--tmp', metavar='TMP', help='Override default temporary directory', type=str, default='tmp') parent.add_argument('--keep-tmp', help='Do not delete the tmp directory', action='store_true') parent.add_argument('--verbose', help='Print debugging information', action='store_true') parent.add_argument('-t', metavar='THREADS', help='Number of threads to use (default: 8)', type=int, default=8) - # Command parsers + # Create each subcommand parser parser_ref = setup_ref_args(subparsers, parent) parser_align = setup_align_args(subparsers, parent) parser_consensus = setup_consensus_args(subparsers, parent) parser_count = setup_count_args(subparsers, parent) parser_estimate = setup_estimate_args(subparsers, parent) + command_to_parser = { 'ref': parser_ref, 'align': parser_align, @@ -1015,10 +1028,10 @@ def main(): 'count': parser_count, 'estimate': parser_estimate, } + if '--list' in sys.argv: print_technologies() - # Show help when no arguments are given if len(sys.argv) == 1: parser.print_help(sys.stderr) sys.exit(1) @@ -1045,6 +1058,7 @@ def main(): ) os.makedirs(args.tmp) os.environ['NUMEXPR_MAX_THREADS'] = str(args.t) + try: patch_mp_connection_bpo_17560() # Monkeypatch for python 3.7 with warnings.catch_warnings(): diff --git a/dynast/preprocessing/consensus.py b/dynast/preprocessing/consensus.py old mode 100755 new mode 100644 index 8be4313..719878a --- a/dynast/preprocessing/consensus.py +++ b/dynast/preprocessing/consensus.py @@ -18,89 +18,86 @@ BASE_IDX = {base: i for i, base in enumerate(BASES)} +def _majority_flag(reads: List[pysam.AlignedSegment], attr: str) -> bool: + """ + Return True if more than half of the reads have getattr(r, attr) == True. + Otherwise return False. + + E.g.: + _majority_flag(reads, 'is_proper_pair') + _majority_flag(reads, 'is_reverse') + _majority_flag(reads, 'mate_is_reverse') + """ + count_true = sum(1 for r in reads if getattr(r, attr)) + return count_true > (len(reads) / 2) + + def call_consensus_from_reads( reads: List[pysam.AlignedSegment], header: pysam.AlignmentHeader, quality: int = 27, tags: Optional[Dict[str, Any]] = None, + read_number: Optional[str] = None, + shared_qname: Optional[str] = None ) -> pysam.AlignedSegment: - """Call a single consensus alignment given a list of aligned reads. + """ + Call a single consensus alignment given a list of aligned reads. Reads must map to the same contig. Results are undefined otherwise. Additionally, consensus bases are called only for positions that match to the reference (i.e. no insertions allowed). - This function only sets the minimal amount of attributes such that the - alignment is valid. These include: - * read name -- SHA256 hash of the provided read names - * read sequence and qualities - * reference name and ID - * reference start - * mapping quality (MAPQ) - * cigarstring - * MD tag - * NM tag - * Not unmapped, paired, duplicate, qc fail, secondary, nor supplementary - - The caller is expected to further populate the alignment - with additional tags, flags, and name. - - Args: - reads: List of reads to call a consensus sequence from - header: header to use when creating the new pysam alignment - quality: quality threshold - tags: additional tags to set + If read_number is 'R1' or 'R2', we set the resulting consensus read + to is_paired = True and is_read1/is_read2 = True accordingly, + preserving the same QNAME for both consensus reads (if shared_qname is given). - Returns: - New pysam alignment of the consensus sequence + We also aggregate the following flags by majority vote across `reads`: + - is_proper_pair + - is_reverse + - mate_is_reverse """ - if len(set(read.reference_name for read in reads)) > 1: - raise Exception("Can not call consensus from reads mapping to multiple contigs.") + # Validate all reads on the same reference + if len(set(r.reference_name for r in reads)) > 1: + raise Exception("Cannot call consensus from reads mapping to multiple contigs.") - # Pysam coordinates are [start, end) - left_pos = min(read.reference_start for read in reads) - right_pos = max(read.reference_end for read in reads) + # Determine leftmost and rightmost positions + left_pos = min(r.reference_start for r in reads) + right_pos = max(r.reference_end for r in reads) length = right_pos - left_pos - # A consensus sequence is internally represented as a L x 4 matrix, - # where L is the length of the sequence and the columns correspond to - # each of the four bases. The values indicate the support of each base. - # It's possible to switch these to sparse matrices if memory becomes an issue. + # Build a matrix for base counts sequence = np.zeros((length, len(BASES)), dtype=np.uint32) reference = np.full(length, -1, dtype=np.int8) # -1 means unobserved - deletions = 0 for read in reads: - read_sequence = read.query_sequence.upper() - read_qualities = read.query_qualities - for read_i, genome_i, _genome_base in read.get_aligned_pairs(matches_only=False, with_seq=True): - # Insertion - if genome_i is None or _genome_base is None: + read_seq = read.query_sequence.upper() + read_quals = read.query_qualities + for read_i, ref_i, ref_base in read.get_aligned_pairs(matches_only=False, with_seq=True): + if ref_i is None or ref_base is None: + # Insertion or soft-clip, etc. continue - i = genome_i - left_pos - genome_base = _genome_base.upper() - if genome_base == 'N': + i = ref_i - left_pos + ref_base = ref_base.upper() + if ref_base == 'N': continue - # Deletion if read_i is None: + # This is a deletion from the read if reference[i] < 0: - reference[i] = BASE_IDX[genome_base] - deletions += 1 + reference[i] = BASE_IDX[ref_base] continue - - read_base = read_sequence[read_i] + read_base = read_seq[read_i] if read_base == 'N': continue - if reference[i] < 0: - reference[i] = BASE_IDX[genome_base] - sequence[i, BASE_IDX[read_base]] += read_qualities[read_i] + reference[i] = BASE_IDX[ref_base] + sequence[i, BASE_IDX[read_base]] += read_quals[read_i] - # Determine consensus - # Note that we ignore any insertions - consensus_length = (sequence > 0).any(axis=1).sum() + # Now determine the consensus sequence + consensus_mask = (sequence > 0).any(axis=1) + consensus_length = consensus_mask.sum() consensus = np.zeros(consensus_length, dtype=np.uint8) qualities = np.zeros(consensus_length, dtype=np.uint8) + cigar = [] last_cigar_op = None cigar_n = 0 @@ -110,74 +107,73 @@ def call_consensus_from_reads( md_del = False nm = 0 consensus_i = 0 + for i in range(length): - ref = reference[i] - # Region not present in read. MD tag only deals with aligned - # regions, so nothing else needs to be done. + ref_idx = reference[i] cigar_op = 'N' - if ref >= 0: + if ref_idx >= 0: seq = sequence[i] - - # Deletion if (seq == 0).all(): + # Deletion cigar_op = 'D' if md_n > 0 or md_zero: md.append(str(md_n)) md_n = 0 - if not md_del: md.append('^') - md.append(BASES[ref]) + md.append(BASES[ref_idx]) md_del = True - - # Match else: md_del = False - - # On ties, select reference if present. Otherwise, choose lexicographically. base_q = seq.max() if base_q < quality: - base = ref + # If no base surpasses the threshold, choose reference + base = ref_idx else: - bases = (seq == base_q).nonzero()[0] - if len(bases) > 0 and ref in bases: - base = ref + candidates = (seq == base_q).nonzero()[0] + # On tie: if reference is in that tie, choose reference + if ref_idx in candidates: + base = ref_idx else: - base = bases[0] - - # We use the STAR convention of using M cigar operation to mean - # both matches AND mismatches, ignoring the X cigar operation exists. + base = candidates[0] cigar_op = 'M' - - if ref == base: + if base == ref_idx: md_n += 1 md_zero = False else: if md_n > 0 or md_zero: md.append(str(md_n)) md_n = 0 - md.append(BASES[ref]) + md.append(BASES[ref_idx]) md_zero = True nm += 1 - - consensus[consensus_i] = base - qualities[consensus_i] = min(base_q, 42) # Clip to maximum PHRED score - consensus_i += 1 + if consensus_mask[i]: + consensus[consensus_i] = base + qualities[consensus_i] = min(base_q, 42) # cap at PHRED 42 + consensus_i += 1 if cigar_op == last_cigar_op: cigar_n += 1 else: if last_cigar_op: - cigar.append(f'{cigar_n}{last_cigar_op}') + cigar.append(f"{cigar_n}{last_cigar_op}") last_cigar_op = cigar_op cigar_n = 1 - md.append(str(md_n)) # MD tag always ends with a number - cigar.append(f'{cigar_n}{last_cigar_op}') + # Finish up the last cigar chunk + md.append(str(md_n)) + cigar.append(f"{cigar_n}{last_cigar_op}") + # Create a new pysam AlignedSegment al = pysam.AlignedSegment(header) - al.query_name = sha256(''.join(read.query_name for read in reads).encode('utf-8')).hexdigest() - al.query_sequence = ''.join(BASES[i] for i in consensus) + if shared_qname is not None: + al.query_name = shared_qname + else: + # default: hash all read names + all_names = ''.join(r.query_name for r in reads) + al.query_name = sha256(all_names.encode('utf-8')).hexdigest() + + al.query_sequence = ''.join(BASES[b] for b in consensus) al.query_qualities = array.array('B', qualities) al.reference_name = reads[0].reference_name al.reference_id = reads[0].reference_id @@ -190,38 +186,95 @@ def call_consensus_from_reads( tags.update({'MD': ''.join(md), 'NM': nm}) al.set_tags(list(tags.items())) - # Make sure these are False + # Mark read1 or read2 if provided + if read_number == 'R1': + al.is_paired = True + al.is_read1 = True + al.is_read2 = False + elif read_number == 'R2': + al.is_paired = True + al.is_read1 = False + al.is_read2 = True + else: + al.is_paired = False + + # -------------------------------------------------------------- + # NEW: Aggregate certain flags with a "majority vote" approach: + # is_proper_pair + # is_reverse + # mate_is_reverse + # -------------------------------------------------------------- + al.is_proper_pair = _majority_flag(reads, 'is_proper_pair') + # If this consensus read is designated as reversed by majority: + al.is_reverse = _majority_flag(reads, 'is_reverse') + + # Only relevant if read is paired + if al.is_paired: + al.mate_is_reverse = _majority_flag(reads, 'mate_is_reverse') + + # Force false for these al.is_unmapped = False - al.is_paired = False al.is_duplicate = False al.is_qcfail = False al.is_secondary = False al.is_supplementary = False + return al -def call_consensus_from_reads_process(reads, header, tags, strand=None, quality=27): - """Helper function to call :func:`call_consensus_from_reads` from a subprocess.""" +def call_consensus_from_reads_process( + reads, + header, + tags, + strand=None, + read_number=None, + shared_qname=None, + quality=27 +): + """ + Helper function to call :func:`call_consensus_from_reads`. + """ header = pysam.AlignmentHeader.from_dict(header) - reads = [pysam.AlignedSegment.fromstring(read, header) for read in reads] - consensus = call_consensus_from_reads(reads, header, quality=quality, tags=tags) - consensus.is_paired = False + pysam_reads = [pysam.AlignedSegment.fromstring(r, header) for r in reads] + + aln = call_consensus_from_reads( + pysam_reads, + header, + quality=quality, + tags=tags, + read_number=read_number, + shared_qname=shared_qname + ) if strand == '-': - consensus.is_reverse = True - return consensus.to_string() + aln.is_reverse = True + return aln.to_string() -def consensus_worker(args_q, results_q, *args, **kwargs): - """Multiprocessing worker.""" + +def consensus_worker(args_q, results_q, quality=27): + """ + Multiprocessing worker that reads tasks from args_q and calls + call_consensus_from_reads_process. + """ while True: try: - _args = args_q.get(timeout=1) # None means we are done. + _args = args_q.get(timeout=1) except queue.Empty: continue if _args is None: return + results_q.put(call_consensus_from_reads_process(*_args, quality=quality)) - results_q.put(call_consensus_from_reads_process(*_args, *args, **kwargs)) + +def get_read_number(read: pysam.AlignedSegment) -> str: + """ + Return 'R1' if read.is_read1 is True, else 'R2'. + For single-end reads, default to 'R1'. + """ + if read.is_paired: + return 'R1' if read.is_read1 else 'R2' + else: + return 'R1' def call_consensus( @@ -236,34 +289,46 @@ def call_consensus( quality: int = 27, add_RS_RI: bool = False, temp_dir: Optional[str] = None, - n_threads: int = 8 + n_threads: int = 8, + collapse_r1_r2: bool = False ) -> str: - """Call consensus sequences from BAM. + """ + Call consensus sequences from BAM. + + If collapse_r1_r2 is True, then R1 and R2 from the same UMI + are stored together in one group, producing a single consensus. + + If collapse_r1_r2 is False, R1 and R2 from the same UMI + are stored separately, producing two consensus reads (one for R1, one for R2). + + The final consensus read also aggregates is_proper_pair, is_reverse, + and mate_is_reverse from the source reads by majority vote. Args: bam_path: Path to BAM out_path: Output BAM path - gene_infos: Gene information, as parsed from the GTF + gene_infos: Gene info from GTF strand: Protocol strandedness - umi_tag: BAM tag containing the UMI - barcode_tag: BAM tag containing the barcode - gene_tag: BAM tag containing the assigned gene - barcodes: List of barcodes to consider - quality: Quality threshold - add_RS_RI: Add RS and RI BAM tags for debugging - temp_dir: Temporary directory - n_threads: Number of threads - - Returns: - Path to sorted and indexed consensus BAM + umi_tag: BAM tag for UMI + barcode_tag: BAM tag for barcode + gene_tag: BAM tag for assigned gene + barcodes: optional filter for barcodes + quality: consensus base quality threshold + add_RS_RI: optional debug tags + temp_dir: optional temp dir + n_threads: number of threads + collapse_r1_r2: if True, combine R1/R2. if False, produce separate consensus. """ - def skip_alignment(read, tags): - return read.is_secondary or read.is_unmapped or any(not read.has_tag(tag) for tag in tags) + def skip_alignment(read, required): + return ( + read.is_secondary + or read.is_unmapped + or any(not read.has_tag(t) for t in required) + ) def find_genes(contig, start, end, read_strand=None): genes = [] - # NOTE: contigs may not have any genes for gene in contig_gene_order.get(contig, []): if read_strand and read_strand != gene_infos[gene]['strand']: continue @@ -274,67 +339,70 @@ def find_genes(contig, start, end, read_strand=None): genes.append(gene) return genes - def swap_gene_tags(read, gene): - tags = dict(read.get_tags()) - if gene_tag and read.has_tag(gene_tag): - del tags[gene_tag] - tags['GX'] = gene + def swap_gene_tags(r, gene): + t = dict(r.get_tags()) + if gene_tag and r.has_tag(gene_tag): + del t[gene_tag] + t['GX'] = gene gn = gene_infos.get(gene, {}).get('gene_name') if gn: - tags['GN'] = gn - read.set_tags(list(tags.items())) - return read - - def create_tags_and_strand(barcode, umi, reads, gene_info): - tags = { - 'AS': sum(read.get_tag('AS') for read in reads), + t['GN'] = gn + r.set_tags(list(t.items())) + return r + + def create_tags_and_strand(barcode, umi, reads, ginfo): + as_sum = sum(r.get_tag('AS') for r in reads if r.has_tag('AS')) + tags_ = { + 'AS': as_sum, 'NH': 1, 'HI': 1, config.BAM_CONSENSUS_READ_COUNT_TAG: len(reads), } if barcode_tag: - tags[barcode_tag] = barcode + tags_[barcode_tag] = barcode if umi_tag: - tags[umi_tag] = umi + tags_[umi_tag] = umi if gene_tag: gene_id = None - for read in reads: - if read.has_tag(gene_tag): - gene_id = read.get_tag(gene_tag) + for rr in reads: + if rr.has_tag(gene_tag): + gene_id = rr.get_tag(gene_tag) break - if gene_id: - tags['GX'] = gene_id - gn = gene_info.get('gene_name') + tags_['GX'] = gene_id + gn = ginfo.get('gene_name') if gn: - tags['GN'] = gn + tags_['GN'] = gn + if add_RS_RI: - tags.update({ - 'RS': ';'.join(read.query_name for read in reads), - 'RI': ';'.join(str(read.get_tag('HI')) for read in reads), - }) - - # Figure out what strand the consensus should map to - gene_strand = gene_info['strand'] - consensus_strand = gene_strand + tags_['RS'] = ';'.join(r.query_name for r in reads) + tags_['RI'] = ';'.join(str(r.get_tag('HI')) if r.has_tag('HI') else '0' for r in reads) + + consensus_strand = None + gene_strand = ginfo['strand'] if strand == 'forward': consensus_strand = gene_strand elif strand == 'reverse': consensus_strand = '-' if gene_strand == '+' else '+' - return tags, consensus_strand + + return tags_, consensus_strand if add_RS_RI: - logger.warning('RS and RI tags will be added to the BAM. This may dramatically increase the BAM size.') + logger.warning("RS and RI tags may greatly increase BAM size.") + # Build index of genes per contig contig_gene_order = {} - for gene_id, gene_info in gene_infos.items(): - contig_gene_order.setdefault(gene_info['chromosome'], []).append(gene_id) + for gid, ginfo in gene_infos.items(): + contig_gene_order.setdefault(ginfo['chromosome'], []).append(gid) for contig in list(contig_gene_order.keys()): contig_gene_order[contig] = sorted( - contig_gene_order[contig], key=lambda gene: tuple(gene_infos[gene]['segment']) + contig_gene_order[contig], + key=lambda g: tuple(gene_infos[g]['segment']) ) + # We'll store groups as: gx_barcode_umi_groups[gene][barcode][umi][subkey] -> list_of_reads + # Where subkey is either 'ALL' (if collapsing) or 'R1'/'R2' (if separate) gx_barcode_umi_groups = {} paired = {} @@ -344,29 +412,35 @@ def create_tags_and_strand(barcode, umi, reads, gene_info): if barcode_tag: required_tags.append(barcode_tag) - # Start processes for consensus calling - logger.debug(f'Spawning {n_threads} processes') manager = multiprocessing.Manager() args_q = manager.Queue(1000 * n_threads) results_q = manager.Queue() + workers = [ multiprocessing.Process( - target=consensus_worker, args=(args_q, results_q), kwargs=dict(quality=quality), daemon=True - ) for _ in range(n_threads) + target=consensus_worker, + args=(args_q, results_q, quality), + daemon=True + ) + for _ in range(n_threads) ] - for worker in workers: - worker.start() + for w in workers: + w.start() temp_out_path = utils.mkstemp(dir=temp_dir) - with pysam.AlignmentFile(bam_path, 'rb') as f: - # Get header dict and update sort order to unsorted. - header_dict = f.header.to_dict() + with pysam.AlignmentFile(bam_path, 'rb') as f_in: + header_dict = f_in.header.to_dict() hd = header_dict.setdefault('HD', {'VN': '1.4', 'SO': 'unsorted'}) hd['SO'] = 'unsorted' header = pysam.AlignmentHeader.from_dict(header_dict) + + total_reads = ngs.bam.count_bam(bam_path) with pysam.AlignmentFile(temp_out_path, 'wb', header=header) as out: - for i, read in tqdm(enumerate(f.fetch()), total=ngs.bam.count_bam(bam_path), ascii=True, smoothing=0.01, - desc='Calling consensus'): + for i, read in tqdm( + enumerate(f_in.fetch()), + total=total_reads, ascii=True, smoothing=0.01, + desc="Calling consensus" + ): if skip_alignment(read, required_tags): continue @@ -374,121 +448,174 @@ def create_tags_and_strand(barcode, umi, reads, gene_info): if barcode == '-' or (barcodes and barcode not in barcodes): continue - contig = read.reference_name umi = read.get_tag(umi_tag) if umi_tag else None - read_id = read.query_name - alignment_index = read.get_tag('HI') - start = read.reference_start - end = read.reference_end + alignment_index = read.get_tag('HI') if read.has_tag('HI') else 1 key = (read_id, alignment_index) + mate = None if read.is_paired: if key not in paired: paired[key] = read continue - mate = paired.pop(key) - # Use alignment start and end as UMI for paired reads without UMI if not umi: - start = mate.reference_start - umi = (start, end) + umi = (mate.reference_start, mate.reference_end) - # Determine read strand + # Determine the gene + start = read.reference_start + end = read.reference_end read_strand = None - if read.is_paired: - if read.is_read1: # R1 is mapped after R2 + if strand in ('forward', 'reverse'): + if read.is_paired: if strand == 'forward': - read_strand = '+' if read.is_reverse else '-' - elif strand == 'reverse': - read_strand = '-' if read.is_reverse else '+' - else: # R1 is mapped before R2 + read_strand = '+' if not read.is_reverse else '-' + else: # 'reverse' + read_strand = '-' if not read.is_reverse else '+' + else: if strand == 'forward': read_strand = '-' if read.is_reverse else '+' - elif strand == 'reverse': + else: read_strand = '+' if read.is_reverse else '-' - elif strand == 'forward': - read_strand = '-' if read.is_reverse else '+' - elif strand == 'reverse': - read_strand = '+' if read.is_reverse else '-' - # Find compatible genes - gx_assigned = read.has_tag(gene_tag) if gene_tag else False - genes = [read.get_tag(gene_tag)] if gx_assigned else find_genes(contig, start, end, read_strand) + if gene_tag and read.has_tag(gene_tag): + genes = [read.get_tag(gene_tag)] + else: + contig = read.reference_name + genes = find_genes(contig, start, end, read_strand) - # If there isn't exactly one compatible gene, do nothing and - # write to BAM. if len(genes) != 1: + # If not exactly one gene, we just write the read directly out.write(read) - if read.is_paired: + if mate: out.write(mate) continue - # Add read to group - gx_barcode_umi_groups.setdefault(genes[0], {}).setdefault(barcode, {}).setdefault(umi, []).append(read) - if read.is_paired: - gx_barcode_umi_groups[genes[0]][barcode][umi].append(mate) + gene = genes[0] + if collapse_r1_r2: + subkey = 'ALL' + else: + from_read = 'R1' if read.is_read1 else 'R2' + subkey = from_read + + gx_barcode_umi_groups \ + .setdefault(gene, {}) \ + .setdefault(barcode, {}) \ + .setdefault(umi, {}) \ + .setdefault(subkey, []) \ + .append(read) + + if mate: + if collapse_r1_r2: + mate_subkey = 'ALL' + else: + mate_subkey = 'R1' if mate.is_read1 else 'R2' + gx_barcode_umi_groups[gene][barcode][umi] \ + .setdefault(mate_subkey, []) \ + .append(mate) + # Every 10k reads, flush old genes if i % 10000 == 0: - # Call consensus for gene's whose bodies we've fully passed. - leftmost_start = start if not paired else next(iter(paired.values())).reference_start - for gene in list(gx_barcode_umi_groups.keys()): - gene_info = gene_infos[gene] - gene_contig = gene_info['chromosome'] - gene_segment = gene_info['segment'] - if (gene_contig < contig) or (gene_contig == contig and gene_segment.end <= leftmost_start): - barcode_umi_groups = gx_barcode_umi_groups.pop(gene) - for barcode, umi_groups in barcode_umi_groups.items(): - for umi, reads in umi_groups.items(): - if len(reads) == 1: - out.write(swap_gene_tags(reads[0], gene)) - - tags, consensus_strand = create_tags_and_strand(barcode, umi, reads, gene_info) - - # Save for multiprocessing later. - args_q.put(([read.to_string() - for read in reads], header_dict, tags, consensus_strand)) - else: - break - - to_remove = 0 - for gene in contig_gene_order[contig]: - if gene_infos[gene]['segment'].end <= leftmost_start: - to_remove += 1 - else: - break - if to_remove > 0: - contig_gene_order[contig] = contig_gene_order[contig][to_remove:] - + leftmost = min( + read.reference_start, + mate.reference_start if mate else read.reference_start + ) + for g in list(gx_barcode_umi_groups.keys()): + ginfo = gene_infos[g] + gene_contig = ginfo['chromosome'] + gene_segment = ginfo['segment'] + if (gene_contig < read.reference_name) or ( + gene_contig == read.reference_name and gene_segment.end <= leftmost + ): + bc_map = gx_barcode_umi_groups.pop(g) + for bc, umi_map in bc_map.items(): + for this_umi, sub_map in umi_map.items(): + # Build a stable QNAME from everything in sub_map + all_names = [] + for k_, subreads in sub_map.items(): + all_names.extend(r.query_name for r in subreads) + shared_qname = sha256(''.join(all_names).encode('utf-8')).hexdigest() + + for subk, subreads in sub_map.items(): + if len(subreads) == 1: + single_read = swap_gene_tags(subreads[0], g) + single_read.query_name = shared_qname + # If not collapsing, see if subk=R1/R2 to set flags + if not collapse_r1_r2: + if subk == 'R1': + single_read.is_paired = True + single_read.is_read1 = True + single_read.is_read2 = False + elif subk == 'R2': + single_read.is_paired = True + single_read.is_read1 = False + single_read.is_read2 = True + out.write(single_read) + else: + tags_, cstrand = create_tags_and_strand(bc, this_umi, subreads, ginfo) + read_num = subk if not collapse_r1_r2 else None + args_q.put(( + [r.to_string() for r in subreads], + header_dict, + tags_, + cstrand, + read_num, + shared_qname + )) while True: try: result = results_q.get_nowait() if result: - consensus = pysam.AlignedSegment.fromstring(result, header) - out.write(consensus) + out.write(pysam.AlignedSegment.fromstring(result, header)) except queue.Empty: break - # Put remaining - for gene, barcode_umi_groups in gx_barcode_umi_groups.items(): - for barcode, umi_groups in barcode_umi_groups.items(): - for umi, reads in umi_groups.items(): - if len(reads) == 1: - out.write(swap_gene_tags(reads[0], gene)) - continue - - tags, consensus_strand = create_tags_and_strand(barcode, umi, reads, gene_infos[gene]) - args_q.put(([read.to_string() for read in reads], header_dict, tags, consensus_strand)) - - # Signal to workers to terminate once queue is depleted. + # Final flush + for g, bc_map in gx_barcode_umi_groups.items(): + ginfo = gene_infos[g] + for bc, umi_map in bc_map.items(): + for this_umi, sub_map in umi_map.items(): + all_names = [] + for subk, subreads in sub_map.items(): + all_names.extend(r.query_name for r in subreads) + shared_qname = sha256(''.join(all_names).encode('utf-8')).hexdigest() + + for subk, subreads in sub_map.items(): + if len(subreads) == 1: + single_read = swap_gene_tags(subreads[0], g) + single_read.query_name = shared_qname + if not collapse_r1_r2: + if subk == 'R1': + single_read.is_paired = True + single_read.is_read1 = True + single_read.is_read2 = False + elif subk == 'R2': + single_read.is_paired = True + single_read.is_read1 = False + single_read.is_read2 = True + out.write(single_read) + else: + tags_, cstrand = create_tags_and_strand(bc, this_umi, subreads, ginfo) + read_num = subk if not collapse_r1_r2 else None + args_q.put(( + [r.to_string() for r in subreads], + header_dict, + tags_, + cstrand, + read_num, + shared_qname + )) + + # Signal termination for _ in range(len(workers)): args_q.put(None) - for worker in workers: - worker.join() + for w in workers: + w.join() + # Gather last results while not results_q.empty(): result = results_q.get() - consensus = pysam.AlignedSegment.fromstring(result, header) - out.write(consensus) + out.write(pysam.AlignedSegment.fromstring(result, header)) + # Sort and index return bam.sort_and_index_bam(temp_out_path, out_path, n_threads=n_threads, temp_dir=temp_dir) diff --git a/setup.cfg b/setup.cfg index 5cf6f2b..ad2e2e2 100755 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.0.1 +current_version = 1.0.2.beta commit = True tag = True diff --git a/setup.py b/setup.py index 98bd733..59b7fb8 100755 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ def read(path): setup( name='dynast-release', - version='1.0.1', + version='1.0.2.beta', url='https://github.com/aristoteleo/dynast-release', author='Kyung Hoi (Joseph) Min', author_email='phoenixter96@gmail.com',