Skip to content

Commit

Permalink
fix: ref_dist calculated using wf_ed, not variants.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimD1 committed Nov 18, 2024
1 parent 237d0e7 commit a69d0bd
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 22 deletions.
166 changes: 145 additions & 21 deletions src/dist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ void calc_prec_recall(
graph->qseqs[graph->qnodes-1].length()-1,
graph->tseqs[graph->tnodes-1].length()-1);
idx4 curr = end;
int prev_query_ref_pos = graph->ref.size();
int prev_truth_pos = graph->truth.size();
int sync_group = 0;
int ref_dist = 0;
int query_dist = 0;
std::vector<int> sync_tvars;
std::vector<int> sync_qvars;
Expand Down Expand Up @@ -288,22 +289,6 @@ void calc_prec_recall(
if (print) printf("new truth variant\n");
int tvar_idx = graph->tidxs[curr.tni];
sync_tvars.push_back(tvar_idx);
switch (graph->ttypes[curr.tni]) {
case TYPE_INS:
ref_dist += tvars->alts[tvar_idx].size();
break;
case TYPE_DEL:
ref_dist += tvars->refs[tvar_idx].size();
break;
case TYPE_SUB:
ref_dist += 1;
break;
default:
ERROR("Unexpected truth var type '%s' in calc_prec_recall() at %s:%d",
type_strs[graph->ttypes[curr.tni]].data(),
tvars->ctg.data(), tvars->poss[tvar_idx]);
break;
}
}
// if the alignment is a substitution, insertion, or deletion
if (prev.qni == curr.qni && prev.tni == curr.tni && // same matrix
Expand All @@ -315,7 +300,9 @@ void calc_prec_recall(
}

// check if this movement is a sync point (ref main diag mvmt or between var main diag mvmt)
bool on_main_diag = graph->tbegs[curr.tni] + curr.ti == graph->qbegs[curr.qni] + curr.qi;
int truth_ref_pos = graph->tbegs[curr.tni] + curr.ti;
int query_ref_pos = graph->qbegs[curr.qni] + curr.qi;
bool on_main_diag = (truth_ref_pos == query_ref_pos);
bool ref_query_move = graph->qtypes[curr.qni] == TYPE_REF
&& prev.qni == curr.qni
&& prev.qi+1 == curr.qi;
Expand All @@ -331,6 +318,14 @@ void calc_prec_recall(

// add sync point
if (sync_tvars.size() || sync_qvars.size()) {

// calculate edit distance without variants
int ref_dist = 0;
std::vector< std::vector<int> > ed_offs, ed_ptrs;
int truth_pos = graph->get_truth_pos(curr.tni, curr.ti);
wf_ed(graph->ref.substr(query_ref_pos, prev_query_ref_pos - query_ref_pos),
graph->truth.substr(truth_pos, prev_truth_pos - truth_pos),
ref_dist, ed_offs, ed_ptrs);
if (print) printf("syncing\n");
float credit = 0;
if (ref_dist == 0) {
Expand Down Expand Up @@ -369,12 +364,13 @@ void calc_prec_recall(
tvars->credit[truth_hap][tvar_idx] = credit;
}

// reset
// reset sync group
sync_group++;
sync_qvars.clear();
sync_tvars.clear();
prev_query_ref_pos = query_ref_pos;
prev_truth_pos = truth_pos;
query_dist = 0;
ref_dist = 0;
}
}
curr = prev;
Expand Down Expand Up @@ -607,6 +603,7 @@ void precision_recall_threads_wrapper(

/******************************************************************************/


void Graph::print() {
printf("QUERY:\n");
for (int qn = 0; qn < this->qnodes; qn++) {
Expand All @@ -630,8 +627,22 @@ void Graph::print() {
}
}


/******************************************************************************/


int Graph::get_truth_pos(int truth_node_idx, int truth_idx) {
int truth_pos = 0;
for (int tn = 0; tn < truth_node_idx; tn++) {
truth_pos += this->tseqs[tn].size() - 1; // each tseq starts with '_'
}
truth_pos += truth_idx;
return truth_pos;
}

/******************************************************************************/


/* Build a graph containing all possible paths through query, and selected truth path
NOTE: start off simple O(V*V) and make sure it's correct (worry about efficiency later)
*/
Expand Down Expand Up @@ -747,6 +758,7 @@ Graph::Graph(

// iterate through all the variants
this->tnodes = 0;
this->truth = "";
for (int var_idx = tvar_beg; var_idx < tvar_end; var_idx++) {

if (tvars->var_on_hap(var_idx, truth_hap)) { // ignore variant if on other hap
Expand All @@ -758,6 +770,7 @@ Graph::Graph(
this->tbegs.push_back(ref_pos - ref_beg);
this->ttypes.push_back(TYPE_REF);
this->tidxs.push_back(-1);
this->truth += ref->fasta.at(ctg).substr(ref_pos, var_pos-ref_pos);
}

// add the truth variant
Expand All @@ -766,16 +779,18 @@ Graph::Graph(
this->tbegs.push_back(var_pos - ref_beg);
this->ttypes.push_back(tvars->types[var_idx]);
this->tidxs.push_back(var_idx);
this->truth += tvars->alts[var_idx];
ref_pos = tvars->poss[var_idx] + tvars->rlens[var_idx];
}
}

// add the remainder of the truth
this->tnodes++;
this->tseqs.push_back("_" + ref->fasta.at(ctg).substr(ref_pos, ref_end+1 -ref_pos));
this->tseqs.push_back("_" + ref->fasta.at(ctg).substr(ref_pos, ref_end+1 - ref_pos));
this->tbegs.push_back(ref_pos - ref_beg);
this->ttypes.push_back(TYPE_REF);
this->tidxs.push_back(-1);
this->truth += ref->fasta.at(ctg).substr(ref_pos, ref_end+1 - ref_pos);

/////////////////////////////////////
// STEP 4: SET TRUTH NODE POINTERS //
Expand All @@ -790,6 +805,8 @@ Graph::Graph(
this->tprevs[tni+1].push_back(tni);
}
}

this->ref = ref->fasta.at(ctg).substr(ref_beg, ref_end+1 - ref_beg);
}


Expand Down Expand Up @@ -1416,3 +1433,110 @@ std::vector<int> wf_swg_backtrack(
}
return cigar;
}


/******************************************************************************/


void wf_ed(
const std::string & query, const std::string & truth, int & s,
std::vector< std::vector<int> > & offs,
std::vector< std::vector<int> > & ptrs, bool print
) {

// alignment
int query_len = query.size();
int truth_len = truth.size();

// early exit if either string is empty
if (!query_len) { s = truth_len; return; }
if (!truth_len) { s = query_len; return; }
s = 0;

int mat_len = query_len + truth_len - 1;
offs.push_back(std::vector<int>(mat_len, -2));
offs[0][query_len-1] = -1;
ptrs.push_back(std::vector<int>(mat_len, PTR_NONE));
bool done = false;
while (true) {

// EXTEND WAVEFRONT
for (int d = 0; d < mat_len; d++) {
int off = offs[s][d];
int diag = d + 1 - query_len;

// don't allow starting from untouched cells
if (off == -2) continue;

// check that it's within matrix
if (diag + off + 1 < 0) continue;
if (off > query_len - 1) continue;
if (diag + off > truth_len - 1) continue;

// extend
while (off < query_len - 1 &&
diag + off < truth_len - 1) {
if (query[off+1] == truth[diag+off+1]) off++;
else break;
}
if (off > offs[s][d])
if(print) printf("(%d, %d) extend\n", off, off+diag);
offs[s][d] = off;

// finish if done
if (off == query_len - 1 && off + diag == truth_len - 1)
{ done = true; break; }

}
if (done) break;

// debug print
if(print) printf("offs %d:", s);
for (int di = 0; di < int(query.size() + truth.size()-1); di++) {
if(print) printf("\t%d", offs[s][di]);
}
if(print) printf("\n");

// NEXT WAVEFRONT
offs.push_back(std::vector<int>(mat_len, -2));
ptrs.push_back(std::vector<int>(mat_len, PTR_NONE));
s++;
if(print) printf("\nscore = %d\n", s);
for (int d = 0; d < mat_len; d++) {
int diag = d + 1 - query_len;

// SUB
if (s-1 >= 0 && offs[s-1][d] != -2 &&
offs[s-1][d]+1 < query_len &&
diag + offs[s-1][d]+1 < truth_len &&
offs[s-1][d]+1 >= offs[s][d]) {
offs[s][d] = offs[s-1][d] + 1;
ptrs[s][d] |= PTR_SUB;
if(print) printf("(%d, %d) sub\n", offs[s][d], offs[s][d]+diag);
}

// DEL
if (s-1 >= 0 && d > 0 &&
offs[s-1][d-1] != -2 &&
diag + offs[s-1][d-1] < truth_len &&
offs[s-1][d-1] >= offs[s][d]) {
offs[s][d] = offs[s-1][d-1];
ptrs[s][d] |= PTR_DEL;
if(print) printf("(%d, %d) del\n", offs[s][d], offs[s][d]+diag);
}

// INS
if (s-1 >= 0 && d < mat_len-1 &&
offs[s-1][d+1] != -2 &&
offs[s-1][d+1]+1 < query_len &&
diag + offs[s-1][d+1]+1 < truth_len &&
diag + offs[s-1][d+1]+1 >= -1 &&
offs[s-1][d+1]+1 >= offs[s][d]) {
offs[s][d] = offs[s-1][d+1]+1;
ptrs[s][d] |= PTR_INS;
if(print) printf("(%d, %d) ins\n", offs[s][d], offs[s][d]+diag);
}
}
}
}

11 changes: 10 additions & 1 deletion src/dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ class Graph {

// supercluster data used to generate this graph
// query variant data is at sc->callset_vars[QUERY]->{fieldname}[this->qidxs[qni]]
// query variant data is at sc->callset_vars[TRUTH]->{fieldname}[this->tidxs[tni]]
// truth variant data is at sc->callset_vars[TRUTH]->{fieldname}[this->tidxs[tni]]
std::shared_ptr<ctgSuperclusters> sc;
int sc_idx;
std::string ref; // for calculating original edit distance
std::string truth;

// graph data for each node
int qnodes; // each qvector is of size qnodes
Expand Down Expand Up @@ -44,6 +46,7 @@ class Graph {

// methods
void print();
int get_truth_pos(int truth_node_idx, int truth_idx);
};

/******************************************************************************/
Expand Down Expand Up @@ -198,6 +201,12 @@ std::vector<int> wf_swg_backtrack(
const std::vector< std::vector< std::vector<int> > > & offs,
int s, int sub, int open, int extend, bool print = false);


void wf_ed(
const std::string & query, const std::string & truth, int & s,
std::vector< std::vector<int> > & offs,
std::vector< std::vector<int> > & ptrs, bool print = false);

/******************************************************************************/

void precision_recall_threads_wrapper(
Expand Down

0 comments on commit a69d0bd

Please sign in to comment.