Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 123 additions & 11 deletions include/valik/search/local_prefilter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,40 @@
namespace valik
{

/**
* @brief Function that samples patterns on a query.
*
* @param read_len Length of query.
* @param pattern_size Length of pattern.
* @param query_every Every nth potential match is considered.
* @param callback Functor that corrects the threshold based on matching k-mer counts.
* @return Lower quartile of threshold correction.
*/
template <typename functor_t>
constexpr double sample_begin_positions(size_t const read_len, uint64_t const pattern_size, uint8_t const query_every, functor_t && callback)
{
assert(read_len >= pattern_size);

size_t first_pos{pattern_size};
if (read_len < query_every + pattern_size)
first_pos = 0; // start from beginning when short query

size_t corrected_pattern_count{0u};
double total_correction{0};
for (size_t pos = first_pos; pos <= read_len - pattern_size; pos = pos + query_every * pattern_size)
{
auto correction = callback(pos);
if (correction > 0)
{
corrected_pattern_count++;
total_correction += correction;
}
}

return 1 + total_correction / (double) std::max(corrected_pattern_count, (size_t) 1);
}


/**
* @brief Function that finds the begin positions of all pattern of a query.
*
Expand All @@ -33,10 +67,10 @@ constexpr void pattern_begin_positions(size_t const read_len, uint64_t const pat
assert(read_len >= pattern_size);

size_t last_begin{0u};
for (size_t i = 0; i <= read_len - pattern_size; i = i + query_every)
for (size_t pos = 0; pos <= read_len - pattern_size; pos = pos + query_every)
{
callback(i);
last_begin = i;
callback(pos);
last_begin = pos;
}

if (last_begin < read_len - pattern_size)
Expand All @@ -54,6 +88,11 @@ struct pattern_bounds
size_t begin_position;
size_t end_position;
size_t threshold;

size_t minimiser_count() const
{
return end_position - begin_position;
}
};

/**
Expand Down Expand Up @@ -89,35 +128,96 @@ pattern_bounds make_pattern_bounds(size_t const & begin,
assert(end_it != window_span_begin.begin());
pattern.end_position = end_it - window_span_begin.begin();

size_t const minimiser_count = pattern.end_position - pattern.begin_position;

pattern.threshold = thresholder.get(minimiser_count);
pattern.threshold = thresholder.get(pattern.minimiser_count());
return pattern;
}

/**
* @brief Function that for a single pattern counts matching k-mers and corrects the threshold to avoid too many spuriously matching bins.
*
* @param pattern Slice of a query record that is being considered.
* @param bin_count Number of bins in the IBF.
* @param counting_table Rows: minimisers of the query. Columns: bins of the IBF.
* @return Threshold correction that avoids too many spurious matches.
*/
template <typename binning_bitvector_t>
double find_dynamic_threshold_correction(pattern_bounds const & pattern,
size_t const & bin_count,
binning_bitvector_t const & counting_table)
{
// counting vector for the current pattern
seqan3::counting_vector<uint8_t> total_counts(bin_count, 0);

for (size_t i = pattern.begin_position; i < pattern.end_position; i++)
total_counts += counting_table[i];

std::unordered_set<size_t> pattern_hits;

bool max_threshold{false};
uint8_t correction_count{0};
while (true)
{
for (size_t current_bin = 0; current_bin < total_counts.size(); current_bin++)
{
auto &&count = total_counts[current_bin];
if (count >= (pattern.threshold + correction_count))
{
pattern_hits.insert(current_bin);
}
}
if ((pattern.threshold + correction_count) >= pattern.minimiser_count())
max_threshold = true;
if (pattern_hits.size() < std::max((size_t) 4, (size_t) std::round(bin_count / 4.0)) ||
max_threshold)
break;
else
{
pattern_hits.clear();
// increase threshold in 10% increments or by at least 1 to find lowest threshold that is not ubiquitous
correction_count += std::max((size_t) 1, (size_t) std::round(pattern.threshold * 0.1 * correction_count));
}
}

return (double) correction_count / (double) pattern.threshold;
}


/**
* @brief Function that for a single pattern counts matching k-mers and returns bins that exceed the threshold.
*
* @param pattern Slice of a query record that is being considered.
* @param correction Threshold correction determined from a sample of patterns.
* @param bin_count Number of bins in the IBF.
* @param counting_table Rows: minimisers of the query. Columns: bins of the IBF.
* @param sequence_hits Bins that likely contain a match for the pattern (IN-OUT parameter).
*/
template <typename binning_bitvector_t>
void find_pattern_bins(pattern_bounds const & pattern,
size_t const & bin_count,
binning_bitvector_t const & counting_table,
std::unordered_set<size_t> & sequence_hits)
double const & correction_coef,
size_t const & bin_count,
binning_bitvector_t const & counting_table,
std::unordered_set<size_t> & sequence_hits)
{
// counting vector for the current pattern
seqan3::counting_vector<uint8_t> total_counts(bin_count, 0);

for (size_t i = pattern.begin_position; i < pattern.end_position; i++)
total_counts += counting_table[i];

for (size_t current_bin = 0; current_bin < total_counts.size(); current_bin++)
{
auto &&count = total_counts[current_bin];
if (count >= pattern.threshold)
/*
if (current_bin == 0)
{
if (std::round(pattern.threshold * correction_coef) > pattern.threshold)
{
seqan3::debug_stream << "Threshold was " << pattern.threshold << '\n';
seqan3::debug_stream << "New threshold " << std::to_string((size_t) std::round(pattern.threshold * correction_coef)) << '\n';
}
}
*/
if (count >= (pattern.threshold * correction_coef))
{
// the result is a union of results from all patterns of a read
sequence_hits.insert(current_bin);
Expand Down Expand Up @@ -199,10 +299,22 @@ void local_prefilter(
minimiser.clear();

std::unordered_set<size_t> sequence_hits{};
double threshold_correction{1};
if (!arguments.static_threshold)
{
threshold_correction = sample_begin_positions(seq.size(), arguments.pattern_size, arguments.query_every, [&](size_t const begin) -> double
{
pattern_bounds const pattern = make_pattern_bounds(begin, arguments, window_span_begin, thresholder);
return find_dynamic_threshold_correction(pattern, bin_count, counting_table);
});
}

if (threshold_correction > 1.0000001)
seqan3::debug_stream << "Correct threshold by " << threshold_correction << '\n';
pattern_begin_positions(seq.size(), arguments.pattern_size, arguments.query_every, [&](size_t const begin)
{
pattern_bounds const pattern = make_pattern_bounds(begin, arguments, window_span_begin, thresholder);
find_pattern_bins(pattern, bin_count, counting_table, sequence_hits);
find_pattern_bins(pattern, threshold_correction, bin_count, counting_table, sequence_hits);
});

result_cb(record, sequence_hits);
Expand Down
5 changes: 4 additions & 1 deletion include/valik/search/search_local.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ bool search_local(search_arguments & arguments, search_time_statistics & time_st
std::cout.precision(3);

std::cout << "\n-----------Search parameters-----------\n";
std::cout << "kmer size " << std::to_string(arguments.shape_size) << '\n';
if (arguments.shape_size == arguments.shape_weight)
std::cout << "kmer size " << std::to_string(arguments.shape_size) << '\n';
else
std::cout << "kmer shape " << arguments.shape.to_string() << '\n';
std::cout << "window size " << std::to_string(arguments.window_size) << '\n';
switch (arguments.search_type)
{
Expand Down
1 change: 1 addition & 0 deletions include/valik/shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ struct search_arguments final : public minimiser_threshold_arguments, search_pro
bool keep_best_repeats{false};
double best_bin_entropy_cutoff{0.25};
bool keep_all_repeats{false};
bool static_threshold{false};
bool stellar_only{false};

size_t cart_max_capacity{1000};
Expand Down
7 changes: 6 additions & 1 deletion src/argument_parsing/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ void init_search_parser(sharg::parser & parser, search_arguments & arguments)
.long_id = "keep-all-repeats",
.description = "Do not filter out query matches from repeat regions. This may significantly increase the runtime.",
.advanced = true});
parser.add_flag(arguments.static_threshold,
sharg::config{.short_id = '\0',
.long_id = "static-threshold",
.description = "Do not correct threshold to avoid many spuriously matching bins.",
.advanced = true});
parser.add_option(arguments.seg_count_in,
sharg::config{.short_id = 'n',
.long_id = "seg-count",
Expand Down Expand Up @@ -346,7 +351,7 @@ void run_search(sharg::parser & parser)
{
arguments.search_type = search_kind::LEMMA;
if (arguments.threshold < lemma_thresh)
std::cerr << "[Warning] chosen threshold is less than the k-mer lemma threshold. Ignore this warning if this was deliberate.";
std::cerr << "[Warning] The chosen threshold is less than the k-mer lemma threshold. Ignore this warning if this was deliberate.";
}
}
if (arguments.stellar_only)
Expand Down
4 changes: 2 additions & 2 deletions test/cli/valik_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ TEST_P(valik_search_clusters, search)
"--error-rate ", std::to_string(error_rate),
"--index ", ibf_path(number_of_bins, window_size),
"--query ", data("query.fq"),
"--threads 1", "--very-verbose",
"--threads 1", "--very-verbose", "--static-threshold",
"--cart-max-capacity 3",
"--max-queued-carts 10",
"--without-parameter-tuning");
Expand Down Expand Up @@ -399,7 +399,7 @@ TEST_P(valik_search_segments, search)
"--error-rate ", std::to_string(error_rate),
"--index ", ibf_path(segment_overlap, number_of_bins, window_size),
"--query ", data("single_query.fasta"),
"--threads 1", "--very-verbose",
"--threads 1", "--very-verbose", "--static-threshold",
"--ref-meta", segment_metadata_path(segment_overlap, number_of_bins),
"--cart-max-capacity 3",
"--max-queued-carts 10",
Expand Down
Loading