Skip to content

Commit

Permalink
Fix for queries with invalid characters (#502)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Oleksandr Kulkov <[email protected]>
  • Loading branch information
hmusta and adamant-pwn authored Oct 8, 2024
1 parent f2f320f commit 5af4aca
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 27 deletions.
3 changes: 1 addition & 2 deletions metagraph/src/common/algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ namespace utils {
size_t segment_length) {
std::vector<bool> mask(array.size(), false);
size_t last_occurrence
= std::find(array.data(), array.data() + array.size(), label)
- array.data();
= std::find(array.begin(), array.end(), label) - array.begin();

for (size_t i = last_occurrence; i < array.size(); ++i) {
if (array[i] == label)
Expand Down
6 changes: 3 additions & 3 deletions metagraph/src/graph/representation/canonical_dbg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ ::map_to_nodes_sequentially(std::string_view sequence,
path.reserve(sequence.size() - get_k() + 1);

if (const auto sshash = std::dynamic_pointer_cast<const DBGSSHash>(graph_)) {
sshash->map_to_nodes_with_rc<>(sequence, [&](node_index node, bool orientation) {
sshash->map_to_nodes_with_rc<true>(sequence, [&](node_index node, bool orientation) {
callback(node && orientation ? reverse_complement(node) : node);
}, terminate);
return;
Expand Down Expand Up @@ -180,7 +180,7 @@ void CanonicalDBG::call_outgoing_kmers(node_index node,
}

if (const auto sshash = std::dynamic_pointer_cast<const DBGSSHash>(graph_)) {
sshash->call_outgoing_kmers_with_rc<>(node, [&](node_index next, char c, bool orientation) {
sshash->call_outgoing_kmers_with_rc<true>(node, [&](node_index next, char c, bool orientation) {
callback(orientation ? reverse_complement(next) : next, c);
});
return;
Expand Down Expand Up @@ -273,7 +273,7 @@ void CanonicalDBG::call_incoming_kmers(node_index node,
}

if (const auto sshash = std::dynamic_pointer_cast<const DBGSSHash>(graph_)) {
sshash->call_incoming_kmers_with_rc<>(node, [&](node_index prev, char c, bool orientation) {
sshash->call_incoming_kmers_with_rc<true>(node, [&](node_index prev, char c, bool orientation) {
callback(orientation ? reverse_complement(prev) : prev, c);
});
return;
Expand Down
54 changes: 38 additions & 16 deletions metagraph/src/graph/representation/hash/dbg_sshash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "common/seq_tools/reverse_complement.hpp"
#include "common/threads/threading.hpp"
#include "common/logger.hpp"
#include "common/algorithms.hpp"
#include "kmer/kmer_extractor.hpp"


Expand Down Expand Up @@ -99,32 +100,53 @@ void DBGSSHash::add_sequence(std::string_view sequence,
throw std::logic_error("adding sequences not supported");
}

template <bool with_rc>
void DBGSSHash::map_to_nodes_with_rc(std::string_view sequence,
const std::function<void(node_index, bool)>& callback,
const std::function<bool()>& terminate) const {
if (terminate() || sequence.size() < k_)
template <bool with_rc, class Dict>
void map_to_nodes_with_rc_impl(size_t k,
const Dict &dict,
std::string_view sequence,
const std::function<void(sshash::lookup_result)>& callback,
const std::function<bool()>& terminate) {
size_t n = sequence.size();
if (terminate() || n < k)
return;

if (!num_nodes()) {
for (size_t i = 0; i < sequence.size() - k_ + 1 && !terminate(); ++i) {
callback(npos, false);
if (!dict.size()) {
for (size_t i = 0; i + k <= sequence.size() && !terminate(); ++i) {
callback(sshash::lookup_result());
}
return;
}

using kmer_t = get_kmer_t<Dict>;

std::vector<bool> invalid_char(n);
for (size_t i = 0; i < n; ++i) {
invalid_char[i] = !kmer_t::is_valid(sequence[i]);
}

auto invalid_kmer = utils::drag_and_mark_segments(invalid_char, true, k);

kmer_t uint_kmer = sshash::util::string_to_uint_kmer<kmer_t>(sequence.data(), k - 1);
uint_kmer.pad_char();
for (size_t i = k - 1; i < n && !terminate(); ++i) {
uint_kmer.drop_char();
uint_kmer.kth_char_or(k - 1, kmer_t::char_to_uint(sequence[i]));
callback(invalid_kmer[i] ? sshash::lookup_result()
: dict.lookup_advanced_uint(uint_kmer, with_rc));
}
}

template <bool with_rc>
void DBGSSHash::map_to_nodes_with_rc(std::string_view sequence,
const std::function<void(node_index, bool)>& callback,
const std::function<bool()>& terminate) const {
std::visit([&](const auto &dict) {
using kmer_t = get_kmer_t<decltype(dict)>;
kmer_t uint_kmer = sshash::util::string_to_uint_kmer<kmer_t>(sequence.data(), k_ - 1);
uint_kmer.pad_char();
for (size_t i = k_ - 1; i < sequence.size() && !terminate(); ++i) {
uint_kmer.drop_char();
uint_kmer.kth_char_or(k_ - 1, kmer_t::char_to_uint(sequence[i]));
auto res = dict.lookup_advanced_uint(uint_kmer, with_rc);
map_to_nodes_with_rc_impl<with_rc>(k_, dict, sequence, [&](sshash::lookup_result res) {
callback(sshash_to_graph_index(res.kmer_id), res.kmer_orientation);
}
}, terminate);
}, dict_);
}

template
void DBGSSHash::map_to_nodes_with_rc<true>(std::string_view,
const std::function<void(node_index, bool)>&,
Expand Down
1 change: 1 addition & 0 deletions metagraph/tests/annotation/test_aligner_labeled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class LabeledAlignerTest : public ::testing::Test {};

typedef ::testing::Types<std::pair<DBGHashFast, annot::ColumnCompressed<>>,
std::pair<DBGSuccinct, annot::ColumnCompressed<>>,
std::pair<DBGSSHash, annot::ColumnCompressed<>>,
std::pair<DBGHashFast, annot::RowFlatAnnotator>,
std::pair<DBGSuccinct, annot::RowFlatAnnotator>,
std::pair<DBGSuccinct, annot::RowDiffColumnAnnotator>> FewGraphAnnotationPairTypes;
Expand Down
7 changes: 3 additions & 4 deletions metagraph/tests/annotation/test_annotated_dbg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
#include "gtest/gtest.h"

#include "../test_helpers.hpp"
#include "../graph/all/test_dbg_helpers.hpp"

#include "common/threads/threading.hpp"
#include "common/vectors/bit_vector_dyn.hpp"
#include "common/vectors/vector_algorithm.hpp"
#include "annotation/representation/column_compressed/annotate_column_compressed.hpp"
#include "graph/representation/bitmap/dbg_bitmap.hpp"
#include "graph/representation/hash/dbg_hash_string.hpp"
#include "graph/representation/hash/dbg_hash_ordered.hpp"
#include "graph/representation/hash/dbg_hash_fast.hpp"

#define protected public
#define private public
Expand Down Expand Up @@ -987,6 +984,7 @@ typedef ::testing::Types<std::pair<DBGBitmap, annot::ColumnCompressed<>>,
std::pair<DBGHashOrdered, annot::ColumnCompressed<>>,
std::pair<DBGHashFast, annot::ColumnCompressed<>>,
std::pair<DBGSuccinct, annot::ColumnCompressed<>>,
std::pair<DBGSSHash, annot::ColumnCompressed<>>,
std::pair<DBGBitmap, annot::RowFlatAnnotator>,
std::pair<DBGHashString, annot::RowFlatAnnotator>,
std::pair<DBGHashOrdered, annot::RowFlatAnnotator>,
Expand Down Expand Up @@ -1016,6 +1014,7 @@ class AnnotatedDBGNoNTest : public ::testing::Test {};
typedef ::testing::Types<std::pair<DBGBitmap, annot::ColumnCompressed<>>,
std::pair<DBGHashOrdered, annot::ColumnCompressed<>>,
std::pair<DBGHashFast, annot::ColumnCompressed<>>,
std::pair<DBGSSHash, annot::ColumnCompressed<>>,
std::pair<DBGBitmap, annot::RowFlatAnnotator>,
std::pair<DBGHashOrdered, annot::RowFlatAnnotator>,
std::pair<DBGHashFast, annot::RowFlatAnnotator>,
Expand Down
1 change: 1 addition & 0 deletions metagraph/tests/annotation/test_annotated_dbg_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGBitmap, ColumnCompres
template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGHashOrdered, ColumnCompressed<>>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);
template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGHashFast, ColumnCompressed<>>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);
template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGHashString, ColumnCompressed<>>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);
template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGSSHash, ColumnCompressed<>>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);

template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGSuccinct, RowFlatAnnotator>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);
template std::unique_ptr<AnnotatedDBG> build_anno_graph<DBGBitmap, RowFlatAnnotator>(uint64_t, const std::vector<std::string> &, const std::vector<std::string>&, DeBruijnGraph::Mode, bool);
Expand Down
5 changes: 3 additions & 2 deletions metagraph/tests/graph/all/test_dbg_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ void writeFastaFile(const std::vector<std::string>& sequences, const std::string

fastaFile.close();
}

template <>
std::shared_ptr<DeBruijnGraph>
build_graph<DBGSSHash>(uint64_t k,
Expand All @@ -154,8 +155,8 @@ build_graph<DBGSSHash>(uint64_t k,
if (sequences.empty())
return std::make_shared<DBGSSHash>(k, mode);

// use DBGHashString to get contigs for SSHash
auto string_graph = build_graph<DBGHashString>(k, sequences, mode);
// use DBGHashFast to get contigs for SSHash
auto string_graph = build_graph<DBGHashFast>(k, sequences, mode);

std::vector<std::string> contigs;
size_t num_kmers = 0;
Expand Down

0 comments on commit 5af4aca

Please sign in to comment.