From ca0f9faf684e0616d37bca77fb8e78c8c39a4f38 Mon Sep 17 00:00:00 2001 From: qirong Date: Thu, 2 Jan 2025 12:55:30 +0100 Subject: [PATCH 1/3] New feature: allow user to select whether collapse R1 and R2 separately in dynast consensus --- dynast/__init__.py | 2 +- dynast/consensus.py | 20 +- dynast/main.py | 60 +-- dynast/preprocessing/consensus.py | 561 +++++++++++++++----------- dynast/preprocessing/consensus.py.old | 497 +++++++++++++++++++++++ setup.cfg | 2 +- setup.py | 2 +- 7 files changed, 880 insertions(+), 264 deletions(-) mode change 100755 => 100644 dynast/preprocessing/consensus.py create mode 100755 dynast/preprocessing/consensus.py.old diff --git a/dynast/__init__.py b/dynast/__init__.py index cd7ca49..6ba6abe 100755 --- a/dynast/__init__.py +++ b/dynast/__init__.py @@ -1 +1 @@ -__version__ = '1.0.1' +__version__ = '1.0.2.beta' diff --git a/dynast/consensus.py b/dynast/consensus.py index b7bcdef..f1f68f1 100755 --- a/dynast/consensus.py +++ b/dynast/consensus.py @@ -24,8 +24,10 @@ def consensus( add_RS_RI: bool = False, n_threads: int = 8, temp_dir: Optional[str] = None, + collapse_r1_r2: bool = False # <-- NEW parameter ): - """Main interface for the `consensus` command. + """ + Main interface for the `consensus` command. Args: bam_path: Path to BAM to call consensus from @@ -40,15 +42,19 @@ def consensus( add_RS_RI: Add RS and RI tags to BAM. Mostly useful for debugging. n_threads: Number of threads to use temp_dir: Temporary directory + collapse_r1_r2: If True, reads from R1 and R2 for the same UMI + will be combined into a single consensus read. If False, + R1 and R2 will each get their own consensus. """ stats = Stats() stats.start() stats_path = os.path.join( - out_dir, f'{constants.STATS_PREFIX}_{dt.datetime.strftime(stats.start_time, "%Y%m%d_%H%M%S_%f")}.json' + out_dir, + f'{constants.STATS_PREFIX}_{dt.datetime.strftime(stats.start_time, "%Y%m%d_%H%M%S_%f")}.json' ) os.makedirs(out_dir, exist_ok=True) - # Sort and index bam. + # Sort and index BAM bam_path = preprocessing.sort_and_index_bam( bam_path, '{}.sortedByCoord{}'.format(*os.path.splitext(bam_path)), n_threads=n_threads ) @@ -105,6 +111,10 @@ def consensus( consensus_path = os.path.join(out_dir, constants.CONSENSUS_BAM_FILENAME) logger.info(f'Calling consensus sequences from BAM to {consensus_path}') + + # ---------------------------------------------------------------------- + # Pass collapse_r1_r2 into preprocessing.call_consensus + # ---------------------------------------------------------------------- preprocessing.call_consensus( bam_path, consensus_path, @@ -117,7 +127,9 @@ def consensus( quality=quality, add_RS_RI=add_RS_RI, temp_dir=temp_dir, - n_threads=n_threads + n_threads=n_threads, + collapse_r1_r2=collapse_r1_r2 ) + stats.end() stats.save(stats_path) diff --git a/dynast/main.py b/dynast/main.py index d106c4d..9b5ac0d 100755 --- a/dynast/main.py +++ b/dynast/main.py @@ -167,16 +167,7 @@ def setup_align_args(parser: argparse.ArgumentParser, parent: argparse.ArgumentP def setup_consensus_args(parser: argparse.ArgumentParser, parent: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Helper function to set up a subparser for the `consensus` command. - - Args: - parser: Argparse parser to add the `consensus` command to - parent: Argparse parser parent of the newly added subcommand. - Used to inherit shared commands/flags - - Returns: - The newly added parser - """ + """Helper function to set up a subparser for the `consensus` command.""" parser_consensus = parser.add_parser( 'consensus', description='Generate consensus sequences', @@ -258,6 +249,22 @@ def setup_consensus_args(parser: argparse.ArgumentParser, parent: argparse.Argum ), action='store_true' ) + + # ---------------------------------------------------------------- + # NEW ARGUMENT to toggle separate R1/R2. Default is False => + # (which means R1/R2 are collapsed together internally). + # ---------------------------------------------------------------- + parser_consensus.add_argument( + '--separate-r1-r2', + action='store_true', + default=False, + help=( + 'If specified, R1 and R2 from the same UMI will be collapsed ' + 'separately (i.e. produce two consensus reads). By default, ' + 'they are collapsed into one consensus read.' + ), + ) + parser_consensus.add_argument( 'bam', help=( @@ -268,7 +275,6 @@ def setup_consensus_args(parser: argparse.ArgumentParser, parent: argparse.Argum ) return parser_consensus - def setup_count_args(parser: argparse.ArgumentParser, parent: argparse.ArgumentParser) -> argparse.ArgumentParser: """Helper function to set up a subparser for the `count` command. @@ -727,12 +733,9 @@ def parse_align(parser, args, temp_dir=None): def parse_consensus(parser: argparse.ArgumentParser, args: argparse.Namespace, temp_dir: Optional[str] = None): - """Parser for the `consensus` command. - - Args: - parser: The parser - args: Command-line arguments dictionary, as parsed by argparse - temp_dir: Temporary directory + """ + Parser for the `consensus` command. We also handle whether to + separate R1/R2 (based on `args.separate_r1_r2`). """ # Check quality if args.quality < 0 or args.quality > 41: @@ -759,11 +762,19 @@ def parse_consensus(parser: argparse.ArgumentParser, args: argparse.Namespace, t else: logger.info(f'Auto-detected strandedness: {strand}. Use `--strand` to override.') + # ------------------------------------------------------------------ + # Convert user-specified --separate-r1-r2 to collapse_r1_r2 bool. + # If separate_r1_r2 == True => we do NOT collapse them together => + # collapse_r1_r2 = False + # If separate_r1_r2 == False => we collapse them => True + # ------------------------------------------------------------------ + collapse_r1_r2 = not args.separate_r1_r2 + from .consensus import consensus consensus( - args.bam, - args.g, - args.o, + bam_path=args.bam, + gtf_path=args.g, + out_dir=args.o, strand=strand, umi_tag=args.umi_tag, barcode_tag=args.barcode_tag, @@ -773,6 +784,7 @@ def parse_consensus(parser: argparse.ArgumentParser, args: argparse.Namespace, t add_RS_RI=args.add_RS_RI, n_threads=args.t, temp_dir=temp_dir, + collapse_r1_r2=collapse_r1_r2, # <-- pass the bool down ) @@ -995,19 +1007,20 @@ def main(): metavar='', ) - # Add common options to this parent parser + # Common parent parser parent = argparse.ArgumentParser(add_help=False) parent.add_argument('--tmp', metavar='TMP', help='Override default temporary directory', type=str, default='tmp') parent.add_argument('--keep-tmp', help='Do not delete the tmp directory', action='store_true') parent.add_argument('--verbose', help='Print debugging information', action='store_true') parent.add_argument('-t', metavar='THREADS', help='Number of threads to use (default: 8)', type=int, default=8) - # Command parsers + # Create each subcommand parser parser_ref = setup_ref_args(subparsers, parent) parser_align = setup_align_args(subparsers, parent) parser_consensus = setup_consensus_args(subparsers, parent) parser_count = setup_count_args(subparsers, parent) parser_estimate = setup_estimate_args(subparsers, parent) + command_to_parser = { 'ref': parser_ref, 'align': parser_align, @@ -1015,10 +1028,10 @@ def main(): 'count': parser_count, 'estimate': parser_estimate, } + if '--list' in sys.argv: print_technologies() - # Show help when no arguments are given if len(sys.argv) == 1: parser.print_help(sys.stderr) sys.exit(1) @@ -1045,6 +1058,7 @@ def main(): ) os.makedirs(args.tmp) os.environ['NUMEXPR_MAX_THREADS'] = str(args.t) + try: patch_mp_connection_bpo_17560() # Monkeypatch for python 3.7 with warnings.catch_warnings(): diff --git a/dynast/preprocessing/consensus.py b/dynast/preprocessing/consensus.py old mode 100755 new mode 100644 index 8be4313..5a9e65e --- a/dynast/preprocessing/consensus.py +++ b/dynast/preprocessing/consensus.py @@ -23,84 +23,58 @@ def call_consensus_from_reads( header: pysam.AlignmentHeader, quality: int = 27, tags: Optional[Dict[str, Any]] = None, + read_number: Optional[str] = None, + shared_qname: Optional[str] = None ) -> pysam.AlignedSegment: - """Call a single consensus alignment given a list of aligned reads. + """ + Call a single consensus alignment given a list of aligned reads. Reads must map to the same contig. Results are undefined otherwise. Additionally, consensus bases are called only for positions that match to the reference (i.e. no insertions allowed). - This function only sets the minimal amount of attributes such that the - alignment is valid. These include: - * read name -- SHA256 hash of the provided read names - * read sequence and qualities - * reference name and ID - * reference start - * mapping quality (MAPQ) - * cigarstring - * MD tag - * NM tag - * Not unmapped, paired, duplicate, qc fail, secondary, nor supplementary - - The caller is expected to further populate the alignment - with additional tags, flags, and name. - - Args: - reads: List of reads to call a consensus sequence from - header: header to use when creating the new pysam alignment - quality: quality threshold - tags: additional tags to set - - Returns: - New pysam alignment of the consensus sequence + If read_number is 'R1' or 'R2', we set the resulting consensus read + to is_paired = True and is_read1/is_read2 = True accordingly, + preserving the same QNAME for both consensus reads (if shared_qname is given). """ - if len(set(read.reference_name for read in reads)) > 1: - raise Exception("Can not call consensus from reads mapping to multiple contigs.") + if len(set(r.reference_name for r in reads)) > 1: + raise Exception("Cannot call consensus from reads mapping to multiple contigs.") - # Pysam coordinates are [start, end) - left_pos = min(read.reference_start for read in reads) - right_pos = max(read.reference_end for read in reads) + left_pos = min(r.reference_start for r in reads) + right_pos = max(r.reference_end for r in reads) length = right_pos - left_pos - # A consensus sequence is internally represented as a L x 4 matrix, - # where L is the length of the sequence and the columns correspond to - # each of the four bases. The values indicate the support of each base. - # It's possible to switch these to sparse matrices if memory becomes an issue. sequence = np.zeros((length, len(BASES)), dtype=np.uint32) reference = np.full(length, -1, dtype=np.int8) # -1 means unobserved - deletions = 0 for read in reads: - read_sequence = read.query_sequence.upper() - read_qualities = read.query_qualities - for read_i, genome_i, _genome_base in read.get_aligned_pairs(matches_only=False, with_seq=True): - # Insertion - if genome_i is None or _genome_base is None: + read_seq = read.query_sequence.upper() + read_quals = read.query_qualities + for read_i, ref_i, ref_base in read.get_aligned_pairs(matches_only=False, with_seq=True): + if ref_i is None or ref_base is None: continue - i = genome_i - left_pos - genome_base = _genome_base.upper() - if genome_base == 'N': + i = ref_i - left_pos + ref_base = ref_base.upper() + if ref_base == 'N': continue - # Deletion if read_i is None: + # Deletion if reference[i] < 0: - reference[i] = BASE_IDX[genome_base] - deletions += 1 + reference[i] = BASE_IDX[ref_base] continue - - read_base = read_sequence[read_i] + read_base = read_seq[read_i] if read_base == 'N': continue - if reference[i] < 0: - reference[i] = BASE_IDX[genome_base] - sequence[i, BASE_IDX[read_base]] += read_qualities[read_i] + reference[i] = BASE_IDX[ref_base] + sequence[i, BASE_IDX[read_base]] += read_quals[read_i] - # Determine consensus - # Note that we ignore any insertions - consensus_length = (sequence > 0).any(axis=1).sum() + # Build consensus + consensus_mask = (sequence > 0).any(axis=1) + consensus_length = consensus_mask.sum() consensus = np.zeros(consensus_length, dtype=np.uint8) qualities = np.zeros(consensus_length, dtype=np.uint8) + cigar = [] last_cigar_op = None cigar_n = 0 @@ -110,74 +84,69 @@ def call_consensus_from_reads( md_del = False nm = 0 consensus_i = 0 + for i in range(length): - ref = reference[i] - # Region not present in read. MD tag only deals with aligned - # regions, so nothing else needs to be done. + ref_idx = reference[i] cigar_op = 'N' - if ref >= 0: + if ref_idx >= 0: seq = sequence[i] - - # Deletion if (seq == 0).all(): + # Deletion cigar_op = 'D' if md_n > 0 or md_zero: md.append(str(md_n)) md_n = 0 - if not md_del: md.append('^') - md.append(BASES[ref]) + md.append(BASES[ref_idx]) md_del = True - - # Match else: md_del = False - - # On ties, select reference if present. Otherwise, choose lexicographically. base_q = seq.max() if base_q < quality: - base = ref + base = ref_idx else: - bases = (seq == base_q).nonzero()[0] - if len(bases) > 0 and ref in bases: - base = ref + candidates = (seq == base_q).nonzero()[0] + if ref_idx in candidates: + base = ref_idx else: - base = bases[0] - - # We use the STAR convention of using M cigar operation to mean - # both matches AND mismatches, ignoring the X cigar operation exists. + base = candidates[0] cigar_op = 'M' - - if ref == base: + if base == ref_idx: md_n += 1 md_zero = False else: if md_n > 0 or md_zero: md.append(str(md_n)) md_n = 0 - md.append(BASES[ref]) + md.append(BASES[ref_idx]) md_zero = True nm += 1 - - consensus[consensus_i] = base - qualities[consensus_i] = min(base_q, 42) # Clip to maximum PHRED score - consensus_i += 1 + if consensus_mask[i]: + consensus[consensus_i] = base + qualities[consensus_i] = min(base_q, 42) + consensus_i += 1 if cigar_op == last_cigar_op: cigar_n += 1 else: if last_cigar_op: - cigar.append(f'{cigar_n}{last_cigar_op}') + cigar.append(f"{cigar_n}{last_cigar_op}") last_cigar_op = cigar_op cigar_n = 1 - md.append(str(md_n)) # MD tag always ends with a number - cigar.append(f'{cigar_n}{last_cigar_op}') + md.append(str(md_n)) + cigar.append(f"{cigar_n}{last_cigar_op}") al = pysam.AlignedSegment(header) - al.query_name = sha256(''.join(read.query_name for read in reads).encode('utf-8')).hexdigest() - al.query_sequence = ''.join(BASES[i] for i in consensus) + if shared_qname is not None: + al.query_name = shared_qname + else: + # default: hash together original QNAMEs + all_names = ''.join(r.query_name for r in reads) + al.query_name = sha256(all_names.encode('utf-8')).hexdigest() + + al.query_sequence = ''.join(BASES[b] for b in consensus) al.query_qualities = array.array('B', qualities) al.reference_name = reads[0].reference_name al.reference_id = reads[0].reference_id @@ -185,43 +154,84 @@ def call_consensus_from_reads( al.mapping_quality = 255 al.cigarstring = ''.join(cigar) - # Set tags + # Add tags tags = tags or {} tags.update({'MD': ''.join(md), 'NM': nm}) al.set_tags(list(tags.items())) - # Make sure these are False + # Mark R1 or R2 if provided + if read_number == 'R1': + al.is_paired = True + al.is_read1 = True + al.is_read2 = False + elif read_number == 'R2': + al.is_paired = True + al.is_read1 = False + al.is_read2 = True + else: + al.is_paired = False + + # Force false for these al.is_unmapped = False - al.is_paired = False al.is_duplicate = False al.is_qcfail = False al.is_secondary = False al.is_supplementary = False + return al -def call_consensus_from_reads_process(reads, header, tags, strand=None, quality=27): - """Helper function to call :func:`call_consensus_from_reads` from a subprocess.""" +def call_consensus_from_reads_process( + reads, + header, + tags, + strand=None, + read_number=None, + shared_qname=None, + quality=27 +): + """ + Helper for multiprocessing calls. + """ header = pysam.AlignmentHeader.from_dict(header) - reads = [pysam.AlignedSegment.fromstring(read, header) for read in reads] - consensus = call_consensus_from_reads(reads, header, quality=quality, tags=tags) - consensus.is_paired = False + pysam_reads = [pysam.AlignedSegment.fromstring(r, header) for r in reads] + + aln = call_consensus_from_reads( + pysam_reads, + header, + quality=quality, + tags=tags, + read_number=read_number, + shared_qname=shared_qname + ) if strand == '-': - consensus.is_reverse = True - return consensus.to_string() + aln.is_reverse = True + return aln.to_string() -def consensus_worker(args_q, results_q, *args, **kwargs): - """Multiprocessing worker.""" +def consensus_worker(args_q, results_q, quality=27): + """ + Worker that reads tasks from args_q, calls call_consensus_from_reads_process. + """ while True: try: - _args = args_q.get(timeout=1) # None means we are done. + _args = args_q.get(timeout=1) except queue.Empty: continue if _args is None: return + results_q.put(call_consensus_from_reads_process(*_args, quality=quality)) - results_q.put(call_consensus_from_reads_process(*_args, *args, **kwargs)) + +def get_read_number(read: pysam.AlignedSegment) -> str: + """ + Return 'R1' if read.is_read1 is True, else 'R2'. + For single-end reads, default to 'R1'. + """ + if read.is_paired: + return 'R1' if read.is_read1 else 'R2' + else: + return 'R1' def call_consensus( @@ -236,34 +246,43 @@ def call_consensus( quality: int = 27, add_RS_RI: bool = False, temp_dir: Optional[str] = None, - n_threads: int = 8 + n_threads: int = 8, + collapse_r1_r2: bool = False # <-- ### ADDED ) -> str: - """Call consensus sequences from BAM. + """ + Call consensus sequences from BAM. + + If collapse_r1_r2 is True, then R1 and R2 from the same UMI + are stored together in one group, producing a single consensus. + + If collapse_r1_r2 is False, R1 and R2 from the same UMI + are stored separately, producing two consensus reads (one for R1, one for R2). Args: bam_path: Path to BAM out_path: Output BAM path - gene_infos: Gene information, as parsed from the GTF + gene_infos: Gene info from GTF strand: Protocol strandedness - umi_tag: BAM tag containing the UMI - barcode_tag: BAM tag containing the barcode - gene_tag: BAM tag containing the assigned gene - barcodes: List of barcodes to consider - quality: Quality threshold - add_RS_RI: Add RS and RI BAM tags for debugging - temp_dir: Temporary directory - n_threads: Number of threads - - Returns: - Path to sorted and indexed consensus BAM + umi_tag: BAM tag for UMI + barcode_tag: BAM tag for barcode + gene_tag: BAM tag for assigned gene + barcodes: optional filter for barcodes + quality: consensus base quality threshold + add_RS_RI: optional debug tags + temp_dir: optional temp dir + n_threads: number of threads + collapse_r1_r2: if True, combine R1/R2. if False, produce separate consensus. """ - def skip_alignment(read, tags): - return read.is_secondary or read.is_unmapped or any(not read.has_tag(tag) for tag in tags) + def skip_alignment(read, required): + return ( + read.is_secondary + or read.is_unmapped + or any(not read.has_tag(t) for t in required) + ) def find_genes(contig, start, end, read_strand=None): genes = [] - # NOTE: contigs may not have any genes for gene in contig_gene_order.get(contig, []): if read_strand and read_strand != gene_infos[gene]['strand']: continue @@ -274,20 +293,21 @@ def find_genes(contig, start, end, read_strand=None): genes.append(gene) return genes - def swap_gene_tags(read, gene): - tags = dict(read.get_tags()) - if gene_tag and read.has_tag(gene_tag): - del tags[gene_tag] - tags['GX'] = gene + def swap_gene_tags(r, gene): + t = dict(r.get_tags()) + if gene_tag and r.has_tag(gene_tag): + del t[gene_tag] + t['GX'] = gene gn = gene_infos.get(gene, {}).get('gene_name') if gn: - tags['GN'] = gn - read.set_tags(list(tags.items())) - return read + t['GN'] = gn + r.set_tags(list(t.items())) + return r - def create_tags_and_strand(barcode, umi, reads, gene_info): + def create_tags_and_strand(barcode, umi, reads, ginfo): + as_sum = sum(r.get_tag('AS') for r in reads if r.has_tag('AS')) tags = { - 'AS': sum(read.get_tag('AS') for read in reads), + 'AS': as_sum, 'NH': 1, 'HI': 1, config.BAM_CONSENSUS_READ_COUNT_TAG: len(reads), @@ -299,42 +319,53 @@ def create_tags_and_strand(barcode, umi, reads, gene_info): if gene_tag: gene_id = None - for read in reads: - if read.has_tag(gene_tag): - gene_id = read.get_tag(gene_tag) + for rr in reads: + if rr.has_tag(gene_tag): + gene_id = rr.get_tag(gene_tag) break - if gene_id: tags['GX'] = gene_id - gn = gene_info.get('gene_name') + gn = ginfo.get('gene_name') if gn: tags['GN'] = gn + if add_RS_RI: - tags.update({ - 'RS': ';'.join(read.query_name for read in reads), - 'RI': ';'.join(str(read.get_tag('HI')) for read in reads), - }) - - # Figure out what strand the consensus should map to - gene_strand = gene_info['strand'] - consensus_strand = gene_strand + tags['RS'] = ';'.join(r.query_name for r in reads) + tags['RI'] = ';'.join( + str(r.get_tag('HI')) if r.has_tag('HI') else '0' for r in reads + ) + + consensus_strand = None + gene_strand = ginfo['strand'] if strand == 'forward': consensus_strand = gene_strand elif strand == 'reverse': consensus_strand = '-' if gene_strand == '+' else '+' + return tags, consensus_strand if add_RS_RI: - logger.warning('RS and RI tags will be added to the BAM. This may dramatically increase the BAM size.') + logger.warning("RS and RI tags may greatly increase BAM size.") + # Build index of genes per contig contig_gene_order = {} - for gene_id, gene_info in gene_infos.items(): - contig_gene_order.setdefault(gene_info['chromosome'], []).append(gene_id) + for gene_id, ginfo in gene_infos.items(): + contig_gene_order.setdefault(ginfo['chromosome'], []).append(gene_id) for contig in list(contig_gene_order.keys()): contig_gene_order[contig] = sorted( - contig_gene_order[contig], key=lambda gene: tuple(gene_infos[gene]['segment']) + contig_gene_order[contig], + key=lambda g: tuple(gene_infos[g]['segment']) ) + # ### CHANGED + # Instead of storing R1 and R2 in separate keys unconditionally, + # we let the user decide via 'collapse_r1_r2'. + # + # Data structure: + # gx_barcode_umi_groups[gene][barcode][umi][subkey] -> list_of_reads + # where 'subkey' is either: + # 'ALL' (if collapse_r1_r2=True) + # 'R1' or 'R2' (if collapse_r1_r2=False) gx_barcode_umi_groups = {} paired = {} @@ -344,29 +375,33 @@ def create_tags_and_strand(barcode, umi, reads, gene_info): if barcode_tag: required_tags.append(barcode_tag) - # Start processes for consensus calling - logger.debug(f'Spawning {n_threads} processes') manager = multiprocessing.Manager() args_q = manager.Queue(1000 * n_threads) results_q = manager.Queue() + workers = [ multiprocessing.Process( - target=consensus_worker, args=(args_q, results_q), kwargs=dict(quality=quality), daemon=True - ) for _ in range(n_threads) + target=consensus_worker, + args=(args_q, results_q, quality), + daemon=True + ) + for _ in range(n_threads) ] - for worker in workers: - worker.start() + for w in workers: + w.start() temp_out_path = utils.mkstemp(dir=temp_dir) - with pysam.AlignmentFile(bam_path, 'rb') as f: - # Get header dict and update sort order to unsorted. - header_dict = f.header.to_dict() + with pysam.AlignmentFile(bam_path, 'rb') as f_in: + header_dict = f_in.header.to_dict() hd = header_dict.setdefault('HD', {'VN': '1.4', 'SO': 'unsorted'}) hd['SO'] = 'unsorted' header = pysam.AlignmentHeader.from_dict(header_dict) + + total_reads = ngs.bam.count_bam(bam_path) with pysam.AlignmentFile(temp_out_path, 'wb', header=header) as out: - for i, read in tqdm(enumerate(f.fetch()), total=ngs.bam.count_bam(bam_path), ascii=True, smoothing=0.01, - desc='Calling consensus'): + for i, read in tqdm(enumerate(f_in.fetch()), + total=total_reads, ascii=True, smoothing=0.01, + desc="Calling consensus"): if skip_alignment(read, required_tags): continue @@ -374,121 +409,179 @@ def create_tags_and_strand(barcode, umi, reads, gene_info): if barcode == '-' or (barcodes and barcode not in barcodes): continue - contig = read.reference_name umi = read.get_tag(umi_tag) if umi_tag else None - read_id = read.query_name - alignment_index = read.get_tag('HI') - start = read.reference_start - end = read.reference_end + alignment_index = read.get_tag('HI') if read.has_tag('HI') else 1 key = (read_id, alignment_index) + + # Handle pairing mate = None if read.is_paired: if key not in paired: paired[key] = read continue - mate = paired.pop(key) - # Use alignment start and end as UMI for paired reads without UMI if not umi: - start = mate.reference_start - umi = (start, end) + umi = (mate.reference_start, mate.reference_end) - # Determine read strand + # Determine the gene + start = read.reference_start + end = read.reference_end read_strand = None - if read.is_paired: - if read.is_read1: # R1 is mapped after R2 + if strand in ('forward', 'reverse'): + if read.is_paired: if strand == 'forward': - read_strand = '+' if read.is_reverse else '-' - elif strand == 'reverse': - read_strand = '-' if read.is_reverse else '+' - else: # R1 is mapped before R2 + read_strand = '+' if not read.is_reverse else '-' + else: # 'reverse' + read_strand = '-' if not read.is_reverse else '+' + else: if strand == 'forward': read_strand = '-' if read.is_reverse else '+' - elif strand == 'reverse': + else: read_strand = '+' if read.is_reverse else '-' - elif strand == 'forward': - read_strand = '-' if read.is_reverse else '+' - elif strand == 'reverse': - read_strand = '+' if read.is_reverse else '-' - # Find compatible genes - gx_assigned = read.has_tag(gene_tag) if gene_tag else False - genes = [read.get_tag(gene_tag)] if gx_assigned else find_genes(contig, start, end, read_strand) + if gene_tag and read.has_tag(gene_tag): + genes = [read.get_tag(gene_tag)] + else: + contig = read.reference_name + genes = find_genes(contig, start, end, read_strand) - # If there isn't exactly one compatible gene, do nothing and - # write to BAM. if len(genes) != 1: + # If not exactly one gene, write raw out.write(read) - if read.is_paired: + if mate: out.write(mate) continue - # Add read to group - gx_barcode_umi_groups.setdefault(genes[0], {}).setdefault(barcode, {}).setdefault(umi, []).append(read) - if read.is_paired: - gx_barcode_umi_groups[genes[0]][barcode][umi].append(mate) - - if i % 10000 == 0: - # Call consensus for gene's whose bodies we've fully passed. - leftmost_start = start if not paired else next(iter(paired.values())).reference_start - for gene in list(gx_barcode_umi_groups.keys()): - gene_info = gene_infos[gene] - gene_contig = gene_info['chromosome'] - gene_segment = gene_info['segment'] - if (gene_contig < contig) or (gene_contig == contig and gene_segment.end <= leftmost_start): - barcode_umi_groups = gx_barcode_umi_groups.pop(gene) - for barcode, umi_groups in barcode_umi_groups.items(): - for umi, reads in umi_groups.items(): - if len(reads) == 1: - out.write(swap_gene_tags(reads[0], gene)) - - tags, consensus_strand = create_tags_and_strand(barcode, umi, reads, gene_info) - - # Save for multiprocessing later. - args_q.put(([read.to_string() - for read in reads], header_dict, tags, consensus_strand)) - else: - break + gene = genes[0] - to_remove = 0 - for gene in contig_gene_order[contig]: - if gene_infos[gene]['segment'].end <= leftmost_start: - to_remove += 1 - else: - break - if to_remove > 0: - contig_gene_order[contig] = contig_gene_order[contig][to_remove:] + # ### CHANGED + # Decide subkey = 'ALL' (if collapsing) or 'R1'/'R2' (if separating) + if collapse_r1_r2: + subkey = 'ALL' + else: + subkey = get_read_number(read) + + gx_barcode_umi_groups \ + .setdefault(gene, {}) \ + .setdefault(barcode, {}) \ + .setdefault(umi, {}) \ + .setdefault(subkey, []) \ + .append(read) + + # Also store mate if present + if mate: + if collapse_r1_r2: + mate_subkey = 'ALL' + else: + mate_subkey = get_read_number(mate) + gx_barcode_umi_groups[gene][barcode][umi] \ + .setdefault(mate_subkey, []) \ + .append(mate) + # Periodically flush old genes + if i % 10000 == 0: + leftmost = min( + read.reference_start, + mate.reference_start if mate else read.reference_start + ) + for g in list(gx_barcode_umi_groups.keys()): + ginfo = gene_infos[g] + gene_contig = ginfo['chromosome'] + gene_segment = ginfo['segment'] + if (gene_contig < read.reference_name) or ( + gene_contig == read.reference_name and gene_segment.end <= leftmost + ): + bc_map = gx_barcode_umi_groups.pop(g) + for bc, umi_map in bc_map.items(): + for this_umi, sub_map in umi_map.items(): + # Build a stable QNAME for everything in sub_map + all_names = [] + for subkey_, reads_list in sub_map.items(): + all_names.extend(r.query_name for r in reads_list) + shared_qname = sha256(''.join(all_names).encode('utf-8')).hexdigest() + + for subk, reads_list in sub_map.items(): + if len(reads_list) == 1: + single_read = swap_gene_tags(reads_list[0], g) + single_read.query_name = shared_qname + if not collapse_r1_r2: + # If not collapsing, check if subk = 'R1'/'R2' + if subk == 'R1': + single_read.is_paired = True + single_read.is_read1 = True + single_read.is_read2 = False + elif subk == 'R2': + single_read.is_paired = True + single_read.is_read1 = False + single_read.is_read2 = True + out.write(single_read) + else: + tags, cstrand = create_tags_and_strand(bc, this_umi, reads_list, ginfo) + # subk might be 'ALL' or 'R1'/'R2' + args_q.put(( + [r.to_string() for r in reads_list], + header_dict, + tags, + cstrand, + subk if not collapse_r1_r2 else None, # read_number if separate + shared_qname + )) + + # Drain queue while True: try: result = results_q.get_nowait() if result: - consensus = pysam.AlignedSegment.fromstring(result, header) - out.write(consensus) + out.write(pysam.AlignedSegment.fromstring(result, header)) except queue.Empty: break - # Put remaining - for gene, barcode_umi_groups in gx_barcode_umi_groups.items(): - for barcode, umi_groups in barcode_umi_groups.items(): - for umi, reads in umi_groups.items(): - if len(reads) == 1: - out.write(swap_gene_tags(reads[0], gene)) - continue - - tags, consensus_strand = create_tags_and_strand(barcode, umi, reads, gene_infos[gene]) - args_q.put(([read.to_string() for read in reads], header_dict, tags, consensus_strand)) - - # Signal to workers to terminate once queue is depleted. + # Final flush + for g, bc_map in gx_barcode_umi_groups.items(): + ginfo = gene_infos[g] + for bc, umi_map in bc_map.items(): + for this_umi, sub_map in umi_map.items(): + all_names = [] + for subk, reads_list in sub_map.items(): + all_names.extend(r.query_name for r in reads_list) + shared_qname = sha256(''.join(all_names).encode('utf-8')).hexdigest() + + for subk, reads_list in sub_map.items(): + if len(reads_list) == 1: + single_read = swap_gene_tags(reads_list[0], g) + single_read.query_name = shared_qname + if not collapse_r1_r2: + if subk == 'R1': + single_read.is_paired = True + single_read.is_read1 = True + single_read.is_read2 = False + elif subk == 'R2': + single_read.is_paired = True + single_read.is_read1 = False + single_read.is_read2 = True + out.write(single_read) + else: + tags, cstrand = create_tags_and_strand(bc, this_umi, reads_list, ginfo) + args_q.put(( + [r.to_string() for r in reads_list], + header_dict, + tags, + cstrand, + subk if not collapse_r1_r2 else None, + shared_qname + )) + + # Signal termination for _ in range(len(workers)): args_q.put(None) - for worker in workers: - worker.join() + for w in workers: + w.join() + # Gather last results while not results_q.empty(): result = results_q.get() - consensus = pysam.AlignedSegment.fromstring(result, header) - out.write(consensus) - # Sort and index + out.write(pysam.AlignedSegment.fromstring(result, header)) + + # Sort and index as usual return bam.sort_and_index_bam(temp_out_path, out_path, n_threads=n_threads, temp_dir=temp_dir) diff --git a/dynast/preprocessing/consensus.py.old b/dynast/preprocessing/consensus.py.old new file mode 100755 index 0000000..90b462e --- /dev/null +++ b/dynast/preprocessing/consensus.py.old @@ -0,0 +1,497 @@ +import array +import multiprocessing +import queue +from hashlib import sha256 +from typing import Any, Dict, List, Optional + +import ngs_tools as ngs +import numpy as np +import pysam +from tqdm import tqdm +from typing_extensions import Literal + +from .. import config, utils +from ..logging import logger +from . import bam + +BASES = ('A', 'C', 'G', 'T') +BASE_IDX = {base: i for i, base in enumerate(BASES)} + + +def call_consensus_from_reads( + reads: List[pysam.AlignedSegment], + header: pysam.AlignmentHeader, + quality: int = 27, + tags: Optional[Dict[str, Any]] = None, +) -> pysam.AlignedSegment: + """Call a single consensus alignment given a list of aligned reads. + + Reads must map to the same contig. Results are undefined otherwise. + Additionally, consensus bases are called only for positions that match + to the reference (i.e. no insertions allowed). + + This function only sets the minimal amount of attributes such that the + alignment is valid. These include: + * read name -- SHA256 hash of the provided read names + * read sequence and qualities + * reference name and ID + * reference start + * mapping quality (MAPQ) + * cigarstring + * MD tag + * NM tag + * Not unmapped, paired, duplicate, qc fail, secondary, nor supplementary + + The caller is expected to further populate the alignment + with additional tags, flags, and name. + + Args: + reads: List of reads to call a consensus sequence from + header: header to use when creating the new pysam alignment + quality: quality threshold + tags: additional tags to set + + Returns: + New pysam alignment of the consensus sequence + """ + if len(set(read.reference_name for read in reads)) > 1: + raise Exception("Can not call consensus from reads mapping to multiple contigs.") + + # Pysam coordinates are [start, end) + left_pos = min(read.reference_start for read in reads) + right_pos = max(read.reference_end for read in reads) + length = right_pos - left_pos + + # A consensus sequence is internally represented as a L x 4 matrix, + # where L is the length of the sequence and the columns correspond to + # each of the four bases. The values indicate the support of each base. + # It's possible to switch these to sparse matrices if memory becomes an issue. + sequence = np.zeros((length, len(BASES)), dtype=np.uint32) + reference = np.full(length, -1, dtype=np.int8) # -1 means unobserved + deletions = 0 + + for read in reads: + read_sequence = read.query_sequence.upper() + read_qualities = read.query_qualities + for read_i, genome_i, _genome_base in read.get_aligned_pairs(matches_only=False, with_seq=True): + # Insertion + if genome_i is None or _genome_base is None: + continue + i = genome_i - left_pos + genome_base = _genome_base.upper() + if genome_base == 'N': + continue + # Deletion + if read_i is None: + if reference[i] < 0: + reference[i] = BASE_IDX[genome_base] + deletions += 1 + continue + + read_base = read_sequence[read_i] + if read_base == 'N': + continue + + if reference[i] < 0: + reference[i] = BASE_IDX[genome_base] + sequence[i, BASE_IDX[read_base]] += read_qualities[read_i] + + # Determine consensus + # Note that we ignore any insertions + consensus_length = (sequence > 0).any(axis=1).sum() + consensus = np.zeros(consensus_length, dtype=np.uint8) + qualities = np.zeros(consensus_length, dtype=np.uint8) + cigar = [] + last_cigar_op = None + cigar_n = 0 + md = [] + md_n = 0 + md_zero = True + md_del = False + nm = 0 + consensus_i = 0 + for i in range(length): + ref = reference[i] + # Region not present in read. MD tag only deals with aligned + # regions, so nothing else needs to be done. + cigar_op = 'N' + if ref >= 0: + seq = sequence[i] + + # Deletion + if (seq == 0).all(): + cigar_op = 'D' + if md_n > 0 or md_zero: + md.append(str(md_n)) + md_n = 0 + + if not md_del: + md.append('^') + md.append(BASES[ref]) + md_del = True + + # Match + else: + md_del = False + + # On ties, select reference if present. Otherwise, choose lexicographically. + base_q = seq.max() + if base_q < quality: + base = ref + else: + bases = (seq == base_q).nonzero()[0] + if len(bases) > 0 and ref in bases: + base = ref + else: + base = bases[0] + + # We use the STAR convention of using M cigar operation to mean + # both matches AND mismatches, ignoring the X cigar operation exists. + cigar_op = 'M' + + if ref == base: + md_n += 1 + md_zero = False + else: + if md_n > 0 or md_zero: + md.append(str(md_n)) + md_n = 0 + md.append(BASES[ref]) + md_zero = True + nm += 1 + + consensus[consensus_i] = base + qualities[consensus_i] = min(base_q, 42) # Clip to maximum PHRED score + consensus_i += 1 + + if cigar_op == last_cigar_op: + cigar_n += 1 + else: + if last_cigar_op: + cigar.append(f'{cigar_n}{last_cigar_op}') + last_cigar_op = cigar_op + cigar_n = 1 + + md.append(str(md_n)) # MD tag always ends with a number + cigar.append(f'{cigar_n}{last_cigar_op}') + + al = pysam.AlignedSegment(header) + barcode_group = tags.get('CB', 'UNKNOWN_CB') + umi_group = tags.get('UB', 'UNKNOWN_UMI') + gene_group = tags.get('GX', 'UNKNOWN_GENE') + al.query_name = f"{barcode_group}_{umi_group}_{gene_group}" + al.query_sequence = ''.join(BASES[i] for i in consensus) + al.query_qualities = array.array('B', qualities) + al.reference_name = reads[0].reference_name + al.reference_id = reads[0].reference_id + al.reference_start = left_pos + al.mapping_quality = 255 + al.cigarstring = ''.join(cigar) + + # Set tags + tags = tags or {} + tags.update({'MD': ''.join(md), 'NM': nm}) + al.set_tags(list(tags.items())) + + # Make sure these are False + al.is_unmapped = False + al.is_paired = False + al.is_duplicate = False + al.is_qcfail = False + al.is_secondary = False + al.is_supplementary = False + return al + + +def call_consensus_from_reads_process(reads, header, tags, strand=None, quality=27): + """Helper function to call :func:`call_consensus_from_reads` from a subprocess.""" + header = pysam.AlignmentHeader.from_dict(header) + reads = [pysam.AlignedSegment.fromstring(read, header) for read in reads] + consensus = call_consensus_from_reads(reads, header, quality=quality, tags=tags) + consensus.is_paired = False + if strand == '-': + consensus.is_reverse = True + return consensus.to_string() + + +def consensus_worker(args_q, results_q, *args, **kwargs): + """Multiprocessing worker.""" + while True: + try: + _args = args_q.get(timeout=1) # None means we are done. + except queue.Empty: + continue + if _args is None: + return + + results_q.put(call_consensus_from_reads_process(*_args, *args, **kwargs)) + + +def call_consensus( + bam_path: str, + out_path: str, + gene_infos: dict, + strand: Literal['forward', 'reverse', 'unstranded'] = 'forward', + umi_tag: Optional[str] = None, + barcode_tag: Optional[str] = None, + gene_tag: str = 'GX', + barcodes: Optional[List[str]] = None, + quality: int = 27, + add_RS_RI: bool = False, + temp_dir: Optional[str] = None, + n_threads: int = 8 +) -> str: + """Call consensus sequences from BAM. + + Args: + bam_path: Path to BAM + out_path: Output BAM path + gene_infos: Gene information, as parsed from the GTF + strand: Protocol strandedness + umi_tag: BAM tag containing the UMI + barcode_tag: BAM tag containing the barcode + gene_tag: BAM tag containing the assigned gene + barcodes: List of barcodes to consider + quality: Quality threshold + add_RS_RI: Add RS and RI BAM tags for debugging + temp_dir: Temporary directory + n_threads: Number of threads + + Returns: + Path to sorted and indexed consensus BAM + """ + + def skip_alignment(read, tags): + return read.is_secondary or read.is_unmapped or any(not read.has_tag(tag) for tag in tags) + + def find_genes(contig, start, end, read_strand=None): + genes = [] + # NOTE: contigs may not have any genes + for gene in contig_gene_order.get(contig, []): + if read_strand and read_strand != gene_infos[gene]['strand']: + continue + gene_segment = gene_infos[gene]['segment'] + if end <= gene_segment.start: + break + if start >= gene_segment.start and end <= gene_segment.end: + genes.append(gene) + return genes + + def swap_gene_tags(read, gene): + tags = dict(read.get_tags()) + if gene_tag and read.has_tag(gene_tag): + del tags[gene_tag] + tags['GX'] = gene + gn = gene_infos.get(gene, {}).get('gene_name') + if gn: + tags['GN'] = gn + read.set_tags(list(tags.items())) + return read + + def create_tags_and_strand(barcode, umi, reads, gene_info): + tags = { + 'AS': sum(read.get_tag('AS') for read in reads), + 'NH': 1, + 'HI': 1, + config.BAM_CONSENSUS_READ_COUNT_TAG: len(reads), + } + if barcode_tag: + tags[barcode_tag] = barcode + if umi_tag: + tags[umi_tag] = umi + + if gene_tag: + gene_id = None + for read in reads: + if read.has_tag(gene_tag): + gene_id = read.get_tag(gene_tag) + break + + if gene_id: + tags['GX'] = gene_id + gn = gene_info.get('gene_name') + if gn: + tags['GN'] = gn + if add_RS_RI: + tags.update({ + 'RS': ';'.join(read.query_name for read in reads), + 'RI': ';'.join(str(read.get_tag('HI')) for read in reads), + }) + + # Figure out what strand the consensus should map to + gene_strand = gene_info['strand'] + consensus_strand = gene_strand + if strand == 'forward': + consensus_strand = gene_strand + elif strand == 'reverse': + consensus_strand = '-' if gene_strand == '+' else '+' + return tags, consensus_strand + + if add_RS_RI: + logger.warning('RS and RI tags will be added to the BAM. This may dramatically increase the BAM size.') + + contig_gene_order = {} + for gene_id, gene_info in gene_infos.items(): + contig_gene_order.setdefault(gene_info['chromosome'], []).append(gene_id) + for contig in list(contig_gene_order.keys()): + contig_gene_order[contig] = sorted( + contig_gene_order[contig], key=lambda gene: tuple(gene_infos[gene]['segment']) + ) + + gx_barcode_umi_groups = {} + paired = {} + + required_tags = [] + if umi_tag: + required_tags.append(umi_tag) + if barcode_tag: + required_tags.append(barcode_tag) + + # Start processes for consensus calling + logger.debug(f'Spawning {n_threads} processes') + manager = multiprocessing.Manager() + args_q = manager.Queue(1000 * n_threads) + results_q = manager.Queue() + workers = [ + multiprocessing.Process( + target=consensus_worker, args=(args_q, results_q), kwargs=dict(quality=quality), daemon=True + ) for _ in range(n_threads) + ] + for worker in workers: + worker.start() + + temp_out_path = utils.mkstemp(dir=temp_dir) + with pysam.AlignmentFile(bam_path, 'rb') as f: + # Get header dict and update sort order to unsorted. + header_dict = f.header.to_dict() + hd = header_dict.setdefault('HD', {'VN': '1.4', 'SO': 'unsorted'}) + hd['SO'] = 'unsorted' + header = pysam.AlignmentHeader.from_dict(header_dict) + with pysam.AlignmentFile(temp_out_path, 'wb', header=header) as out: + for i, read in tqdm(enumerate(f.fetch()), total=ngs.bam.count_bam(bam_path), ascii=True, smoothing=0.01, + desc='Calling consensus'): + if skip_alignment(read, required_tags): + continue + + barcode = read.get_tag(barcode_tag) if barcode_tag else None + if barcode == '-' or (barcodes and barcode not in barcodes): + continue + + contig = read.reference_name + umi = read.get_tag(umi_tag) if umi_tag else None + + read_id = read.query_name + alignment_index = read.get_tag('HI') + start = read.reference_start + end = read.reference_end + key = (read_id, alignment_index) + mate = None + if read.is_paired: + if key not in paired: + paired[key] = read + continue + + mate = paired.pop(key) + # Use alignment start and end as UMI for paired reads without UMI + if not umi: + start = mate.reference_start + umi = (start, end) + + # Determine read strand + read_strand = None + if read.is_paired: + if read.is_read1: # R1 is mapped after R2 + if strand == 'forward': + read_strand = '+' if read.is_reverse else '-' + elif strand == 'reverse': + read_strand = '-' if read.is_reverse else '+' + else: # R1 is mapped before R2 + if strand == 'forward': + read_strand = '-' if read.is_reverse else '+' + elif strand == 'reverse': + read_strand = '+' if read.is_reverse else '-' + elif strand == 'forward': + read_strand = '-' if read.is_reverse else '+' + elif strand == 'reverse': + read_strand = '+' if read.is_reverse else '-' + + # Find compatible genes + gx_assigned = read.has_tag(gene_tag) if gene_tag else False + genes = [read.get_tag(gene_tag)] if gx_assigned else find_genes(contig, start, end, read_strand) + + # If there isn't exactly one compatible gene, do nothing and + # write to BAM. + if len(genes) != 1: + out.write(read) + if read.is_paired: + out.write(mate) + continue + + # Add read to group + gx_barcode_umi_groups.setdefault(genes[0], {}).setdefault(barcode, {}).setdefault(umi, []).append(read) + if read.is_paired: + gx_barcode_umi_groups[genes[0]][barcode][umi].append(mate) + + if i % 10000 == 0: + # Call consensus for gene's whose bodies we've fully passed. + leftmost_start = start if not paired else next(iter(paired.values())).reference_start + for gene in list(gx_barcode_umi_groups.keys()): + gene_info = gene_infos[gene] + gene_contig = gene_info['chromosome'] + gene_segment = gene_info['segment'] + if (gene_contig < contig) or (gene_contig == contig and gene_segment.end <= leftmost_start): + barcode_umi_groups = gx_barcode_umi_groups.pop(gene) + for barcode, umi_groups in barcode_umi_groups.items(): + for umi, reads in umi_groups.items(): + if len(reads) == 1: + out.write(swap_gene_tags(reads[0], gene)) + + tags, consensus_strand = create_tags_and_strand(barcode, umi, reads, gene_info) + + # Save for multiprocessing later. + args_q.put(([read.to_string() + for read in reads], header_dict, tags, consensus_strand)) + else: + break + + to_remove = 0 + for gene in contig_gene_order[contig]: + if gene_infos[gene]['segment'].end <= leftmost_start: + to_remove += 1 + else: + break + if to_remove > 0: + contig_gene_order[contig] = contig_gene_order[contig][to_remove:] + + while True: + try: + result = results_q.get_nowait() + if result: + consensus = pysam.AlignedSegment.fromstring(result, header) + out.write(consensus) + except queue.Empty: + break + + # Put remaining + for gene, barcode_umi_groups in gx_barcode_umi_groups.items(): + for barcode, umi_groups in barcode_umi_groups.items(): + for umi, reads in umi_groups.items(): + if len(reads) == 1: + out.write(swap_gene_tags(reads[0], gene)) + continue + + tags, consensus_strand = create_tags_and_strand(barcode, umi, reads, gene_infos[gene]) + args_q.put(([read.to_string() for read in reads], header_dict, tags, consensus_strand)) + + # Signal to workers to terminate once queue is depleted. + for _ in range(len(workers)): + args_q.put(None) + for worker in workers: + worker.join() + + while not results_q.empty(): + result = results_q.get() + consensus = pysam.AlignedSegment.fromstring(result, header) + out.write(consensus) + # Sort and index + return bam.sort_and_index_bam(temp_out_path, out_path, n_threads=n_threads, temp_dir=temp_dir) diff --git a/setup.cfg b/setup.cfg index 5cf6f2b..ad2e2e2 100755 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.0.1 +current_version = 1.0.2.beta commit = True tag = True diff --git a/setup.py b/setup.py index 98bd733..59b7fb8 100755 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ def read(path): setup( name='dynast-release', - version='1.0.1', + version='1.0.2.beta', url='https://github.com/aristoteleo/dynast-release', author='Kyung Hoi (Joseph) Min', author_email='phoenixter96@gmail.com', From b6b453f786a14c0c97a15fc6a6e3ff7bf9c27698 Mon Sep 17 00:00:00 2001 From: qirong Date: Thu, 2 Jan 2025 12:58:37 +0100 Subject: [PATCH 2/3] Removed: delete temp file --- dynast/preprocessing/consensus.py.old | 497 -------------------------- 1 file changed, 497 deletions(-) delete mode 100755 dynast/preprocessing/consensus.py.old diff --git a/dynast/preprocessing/consensus.py.old b/dynast/preprocessing/consensus.py.old deleted file mode 100755 index 90b462e..0000000 --- a/dynast/preprocessing/consensus.py.old +++ /dev/null @@ -1,497 +0,0 @@ -import array -import multiprocessing -import queue -from hashlib import sha256 -from typing import Any, Dict, List, Optional - -import ngs_tools as ngs -import numpy as np -import pysam -from tqdm import tqdm -from typing_extensions import Literal - -from .. import config, utils -from ..logging import logger -from . import bam - -BASES = ('A', 'C', 'G', 'T') -BASE_IDX = {base: i for i, base in enumerate(BASES)} - - -def call_consensus_from_reads( - reads: List[pysam.AlignedSegment], - header: pysam.AlignmentHeader, - quality: int = 27, - tags: Optional[Dict[str, Any]] = None, -) -> pysam.AlignedSegment: - """Call a single consensus alignment given a list of aligned reads. - - Reads must map to the same contig. Results are undefined otherwise. - Additionally, consensus bases are called only for positions that match - to the reference (i.e. no insertions allowed). - - This function only sets the minimal amount of attributes such that the - alignment is valid. These include: - * read name -- SHA256 hash of the provided read names - * read sequence and qualities - * reference name and ID - * reference start - * mapping quality (MAPQ) - * cigarstring - * MD tag - * NM tag - * Not unmapped, paired, duplicate, qc fail, secondary, nor supplementary - - The caller is expected to further populate the alignment - with additional tags, flags, and name. - - Args: - reads: List of reads to call a consensus sequence from - header: header to use when creating the new pysam alignment - quality: quality threshold - tags: additional tags to set - - Returns: - New pysam alignment of the consensus sequence - """ - if len(set(read.reference_name for read in reads)) > 1: - raise Exception("Can not call consensus from reads mapping to multiple contigs.") - - # Pysam coordinates are [start, end) - left_pos = min(read.reference_start for read in reads) - right_pos = max(read.reference_end for read in reads) - length = right_pos - left_pos - - # A consensus sequence is internally represented as a L x 4 matrix, - # where L is the length of the sequence and the columns correspond to - # each of the four bases. The values indicate the support of each base. - # It's possible to switch these to sparse matrices if memory becomes an issue. - sequence = np.zeros((length, len(BASES)), dtype=np.uint32) - reference = np.full(length, -1, dtype=np.int8) # -1 means unobserved - deletions = 0 - - for read in reads: - read_sequence = read.query_sequence.upper() - read_qualities = read.query_qualities - for read_i, genome_i, _genome_base in read.get_aligned_pairs(matches_only=False, with_seq=True): - # Insertion - if genome_i is None or _genome_base is None: - continue - i = genome_i - left_pos - genome_base = _genome_base.upper() - if genome_base == 'N': - continue - # Deletion - if read_i is None: - if reference[i] < 0: - reference[i] = BASE_IDX[genome_base] - deletions += 1 - continue - - read_base = read_sequence[read_i] - if read_base == 'N': - continue - - if reference[i] < 0: - reference[i] = BASE_IDX[genome_base] - sequence[i, BASE_IDX[read_base]] += read_qualities[read_i] - - # Determine consensus - # Note that we ignore any insertions - consensus_length = (sequence > 0).any(axis=1).sum() - consensus = np.zeros(consensus_length, dtype=np.uint8) - qualities = np.zeros(consensus_length, dtype=np.uint8) - cigar = [] - last_cigar_op = None - cigar_n = 0 - md = [] - md_n = 0 - md_zero = True - md_del = False - nm = 0 - consensus_i = 0 - for i in range(length): - ref = reference[i] - # Region not present in read. MD tag only deals with aligned - # regions, so nothing else needs to be done. - cigar_op = 'N' - if ref >= 0: - seq = sequence[i] - - # Deletion - if (seq == 0).all(): - cigar_op = 'D' - if md_n > 0 or md_zero: - md.append(str(md_n)) - md_n = 0 - - if not md_del: - md.append('^') - md.append(BASES[ref]) - md_del = True - - # Match - else: - md_del = False - - # On ties, select reference if present. Otherwise, choose lexicographically. - base_q = seq.max() - if base_q < quality: - base = ref - else: - bases = (seq == base_q).nonzero()[0] - if len(bases) > 0 and ref in bases: - base = ref - else: - base = bases[0] - - # We use the STAR convention of using M cigar operation to mean - # both matches AND mismatches, ignoring the X cigar operation exists. - cigar_op = 'M' - - if ref == base: - md_n += 1 - md_zero = False - else: - if md_n > 0 or md_zero: - md.append(str(md_n)) - md_n = 0 - md.append(BASES[ref]) - md_zero = True - nm += 1 - - consensus[consensus_i] = base - qualities[consensus_i] = min(base_q, 42) # Clip to maximum PHRED score - consensus_i += 1 - - if cigar_op == last_cigar_op: - cigar_n += 1 - else: - if last_cigar_op: - cigar.append(f'{cigar_n}{last_cigar_op}') - last_cigar_op = cigar_op - cigar_n = 1 - - md.append(str(md_n)) # MD tag always ends with a number - cigar.append(f'{cigar_n}{last_cigar_op}') - - al = pysam.AlignedSegment(header) - barcode_group = tags.get('CB', 'UNKNOWN_CB') - umi_group = tags.get('UB', 'UNKNOWN_UMI') - gene_group = tags.get('GX', 'UNKNOWN_GENE') - al.query_name = f"{barcode_group}_{umi_group}_{gene_group}" - al.query_sequence = ''.join(BASES[i] for i in consensus) - al.query_qualities = array.array('B', qualities) - al.reference_name = reads[0].reference_name - al.reference_id = reads[0].reference_id - al.reference_start = left_pos - al.mapping_quality = 255 - al.cigarstring = ''.join(cigar) - - # Set tags - tags = tags or {} - tags.update({'MD': ''.join(md), 'NM': nm}) - al.set_tags(list(tags.items())) - - # Make sure these are False - al.is_unmapped = False - al.is_paired = False - al.is_duplicate = False - al.is_qcfail = False - al.is_secondary = False - al.is_supplementary = False - return al - - -def call_consensus_from_reads_process(reads, header, tags, strand=None, quality=27): - """Helper function to call :func:`call_consensus_from_reads` from a subprocess.""" - header = pysam.AlignmentHeader.from_dict(header) - reads = [pysam.AlignedSegment.fromstring(read, header) for read in reads] - consensus = call_consensus_from_reads(reads, header, quality=quality, tags=tags) - consensus.is_paired = False - if strand == '-': - consensus.is_reverse = True - return consensus.to_string() - - -def consensus_worker(args_q, results_q, *args, **kwargs): - """Multiprocessing worker.""" - while True: - try: - _args = args_q.get(timeout=1) # None means we are done. - except queue.Empty: - continue - if _args is None: - return - - results_q.put(call_consensus_from_reads_process(*_args, *args, **kwargs)) - - -def call_consensus( - bam_path: str, - out_path: str, - gene_infos: dict, - strand: Literal['forward', 'reverse', 'unstranded'] = 'forward', - umi_tag: Optional[str] = None, - barcode_tag: Optional[str] = None, - gene_tag: str = 'GX', - barcodes: Optional[List[str]] = None, - quality: int = 27, - add_RS_RI: bool = False, - temp_dir: Optional[str] = None, - n_threads: int = 8 -) -> str: - """Call consensus sequences from BAM. - - Args: - bam_path: Path to BAM - out_path: Output BAM path - gene_infos: Gene information, as parsed from the GTF - strand: Protocol strandedness - umi_tag: BAM tag containing the UMI - barcode_tag: BAM tag containing the barcode - gene_tag: BAM tag containing the assigned gene - barcodes: List of barcodes to consider - quality: Quality threshold - add_RS_RI: Add RS and RI BAM tags for debugging - temp_dir: Temporary directory - n_threads: Number of threads - - Returns: - Path to sorted and indexed consensus BAM - """ - - def skip_alignment(read, tags): - return read.is_secondary or read.is_unmapped or any(not read.has_tag(tag) for tag in tags) - - def find_genes(contig, start, end, read_strand=None): - genes = [] - # NOTE: contigs may not have any genes - for gene in contig_gene_order.get(contig, []): - if read_strand and read_strand != gene_infos[gene]['strand']: - continue - gene_segment = gene_infos[gene]['segment'] - if end <= gene_segment.start: - break - if start >= gene_segment.start and end <= gene_segment.end: - genes.append(gene) - return genes - - def swap_gene_tags(read, gene): - tags = dict(read.get_tags()) - if gene_tag and read.has_tag(gene_tag): - del tags[gene_tag] - tags['GX'] = gene - gn = gene_infos.get(gene, {}).get('gene_name') - if gn: - tags['GN'] = gn - read.set_tags(list(tags.items())) - return read - - def create_tags_and_strand(barcode, umi, reads, gene_info): - tags = { - 'AS': sum(read.get_tag('AS') for read in reads), - 'NH': 1, - 'HI': 1, - config.BAM_CONSENSUS_READ_COUNT_TAG: len(reads), - } - if barcode_tag: - tags[barcode_tag] = barcode - if umi_tag: - tags[umi_tag] = umi - - if gene_tag: - gene_id = None - for read in reads: - if read.has_tag(gene_tag): - gene_id = read.get_tag(gene_tag) - break - - if gene_id: - tags['GX'] = gene_id - gn = gene_info.get('gene_name') - if gn: - tags['GN'] = gn - if add_RS_RI: - tags.update({ - 'RS': ';'.join(read.query_name for read in reads), - 'RI': ';'.join(str(read.get_tag('HI')) for read in reads), - }) - - # Figure out what strand the consensus should map to - gene_strand = gene_info['strand'] - consensus_strand = gene_strand - if strand == 'forward': - consensus_strand = gene_strand - elif strand == 'reverse': - consensus_strand = '-' if gene_strand == '+' else '+' - return tags, consensus_strand - - if add_RS_RI: - logger.warning('RS and RI tags will be added to the BAM. This may dramatically increase the BAM size.') - - contig_gene_order = {} - for gene_id, gene_info in gene_infos.items(): - contig_gene_order.setdefault(gene_info['chromosome'], []).append(gene_id) - for contig in list(contig_gene_order.keys()): - contig_gene_order[contig] = sorted( - contig_gene_order[contig], key=lambda gene: tuple(gene_infos[gene]['segment']) - ) - - gx_barcode_umi_groups = {} - paired = {} - - required_tags = [] - if umi_tag: - required_tags.append(umi_tag) - if barcode_tag: - required_tags.append(barcode_tag) - - # Start processes for consensus calling - logger.debug(f'Spawning {n_threads} processes') - manager = multiprocessing.Manager() - args_q = manager.Queue(1000 * n_threads) - results_q = manager.Queue() - workers = [ - multiprocessing.Process( - target=consensus_worker, args=(args_q, results_q), kwargs=dict(quality=quality), daemon=True - ) for _ in range(n_threads) - ] - for worker in workers: - worker.start() - - temp_out_path = utils.mkstemp(dir=temp_dir) - with pysam.AlignmentFile(bam_path, 'rb') as f: - # Get header dict and update sort order to unsorted. - header_dict = f.header.to_dict() - hd = header_dict.setdefault('HD', {'VN': '1.4', 'SO': 'unsorted'}) - hd['SO'] = 'unsorted' - header = pysam.AlignmentHeader.from_dict(header_dict) - with pysam.AlignmentFile(temp_out_path, 'wb', header=header) as out: - for i, read in tqdm(enumerate(f.fetch()), total=ngs.bam.count_bam(bam_path), ascii=True, smoothing=0.01, - desc='Calling consensus'): - if skip_alignment(read, required_tags): - continue - - barcode = read.get_tag(barcode_tag) if barcode_tag else None - if barcode == '-' or (barcodes and barcode not in barcodes): - continue - - contig = read.reference_name - umi = read.get_tag(umi_tag) if umi_tag else None - - read_id = read.query_name - alignment_index = read.get_tag('HI') - start = read.reference_start - end = read.reference_end - key = (read_id, alignment_index) - mate = None - if read.is_paired: - if key not in paired: - paired[key] = read - continue - - mate = paired.pop(key) - # Use alignment start and end as UMI for paired reads without UMI - if not umi: - start = mate.reference_start - umi = (start, end) - - # Determine read strand - read_strand = None - if read.is_paired: - if read.is_read1: # R1 is mapped after R2 - if strand == 'forward': - read_strand = '+' if read.is_reverse else '-' - elif strand == 'reverse': - read_strand = '-' if read.is_reverse else '+' - else: # R1 is mapped before R2 - if strand == 'forward': - read_strand = '-' if read.is_reverse else '+' - elif strand == 'reverse': - read_strand = '+' if read.is_reverse else '-' - elif strand == 'forward': - read_strand = '-' if read.is_reverse else '+' - elif strand == 'reverse': - read_strand = '+' if read.is_reverse else '-' - - # Find compatible genes - gx_assigned = read.has_tag(gene_tag) if gene_tag else False - genes = [read.get_tag(gene_tag)] if gx_assigned else find_genes(contig, start, end, read_strand) - - # If there isn't exactly one compatible gene, do nothing and - # write to BAM. - if len(genes) != 1: - out.write(read) - if read.is_paired: - out.write(mate) - continue - - # Add read to group - gx_barcode_umi_groups.setdefault(genes[0], {}).setdefault(barcode, {}).setdefault(umi, []).append(read) - if read.is_paired: - gx_barcode_umi_groups[genes[0]][barcode][umi].append(mate) - - if i % 10000 == 0: - # Call consensus for gene's whose bodies we've fully passed. - leftmost_start = start if not paired else next(iter(paired.values())).reference_start - for gene in list(gx_barcode_umi_groups.keys()): - gene_info = gene_infos[gene] - gene_contig = gene_info['chromosome'] - gene_segment = gene_info['segment'] - if (gene_contig < contig) or (gene_contig == contig and gene_segment.end <= leftmost_start): - barcode_umi_groups = gx_barcode_umi_groups.pop(gene) - for barcode, umi_groups in barcode_umi_groups.items(): - for umi, reads in umi_groups.items(): - if len(reads) == 1: - out.write(swap_gene_tags(reads[0], gene)) - - tags, consensus_strand = create_tags_and_strand(barcode, umi, reads, gene_info) - - # Save for multiprocessing later. - args_q.put(([read.to_string() - for read in reads], header_dict, tags, consensus_strand)) - else: - break - - to_remove = 0 - for gene in contig_gene_order[contig]: - if gene_infos[gene]['segment'].end <= leftmost_start: - to_remove += 1 - else: - break - if to_remove > 0: - contig_gene_order[contig] = contig_gene_order[contig][to_remove:] - - while True: - try: - result = results_q.get_nowait() - if result: - consensus = pysam.AlignedSegment.fromstring(result, header) - out.write(consensus) - except queue.Empty: - break - - # Put remaining - for gene, barcode_umi_groups in gx_barcode_umi_groups.items(): - for barcode, umi_groups in barcode_umi_groups.items(): - for umi, reads in umi_groups.items(): - if len(reads) == 1: - out.write(swap_gene_tags(reads[0], gene)) - continue - - tags, consensus_strand = create_tags_and_strand(barcode, umi, reads, gene_infos[gene]) - args_q.put(([read.to_string() for read in reads], header_dict, tags, consensus_strand)) - - # Signal to workers to terminate once queue is depleted. - for _ in range(len(workers)): - args_q.put(None) - for worker in workers: - worker.join() - - while not results_q.empty(): - result = results_q.get() - consensus = pysam.AlignedSegment.fromstring(result, header) - out.write(consensus) - # Sort and index - return bam.sort_and_index_bam(temp_out_path, out_path, n_threads=n_threads, temp_dir=temp_dir) From 2abe62e2dd07e1991f14199068ace7721e73c931 Mon Sep 17 00:00:00 2001 From: qirong Date: Wed, 22 Jan 2025 07:24:15 +0100 Subject: [PATCH 3/3] bug fixed: corrected sam flag after R1/R2 UMI collapsing --- dynast/preprocessing/consensus.py | 166 ++++++++++++++++++------------ 1 file changed, 100 insertions(+), 66 deletions(-) diff --git a/dynast/preprocessing/consensus.py b/dynast/preprocessing/consensus.py index 5a9e65e..719878a 100644 --- a/dynast/preprocessing/consensus.py +++ b/dynast/preprocessing/consensus.py @@ -18,6 +18,20 @@ BASE_IDX = {base: i for i, base in enumerate(BASES)} +def _majority_flag(reads: List[pysam.AlignedSegment], attr: str) -> bool: + """ + Return True if more than half of the reads have getattr(r, attr) == True. + Otherwise return False. + + E.g.: + _majority_flag(reads, 'is_proper_pair') + _majority_flag(reads, 'is_reverse') + _majority_flag(reads, 'mate_is_reverse') + """ + count_true = sum(1 for r in reads if getattr(r, attr)) + return count_true > (len(reads) / 2) + + def call_consensus_from_reads( reads: List[pysam.AlignedSegment], header: pysam.AlignmentHeader, @@ -36,14 +50,22 @@ def call_consensus_from_reads( If read_number is 'R1' or 'R2', we set the resulting consensus read to is_paired = True and is_read1/is_read2 = True accordingly, preserving the same QNAME for both consensus reads (if shared_qname is given). + + We also aggregate the following flags by majority vote across `reads`: + - is_proper_pair + - is_reverse + - mate_is_reverse """ + # Validate all reads on the same reference if len(set(r.reference_name for r in reads)) > 1: raise Exception("Cannot call consensus from reads mapping to multiple contigs.") + # Determine leftmost and rightmost positions left_pos = min(r.reference_start for r in reads) right_pos = max(r.reference_end for r in reads) length = right_pos - left_pos + # Build a matrix for base counts sequence = np.zeros((length, len(BASES)), dtype=np.uint32) reference = np.full(length, -1, dtype=np.int8) # -1 means unobserved @@ -52,13 +74,14 @@ def call_consensus_from_reads( read_quals = read.query_qualities for read_i, ref_i, ref_base in read.get_aligned_pairs(matches_only=False, with_seq=True): if ref_i is None or ref_base is None: + # Insertion or soft-clip, etc. continue i = ref_i - left_pos ref_base = ref_base.upper() if ref_base == 'N': continue if read_i is None: - # Deletion + # This is a deletion from the read if reference[i] < 0: reference[i] = BASE_IDX[ref_base] continue @@ -69,7 +92,7 @@ def call_consensus_from_reads( reference[i] = BASE_IDX[ref_base] sequence[i, BASE_IDX[read_base]] += read_quals[read_i] - # Build consensus + # Now determine the consensus sequence consensus_mask = (sequence > 0).any(axis=1) consensus_length = consensus_mask.sum() consensus = np.zeros(consensus_length, dtype=np.uint8) @@ -104,9 +127,11 @@ def call_consensus_from_reads( md_del = False base_q = seq.max() if base_q < quality: + # If no base surpasses the threshold, choose reference base = ref_idx else: candidates = (seq == base_q).nonzero()[0] + # On tie: if reference is in that tie, choose reference if ref_idx in candidates: base = ref_idx else: @@ -124,7 +149,7 @@ def call_consensus_from_reads( nm += 1 if consensus_mask[i]: consensus[consensus_i] = base - qualities[consensus_i] = min(base_q, 42) + qualities[consensus_i] = min(base_q, 42) # cap at PHRED 42 consensus_i += 1 if cigar_op == last_cigar_op: @@ -135,14 +160,16 @@ def call_consensus_from_reads( last_cigar_op = cigar_op cigar_n = 1 + # Finish up the last cigar chunk md.append(str(md_n)) cigar.append(f"{cigar_n}{last_cigar_op}") + # Create a new pysam AlignedSegment al = pysam.AlignedSegment(header) if shared_qname is not None: al.query_name = shared_qname else: - # default: hash together original QNAMEs + # default: hash all read names all_names = ''.join(r.query_name for r in reads) al.query_name = sha256(all_names.encode('utf-8')).hexdigest() @@ -154,12 +181,12 @@ def call_consensus_from_reads( al.mapping_quality = 255 al.cigarstring = ''.join(cigar) - # Add tags + # Set tags tags = tags or {} tags.update({'MD': ''.join(md), 'NM': nm}) al.set_tags(list(tags.items())) - # Mark R1 or R2 if provided + # Mark read1 or read2 if provided if read_number == 'R1': al.is_paired = True al.is_read1 = True @@ -171,6 +198,20 @@ def call_consensus_from_reads( else: al.is_paired = False + # -------------------------------------------------------------- + # NEW: Aggregate certain flags with a "majority vote" approach: + # is_proper_pair + # is_reverse + # mate_is_reverse + # -------------------------------------------------------------- + al.is_proper_pair = _majority_flag(reads, 'is_proper_pair') + # If this consensus read is designated as reversed by majority: + al.is_reverse = _majority_flag(reads, 'is_reverse') + + # Only relevant if read is paired + if al.is_paired: + al.mate_is_reverse = _majority_flag(reads, 'mate_is_reverse') + # Force false for these al.is_unmapped = False al.is_duplicate = False @@ -191,7 +232,7 @@ def call_consensus_from_reads_process( quality=27 ): """ - Helper for multiprocessing calls. + Helper function to call :func:`call_consensus_from_reads`. """ header = pysam.AlignmentHeader.from_dict(header) pysam_reads = [pysam.AlignedSegment.fromstring(r, header) for r in reads] @@ -206,12 +247,14 @@ def call_consensus_from_reads_process( ) if strand == '-': aln.is_reverse = True + return aln.to_string() def consensus_worker(args_q, results_q, quality=27): """ - Worker that reads tasks from args_q, calls call_consensus_from_reads_process. + Multiprocessing worker that reads tasks from args_q and calls + call_consensus_from_reads_process. """ while True: try: @@ -247,7 +290,7 @@ def call_consensus( add_RS_RI: bool = False, temp_dir: Optional[str] = None, n_threads: int = 8, - collapse_r1_r2: bool = False # <-- ### ADDED + collapse_r1_r2: bool = False ) -> str: """ Call consensus sequences from BAM. @@ -258,6 +301,9 @@ def call_consensus( If collapse_r1_r2 is False, R1 and R2 from the same UMI are stored separately, producing two consensus reads (one for R1, one for R2). + The final consensus read also aggregates is_proper_pair, is_reverse, + and mate_is_reverse from the source reads by majority vote. + Args: bam_path: Path to BAM out_path: Output BAM path @@ -306,16 +352,16 @@ def swap_gene_tags(r, gene): def create_tags_and_strand(barcode, umi, reads, ginfo): as_sum = sum(r.get_tag('AS') for r in reads if r.has_tag('AS')) - tags = { + tags_ = { 'AS': as_sum, 'NH': 1, 'HI': 1, config.BAM_CONSENSUS_READ_COUNT_TAG: len(reads), } if barcode_tag: - tags[barcode_tag] = barcode + tags_[barcode_tag] = barcode if umi_tag: - tags[umi_tag] = umi + tags_[umi_tag] = umi if gene_tag: gene_id = None @@ -324,16 +370,14 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): gene_id = rr.get_tag(gene_tag) break if gene_id: - tags['GX'] = gene_id + tags_['GX'] = gene_id gn = ginfo.get('gene_name') if gn: - tags['GN'] = gn + tags_['GN'] = gn if add_RS_RI: - tags['RS'] = ';'.join(r.query_name for r in reads) - tags['RI'] = ';'.join( - str(r.get_tag('HI')) if r.has_tag('HI') else '0' for r in reads - ) + tags_['RS'] = ';'.join(r.query_name for r in reads) + tags_['RI'] = ';'.join(str(r.get_tag('HI')) if r.has_tag('HI') else '0' for r in reads) consensus_strand = None gene_strand = ginfo['strand'] @@ -342,30 +386,23 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): elif strand == 'reverse': consensus_strand = '-' if gene_strand == '+' else '+' - return tags, consensus_strand + return tags_, consensus_strand if add_RS_RI: logger.warning("RS and RI tags may greatly increase BAM size.") # Build index of genes per contig contig_gene_order = {} - for gene_id, ginfo in gene_infos.items(): - contig_gene_order.setdefault(ginfo['chromosome'], []).append(gene_id) + for gid, ginfo in gene_infos.items(): + contig_gene_order.setdefault(ginfo['chromosome'], []).append(gid) for contig in list(contig_gene_order.keys()): contig_gene_order[contig] = sorted( contig_gene_order[contig], key=lambda g: tuple(gene_infos[g]['segment']) ) - # ### CHANGED - # Instead of storing R1 and R2 in separate keys unconditionally, - # we let the user decide via 'collapse_r1_r2'. - # - # Data structure: - # gx_barcode_umi_groups[gene][barcode][umi][subkey] -> list_of_reads - # where 'subkey' is either: - # 'ALL' (if collapse_r1_r2=True) - # 'R1' or 'R2' (if collapse_r1_r2=False) + # We'll store groups as: gx_barcode_umi_groups[gene][barcode][umi][subkey] -> list_of_reads + # Where subkey is either 'ALL' (if collapsing) or 'R1'/'R2' (if separate) gx_barcode_umi_groups = {} paired = {} @@ -399,9 +436,11 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): total_reads = ngs.bam.count_bam(bam_path) with pysam.AlignmentFile(temp_out_path, 'wb', header=header) as out: - for i, read in tqdm(enumerate(f_in.fetch()), - total=total_reads, ascii=True, smoothing=0.01, - desc="Calling consensus"): + for i, read in tqdm( + enumerate(f_in.fetch()), + total=total_reads, ascii=True, smoothing=0.01, + desc="Calling consensus" + ): if skip_alignment(read, required_tags): continue @@ -414,7 +453,6 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): alignment_index = read.get_tag('HI') if read.has_tag('HI') else 1 key = (read_id, alignment_index) - # Handle pairing mate = None if read.is_paired: if key not in paired: @@ -447,20 +485,18 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): genes = find_genes(contig, start, end, read_strand) if len(genes) != 1: - # If not exactly one gene, write raw + # If not exactly one gene, we just write the read directly out.write(read) if mate: out.write(mate) continue gene = genes[0] - - # ### CHANGED - # Decide subkey = 'ALL' (if collapsing) or 'R1'/'R2' (if separating) if collapse_r1_r2: subkey = 'ALL' else: - subkey = get_read_number(read) + from_read = 'R1' if read.is_read1 else 'R2' + subkey = from_read gx_barcode_umi_groups \ .setdefault(gene, {}) \ @@ -469,17 +505,16 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): .setdefault(subkey, []) \ .append(read) - # Also store mate if present if mate: if collapse_r1_r2: mate_subkey = 'ALL' else: - mate_subkey = get_read_number(mate) + mate_subkey = 'R1' if mate.is_read1 else 'R2' gx_barcode_umi_groups[gene][barcode][umi] \ .setdefault(mate_subkey, []) \ .append(mate) - # Periodically flush old genes + # Every 10k reads, flush old genes if i % 10000 == 0: leftmost = min( read.reference_start, @@ -495,18 +530,18 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): bc_map = gx_barcode_umi_groups.pop(g) for bc, umi_map in bc_map.items(): for this_umi, sub_map in umi_map.items(): - # Build a stable QNAME for everything in sub_map + # Build a stable QNAME from everything in sub_map all_names = [] - for subkey_, reads_list in sub_map.items(): - all_names.extend(r.query_name for r in reads_list) + for k_, subreads in sub_map.items(): + all_names.extend(r.query_name for r in subreads) shared_qname = sha256(''.join(all_names).encode('utf-8')).hexdigest() - for subk, reads_list in sub_map.items(): - if len(reads_list) == 1: - single_read = swap_gene_tags(reads_list[0], g) + for subk, subreads in sub_map.items(): + if len(subreads) == 1: + single_read = swap_gene_tags(subreads[0], g) single_read.query_name = shared_qname + # If not collapsing, see if subk=R1/R2 to set flags if not collapse_r1_r2: - # If not collapsing, check if subk = 'R1'/'R2' if subk == 'R1': single_read.is_paired = True single_read.is_read1 = True @@ -517,18 +552,16 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): single_read.is_read2 = True out.write(single_read) else: - tags, cstrand = create_tags_and_strand(bc, this_umi, reads_list, ginfo) - # subk might be 'ALL' or 'R1'/'R2' + tags_, cstrand = create_tags_and_strand(bc, this_umi, subreads, ginfo) + read_num = subk if not collapse_r1_r2 else None args_q.put(( - [r.to_string() for r in reads_list], + [r.to_string() for r in subreads], header_dict, - tags, + tags_, cstrand, - subk if not collapse_r1_r2 else None, # read_number if separate + read_num, shared_qname )) - - # Drain queue while True: try: result = results_q.get_nowait() @@ -543,13 +576,13 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): for bc, umi_map in bc_map.items(): for this_umi, sub_map in umi_map.items(): all_names = [] - for subk, reads_list in sub_map.items(): - all_names.extend(r.query_name for r in reads_list) + for subk, subreads in sub_map.items(): + all_names.extend(r.query_name for r in subreads) shared_qname = sha256(''.join(all_names).encode('utf-8')).hexdigest() - for subk, reads_list in sub_map.items(): - if len(reads_list) == 1: - single_read = swap_gene_tags(reads_list[0], g) + for subk, subreads in sub_map.items(): + if len(subreads) == 1: + single_read = swap_gene_tags(subreads[0], g) single_read.query_name = shared_qname if not collapse_r1_r2: if subk == 'R1': @@ -562,13 +595,14 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): single_read.is_read2 = True out.write(single_read) else: - tags, cstrand = create_tags_and_strand(bc, this_umi, reads_list, ginfo) + tags_, cstrand = create_tags_and_strand(bc, this_umi, subreads, ginfo) + read_num = subk if not collapse_r1_r2 else None args_q.put(( - [r.to_string() for r in reads_list], + [r.to_string() for r in subreads], header_dict, - tags, + tags_, cstrand, - subk if not collapse_r1_r2 else None, + read_num, shared_qname )) @@ -583,5 +617,5 @@ def create_tags_and_strand(barcode, umi, reads, ginfo): result = results_q.get() out.write(pysam.AlignedSegment.fromstring(result, header)) - # Sort and index as usual + # Sort and index return bam.sort_and_index_bam(temp_out_path, out_path, n_threads=n_threads, temp_dir=temp_dir)