|
| 1 | +import sys |
| 2 | +from concurrent.futures import ThreadPoolExecutor |
| 3 | + |
| 4 | +import polars as pl |
| 5 | +from tqdm import tqdm |
| 6 | +import pyarrow.parquet as pq |
| 7 | + |
| 8 | +import tdb |
| 9 | + |
| 10 | +def process_sample(path): |
| 11 | + """ |
| 12 | + Open a sample and generate its allele count |
| 13 | + """ |
| 14 | + pf = pq.ParquetFile(path) |
| 15 | + sample_data = pl.from_arrow(pf.read(columns=['LocusID', 'allele_number'])) |
| 16 | + return sample_data.group_by(['LocusID', 'allele_number']).len().rename({"len": "AC"}) |
| 17 | + |
| 18 | +def merge_in(data, results): |
| 19 | + """ |
| 20 | + Given a batch of samples' counts, consolidate into main data |
| 21 | + """ |
| 22 | + combined = pl.concat(results, how="vertical").group_by(['LocusID', 'allele_number']).sum() |
| 23 | + data = data.join(combined, on=['LocusID', 'allele_number'], how="left").fill_null(0) |
| 24 | + data = data.with_columns([ |
| 25 | + (pl.col("AC") + pl.col("AC_right")).alias("AC"), |
| 26 | + ]) |
| 27 | + return data.drop(["AC_right"]) |
| 28 | + |
| 29 | + |
| 30 | +if __name__ == '__main__': |
| 31 | + # fn = "../../AoU_TRs.v0.1.tdb/" |
| 32 | + fn = sys.argv[1] |
| 33 | + batch_size = 250 |
| 34 | + out_prefix = "result" |
| 35 | + min_af = 0.01 |
| 36 | + lps_norm = 100 |
| 37 | + names = tdb.get_tdb_filenames(fn) |
| 38 | + columns = ["LocusID", "allele_number"] |
| 39 | + |
| 40 | + # Read base allele table efficiently |
| 41 | + counts = pl.from_arrow( |
| 42 | + pq.read_table(names['allele'], |
| 43 | + columns=columns) |
| 44 | + ) |
| 45 | + |
| 46 | + counts = counts.with_columns( |
| 47 | + pl.lit(0).alias('AC'), |
| 48 | + ) |
| 49 | + |
| 50 | + # Use multiple threads for reading + batch merge |
| 51 | + sample_paths = list(names['sample'].values()) |
| 52 | + |
| 53 | + with ThreadPoolExecutor(max_workers=4) as executor: # Adjust thread count |
| 54 | + results = [] |
| 55 | + for sample_counts in tqdm(executor.map(process_sample, sample_paths), |
| 56 | + total=len(sample_paths), desc="Processing"): |
| 57 | + results.append(sample_counts) |
| 58 | + |
| 59 | + # Merge in batches |
| 60 | + if len(results) >= batch_size: |
| 61 | + counts = merge_in(counts, results) |
| 62 | + results = [] # Reset batch |
| 63 | + |
| 64 | + # Merge remaining results |
| 65 | + if results: |
| 66 | + counts = merge_in(counts, results) |
| 67 | + |
| 68 | + AN = counts.group_by("LocusID").agg(pl.col('AC').sum().alias('AN')) |
| 69 | + counts = counts.join(AN, on="LocusID", how='left') |
| 70 | + counts = counts.with_columns((pl.col('AC') / pl.col('AN')).alias('AF')) |
| 71 | + |
| 72 | + counts.write_csv(f"{out_prefix}.allele_seq.txt", separator='\t') |
| 73 | + |
| 74 | + columns = ["LocusID", "allele_number", 'allele_length'] |
| 75 | + bylen = pl.from_arrow(pq.read_table(names['allele'], columns=columns)) |
| 76 | + |
| 77 | + bylen = bylen.join(counts, on=["LocusID", "allele_number"], how="left") # Use "inner" if necessary |
| 78 | + bylen = bylen.group_by(['LocusID', 'allele_length']).agg(pl.col('AC').sum(), pl.col('AN').first()) |
| 79 | + bylen = bylen.with_columns((pl.col('AC') / pl.col('AN')).alias('AF')) |
| 80 | + |
| 81 | + bylen.write_csv(f"{out_prefix}.allele_len.txt", separator='\t') |
| 82 | + |
| 83 | + lps = (bylen.filter(bylen['AF'] >= min_af) |
| 84 | + .group_by('LocusID') |
| 85 | + .agg((pl.col('allele_length').n_unique() / |
| 86 | + (pl.col('AN').first() / lps_norm) |
| 87 | + ).alias('LPS') |
| 88 | + ) |
| 89 | + ).select(['LocusID', 'LPS']) |
| 90 | + |
| 91 | + lps.write_csv(f"{out_prefix}.length_polymorphism.txt", separator='\t') |
0 commit comments