Skip to content

Commit 7a477a7

Browse files
committed
phase/hap complete
Also beginning query refactor to polars Tests are still going to fail
1 parent 84cb2be commit 7a477a7

File tree

5 files changed

+99
-2
lines changed

5 files changed

+99
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def read(rel_path):
88
with codecs.open(os.path.join(here, rel_path), 'r') as fp:
99
return fp.read()
1010

11-
VERSION = "0.3.0"
11+
VERSION = "1.0.0-dev"
1212

1313
setup(
1414
name="tdb",

tdb/bigmerge.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,12 +360,15 @@ def sample_puller(con, dbname, output_dir, compress):
360360
files = tdb.get_tdb_filenames(dbname)
361361
for _, sample_pq in files['sample'].items():
362362
out_name = os.path.join(output_dir, os.path.basename(sample_pq))
363+
# TODO: new fields need to be pulled
363364
query = f"""
364365
COPY (
365366
SELECT
366367
allele_lookup.to_LocusID as LocusID,
367368
allele_lookup.to_allele_number_new as allele_number,
368369
sample.spanning_reads,
370+
sample.phase_set,
371+
sample.haplotype,
369372
sample.length_range_lower,
370373
sample.length_range_upper,
371374
sample.average_methylation,

tdb/create.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"sequence": pa.binary(),
2424
"spanning_reads": pa.uint16(),
2525
"phase_set": pa.uint32(),
26-
"haplotype": pa.uint16(),
26+
"haplotype": pa.uint8(),
2727
"length_range_lower": pa.uint16(),
2828
"length_range_upper": pa.uint16(),
2929
"average_methylation": pa.float32()}
@@ -147,6 +147,7 @@ def convert_buffer(vcf, samples, stats, seen_loci, avail_mem=4e9, force=False):
147147
logging.warning("Unable to convert %s", str(entry))
148148
logging.warning("Error: %s", str(e))
149149
if not force:
150+
logging.error("Exiting")
150151
sys.exit(1)
151152
continue
152153
# pylint: enable=broad-exception-caught

tdb/merge.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ def update_sample_table(second_sample, sample_lookup, compress):
328328
lookup.to_LocusID AS LocusID,
329329
lookup.to_allele_number_new AS allele_number,
330330
sample.spanning_reads AS spanning_reads,
331+
sample.phase_set AS phase_set,
332+
sample.haplotype AS haplotype,
331333
sample.length_range_lower AS length_range_lower,
332334
sample.length_range_upper AS length_range_upper,
333335
sample.average_methylation AS average_methylation

tdb/queries/allele_stats.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import sys
2+
from concurrent.futures import ThreadPoolExecutor
3+
4+
import polars as pl
5+
from tqdm import tqdm
6+
import pyarrow.parquet as pq
7+
8+
import tdb
9+
10+
def process_sample(path):
11+
"""
12+
Open a sample and generate its allele count
13+
"""
14+
pf = pq.ParquetFile(path)
15+
sample_data = pl.from_arrow(pf.read(columns=['LocusID', 'allele_number']))
16+
return sample_data.group_by(['LocusID', 'allele_number']).len().rename({"len": "AC"})
17+
18+
def merge_in(data, results):
19+
"""
20+
Given a batch of samples' counts, consolidate into main data
21+
"""
22+
combined = pl.concat(results, how="vertical").group_by(['LocusID', 'allele_number']).sum()
23+
data = data.join(combined, on=['LocusID', 'allele_number'], how="left").fill_null(0)
24+
data = data.with_columns([
25+
(pl.col("AC") + pl.col("AC_right")).alias("AC"),
26+
])
27+
return data.drop(["AC_right"])
28+
29+
30+
if __name__ == '__main__':
31+
# fn = "../../AoU_TRs.v0.1.tdb/"
32+
fn = sys.argv[1]
33+
batch_size = 250
34+
out_prefix = "result"
35+
min_af = 0.01
36+
lps_norm = 100
37+
names = tdb.get_tdb_filenames(fn)
38+
columns = ["LocusID", "allele_number"]
39+
40+
# Read base allele table efficiently
41+
counts = pl.from_arrow(
42+
pq.read_table(names['allele'],
43+
columns=columns)
44+
)
45+
46+
counts = counts.with_columns(
47+
pl.lit(0).alias('AC'),
48+
)
49+
50+
# Use multiple threads for reading + batch merge
51+
sample_paths = list(names['sample'].values())
52+
53+
with ThreadPoolExecutor(max_workers=4) as executor: # Adjust thread count
54+
results = []
55+
for sample_counts in tqdm(executor.map(process_sample, sample_paths),
56+
total=len(sample_paths), desc="Processing"):
57+
results.append(sample_counts)
58+
59+
# Merge in batches
60+
if len(results) >= batch_size:
61+
counts = merge_in(counts, results)
62+
results = [] # Reset batch
63+
64+
# Merge remaining results
65+
if results:
66+
counts = merge_in(counts, results)
67+
68+
AN = counts.group_by("LocusID").agg(pl.col('AC').sum().alias('AN'))
69+
counts = counts.join(AN, on="LocusID", how='left')
70+
counts = counts.with_columns((pl.col('AC') / pl.col('AN')).alias('AF'))
71+
72+
counts.write_csv(f"{out_prefix}.allele_seq.txt", separator='\t')
73+
74+
columns = ["LocusID", "allele_number", 'allele_length']
75+
bylen = pl.from_arrow(pq.read_table(names['allele'], columns=columns))
76+
77+
bylen = bylen.join(counts, on=["LocusID", "allele_number"], how="left") # Use "inner" if necessary
78+
bylen = bylen.group_by(['LocusID', 'allele_length']).agg(pl.col('AC').sum(), pl.col('AN').first())
79+
bylen = bylen.with_columns((pl.col('AC') / pl.col('AN')).alias('AF'))
80+
81+
bylen.write_csv(f"{out_prefix}.allele_len.txt", separator='\t')
82+
83+
lps = (bylen.filter(bylen['AF'] >= min_af)
84+
.group_by('LocusID')
85+
.agg((pl.col('allele_length').n_unique() /
86+
(pl.col('AN').first() / lps_norm)
87+
).alias('LPS')
88+
)
89+
).select(['LocusID', 'LPS'])
90+
91+
lps.write_csv(f"{out_prefix}.length_polymorphism.txt", separator='\t')

0 commit comments

Comments
 (0)