Skip to content

Commit

Permalink
introduce discard pile for reconstructed haplotypes based on cutoff a…
Browse files Browse the repository at this point in the history
…nd adjust metrics to account for it, introduce avg number of discarded haps per region, precision and f1 score, adjust output to include new metrics and exclude old ones
  • Loading branch information
IoannaNika committed Oct 21, 2024
1 parent 9e63c99 commit b013c50
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 13 deletions.
77 changes: 68 additions & 9 deletions src/benchmarking_scripts/evaluate_hrt_output_lumc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import editdistance
import os
from typing import List, Tuple
from utils.evaluation import calculate_average_edit_distance, calculate_average_number_of_haplotypes, calculate_recall, calculate_duplication_ratio, closest_haplotype, calculate_relative_absolute_abundance_error
from utils.evaluation import calculate_average_edit_distance, calculate_average_number_of_haplotypes, calculate_recall, calculate_duplication_ratio, closest_haplotype, calculate_relative_absolute_abundance_error, normalised_edit_distance, calculate_precision, calculate_f1_score

def is_the_coverage_sufficient(reads: pd.DataFrame, gr: str, low_limit: int = 100) -> bool:
"""
Expand Down Expand Up @@ -238,6 +238,45 @@ def update_abundance_per_region(abundance_per_region:dict, haplotype:str, region
abundance_per_region[region][haplotype] += rel_ab
return abundance_per_region

def init_discarded_haplotypes_per_region(genomic_regions:List[Tuple[int, int]]) -> dict:
"""
Initializes the discarded haplotypes per region dictionary. The dictionary contains the number of discarded haplotypes per region.
Discarded haplotypes are the ones that are not sufficiently similar to the Wuhan or Omicron consensus sequences.
Args:
genomic_regions: list of tuples, genomic regions
Returns:
discarded_haplotypes_per_region: dict, discarded haplotypes per region dictionary
"""

discarded_haplotypes_per_region = {}

for region in genomic_regions:
gr_key = str(region[0]) + "_" + str(region[1])
discarded_haplotypes_per_region[gr_key] = 0

return discarded_haplotypes_per_region

def normalize_abundance_per_region(abundance_per_region:dict) -> dict:
"""
Normalizes the abundance per region such that the sum of the relative abundances of the Wuhan and Omicron haplotypes is 1
Args:
abundance_per_region: dict, abundance per region dictionary
Returns:
abundance_per_region: dict, normalized abundance per region dictionary
"""

for region in abundance_per_region.keys():
total_ab = abundance_per_region[region]['Wuhan'] + abundance_per_region[region]['Omicron']
abundance_per_region[region]['Wuhan'] = abundance_per_region[region]['Wuhan'] / total_ab
abundance_per_region[region]['Omicron'] = abundance_per_region[region]['Omicron'] / total_ab

assert (abundance_per_region[region]['Wuhan'] + abundance_per_region[region]['Omicron']) == 1

return abundance_per_region

def main():
parser = argparse.ArgumentParser(description='Evaluate HRT output, standard format')
Expand All @@ -261,6 +300,8 @@ def main():
input_path = args.input
input_df = pd.read_csv(input_path, sep='\t', header=0)

# normalised edit distance cutoff, defines the threshold for which a haplotype ends up in the discarded haplotypes
dissimilarity_cutoff = 0.01

# remove sequences that are shorter than 800 bp
input_df = input_df[input_df['sequence'].apply(lambda x: len(x) > 800)]
Expand All @@ -287,8 +328,9 @@ def main():
genomic_regions = [region for region in genomic_regions if is_the_coverage_sufficient(reads_tsv, f"{region[0]}_{region[1]}")]

true_abs_per_hap_per_sample_region, true_n_haps_per_sample_region = true_abundances_and_haplotypes_per_sample_per_region(genomic_regions, wuhan_consensus, omicron_consensus, ref_seq)

##### metrics tracking #####
discarded_haplotypes_per_region = init_discarded_haplotypes_per_region(genomic_regions)
edit_distance_from_closest_consensus_per_region = init_edit_distance_from_closest_consensus_per_region(genomic_regions)
number_of_haplotypes_per_region = init_number_of_haplotypes_per_region(genomic_regions)
number_of_exact_haplotypes_per_region = init_number_of_haplotypes_per_region(genomic_regions)
Expand Down Expand Up @@ -317,31 +359,48 @@ def main():
closest_hap, ed_from_hap = closest_haplotype(seq, rel_ab, omicron_amplicon, wuhan_amplicon, true_abs_per_hap_per_sample_region, region, abundance_per_region, args.sample_name.split("-")[0])
print(closest_hap, " is the closest hap with edit distance ", ed_from_hap," with rel ab ", rel_ab, " region ", region)

if closest_hap == 'Wuhan':
norm_ed = normalised_edit_distance(seq, wuhan_amplicon)
else:
norm_ed = normalised_edit_distance(seq, omicron_amplicon)

# update metrics
if norm_ed > dissimilarity_cutoff:
print("Discarding haplotype ", haplotype_id, " with normalised edit distance ", norm_ed, " region ", region, "and length ", len(seq))
# if the normalised edit distance is higher than the cutoff, the haplotype is discarded
discarded_haplotypes_per_region[region] += 1

edit_distance_from_closest_consensus_per_region = update_edit_distance_from_closest_consensus_per_region(edit_distance_from_closest_consensus_per_region, closest_hap, region, ed_from_hap)
number_of_haplotypes_per_region = update_number_of_haplotypes_per_region(number_of_haplotypes_per_region, closest_hap, region)
abundance_per_region = update_abundance_per_region(abundance_per_region, closest_hap, region, rel_ab)

if ed_from_hap == 0:
# if the edit distance is 0, the haplotype is exactly the same as the closest consensus
number_of_exact_haplotypes_per_region = update_number_of_haplotypes_per_region(number_of_haplotypes_per_region, closest_hap, region)


# normalize the abundance per region to account for the discarded haplotypes
abundance_per_region = normalize_abundance_per_region(abundance_per_region)

# calculate summary statistics for the whole sample
average_edit_distance = calculate_average_edit_distance(edit_distance_from_closest_consensus_per_region)
average_number_of_haplotypes = calculate_average_number_of_haplotypes(number_of_haplotypes_per_region)
recall_omicron, recall_wuhan = calculate_recall(number_of_haplotypes_per_region, true_n_haps_per_sample_region,args.sample_name.split("-")[0])
recall = round((recall_omicron + recall_wuhan) / 2, 2)
duplication_ratio = calculate_duplication_ratio(number_of_haplotypes_per_region, true_n_haps_per_sample_region, args.sample_name.split("-")[0])
recall_wuhan, recall_omicron = calculate_recall(number_of_haplotypes_per_region, true_n_haps_per_sample_region,args.sample_name.split("-")[0])
recall = round((recall_omicron + recall_wuhan) / 2,3)
precision_omicron, precision_wuhan = calculate_precision(number_of_haplotypes_per_region, true_n_haps_per_sample_region, args.sample_name.split("-")[0])
precision = round((precision_omicron + precision_wuhan) / 2, 3)
f1_score = calculate_f1_score(precision, recall)
# duplication_ratio = calculate_duplication_ratio(number_of_haplotypes_per_region, true_n_haps_per_sample_region, args.sample_name.split("-")[0])
average_number_of_haplotypes_discard = sum(discarded_haplotypes_per_region.values()) / len(discarded_haplotypes_per_region.keys())
avg_rel_abs_ab_error = calculate_relative_absolute_abundance_error(abundance_per_region, true_abs_per_hap_per_sample_region, args.sample_name.split("-")[0])

# if the output file does not exist, create it
if not os.path.exists(args.output):
with open(args.output, 'w') as f:
f.write("sample_name\taverage_edit_distance\taverage_number_of_haplotypes\trecall\trecall_wuhan\trecall_omicron\tduplication_ratio\tavg_rel_abs_ab_error\n")
f.write("sample_name\tf1_score\trecall\tprecision\taverage_edit_distance\tavg_num_haps_in_discard\tavg_rel_abs_ab_error\n")

# write the summary statistics to the output file
with open(args.output, 'a') as f:
f.write(f"{args.sample_name}\t{average_edit_distance}\t{average_number_of_haplotypes}\t{recall}\t{recall_wuhan}\t{recall_omicron}\t{duplication_ratio}\t{avg_rel_abs_ab_error}\n")
f.write(f"{args.sample_name}\t{f1_score}\t{recall}\t{precision}\t{average_edit_distance}\t{average_number_of_haplotypes_discard}\t{avg_rel_abs_ab_error}\n")

if __name__ == '__main__':
main()
Expand Down
110 changes: 106 additions & 4 deletions src/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def calculate_average_number_of_haplotypes(number_of_haplotypes_per_region:dict)

return average_number_of_haplotypes


def calculate_recall(number_of_haplotypes_per_region:dict, true_number_haps_per_sample_region:dict, sample_name:str) -> Tuple[float, float]:
"""
Calculates the recall for Wuhan and Omicron haplotypes. We define as recall the ratio of the number of true reconstructed haplotypes to the number of true haplotypes.
Expand Down Expand Up @@ -85,6 +84,59 @@ def calculate_recall(number_of_haplotypes_per_region:dict, true_number_haps_per_
return recall_wuhan, recall_omicron


def calculate_precision(number_of_exact_haplotypes_per_region:dict, number_of_haplotypes_per_region:dict) -> Tuple[float, float]:
"""
Calculates the precision for Wuhan and Omicron haplotypes. How many of the reconstructed haplotypes assigned to Wuhan and Omicron are exact matches to the true haplotypes?
Args:
number_of_exact_haplotypes_per_region: dict, number of exact haplotypes per region dictionary
number_of_haplotypes_per_region: dict, number of haplotypes per region dictionary
Returns:
precision_wuhan: float, precision for Wuhan haplotypes
precision_omicron: float, precision for Omicron haplotypes
"""

precision_wuhan = []
precision_omicron = []

for region in number_of_exact_haplotypes_per_region:
if number_of_haplotypes_per_region[region]['Wuhan'] == 0:
precision_wuhan_region = 1
else:
precision_wuhan_region = number_of_exact_haplotypes_per_region[region]['Wuhan'] / number_of_haplotypes_per_region[region]['Wuhan']
if number_of_haplotypes_per_region[region]['Omicron'] == 0:
precision_omicron_region = 1
else:
precision_omicron_region = number_of_exact_haplotypes_per_region[region]['Omicron'] / number_of_haplotypes_per_region[region]['Omicron']

precision_wuhan.append(precision_wuhan_region)
precision_omicron.append(precision_omicron_region)

precision_wuhan = round(sum(precision_wuhan) / len(precision_wuhan), 2)
precision_omicron = round(sum(precision_omicron) / len(precision_omicron), 2)

return precision_wuhan, precision_omicron


def calculate_f1_score(precision:float, recall:float) -> float:
"""
Calculates the F1 score. The harmonic mean of precision and recall.
Args:
precision: float, precision
recall: float, recall
Returns:
f1_score: float, F1 score
"""

f1_score = round(2 * (precision * recall) / (precision + recall + 1e-6), 3)

return f1_score



def calculate_duplication_ratio(number_of_haplotypes_per_region:dict, true_number_haps_per_sample_region:dict, sample_name:str) -> float:
"""
Calculates the duplication ratio. The ratio of the number of reconstructed haplotypes to the number of true haplotypes.
Expand Down Expand Up @@ -180,9 +232,20 @@ def closest_haplotype(seq:str, rel_ab:float, omicron_amplicon:str, wuhan_amplico
return 'Omicron', omicron_distance


def edit_distance_on_overlap(seq1:str, seq2:str) -> Tuple[str, str]:
# align the two sequences using mafft
def find_overlapping_region(seq1:str, seq2:str) -> Tuple[str, str]:
"""
Finds the overlapping region between two sequences
Args:
seq1: str, input sequence 1
seq2: str, input sequence 2
Returns:
overlap_seq1: str, overlapping region of sequence 1
overlap_seq2: str, overlapping region of sequence 2
"""

# align the two sequences using mafft
with open("temp.fasta", "w") as f:
f.write(">seq1\n{}\n".format(seq1))
f.write(">seq2\n{}\n".format(seq2))
Expand Down Expand Up @@ -221,6 +284,45 @@ def edit_distance_on_overlap(seq1:str, seq2:str) -> Tuple[str, str]:
overlap_seq1 = aligned_seq1[start:end].replace("-", "").strip().strip("N")
overlap_seq2 = aligned_seq2[start:end].replace("-", "").strip().strip("N")

return overlap_seq1, overlap_seq2


def edit_distance_on_overlap(seq1:str, seq2:str) -> Tuple[str, str]:
"""
Calculates the edit distance between two sequences on the overlapping region
Args:
seq1: str, input sequence 1
seq2: str, input sequence 2
Returns:
edit_distance: int, edit distance between the two sequences
"""
overlap_seq1, overlap_seq2 = find_overlapping_region(seq1, seq2)

edit_distance = editdistance.eval(overlap_seq1, overlap_seq2)

return edit_distance
return edit_distance


def normalised_edit_distance_on_overlap(seq1:str, seq2:str) -> Tuple[str, str]:
"""
Calculates the normalised edit distance between two sequences on the overlapping region
Args:
seq1: str, input sequence 1
seq2: str, input sequence 2
Returns:
normalised_edit_distance: float, normalised edit distance between the two sequences
"""

overlap_seq1, overlap_seq2 = find_overlapping_region(seq1, seq2)

edit_distance = editdistance.eval(overlap_seq1, overlap_seq2)

normalised_edit_distance = round(edit_distance / len(overlap_seq1), 3)

return normalised_edit_distance


0 comments on commit b013c50

Please sign in to comment.