From ee58d9a45968b9dd79393cf3ce4d67a22ac684cb Mon Sep 17 00:00:00 2001 From: Mikhail Karasikov Date: Wed, 25 Oct 2023 13:28:41 +0200 Subject: [PATCH] slice_rows without reallocations (#468) --- .../binary_matrix/multi_brwt/brwt.cpp | 160 +++++++----------- .../binary_matrix/multi_brwt/brwt.hpp | 9 +- 2 files changed, 63 insertions(+), 106 deletions(-) diff --git a/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.cpp b/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.cpp index 913bc9e02a..aded390179 100644 --- a/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.cpp +++ b/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.cpp @@ -16,7 +16,7 @@ bool BRWT::get(Row row, Column column) const { assert(row < num_rows()); assert(column < num_columns()); - // if this is a leaf + // if leaf if (!child_nodes_.size()) return (*nonzero_rows_)[row]; @@ -29,40 +29,40 @@ bool BRWT::get(Row row, Column column) const { return child_nodes_[child_node]->get(rank - 1, assignments_.rank(column)); } -Vector> BRWT::get_column_ranks(Row i) const { - assert(i < num_rows()); +std::vector +BRWT::get_rows(const std::vector &row_ids) const { + std::vector rows(row_ids.size()); - // check if the row is empty - uint64_t rank = nonzero_rows_->conditional_rank1(i); - if (!rank) - return {}; + Vector slice; + // expect at least 3 relations per row + slice.reserve(row_ids.size() * 4); - // check whether it is a leaf - if (!child_nodes_.size()) { - assert(assignments_.size() == 1); - // the bit is set - return {{ 0, rank }}; - } + slice_rows(row_ids, &slice); - // check all child nodes - Vector> row; - uint64_t index_in_child = rank - 1; + assert(slice.size() >= row_ids.size()); - for (size_t k = 0; k < child_nodes_.size(); ++k) { - const auto &child = *child_nodes_[k]; + auto row_begin = slice.begin(); - for (auto [col_id, rank] : child.get_column_ranks(index_in_child)) { - row.emplace_back(assignments_.get(k, col_id), rank); - } + for (size_t i = 0; i < rows.size(); ++i) { + // every row in `slice` ends with `-1` + auto row_end = std::find(row_begin, slice.end(), + std::numeric_limits::max()); + rows[i].assign(row_begin, row_end); + row_begin = row_end + 1; } - return row; + + return rows; } -std::vector -BRWT::get_rows(const std::vector &row_ids) const { - std::vector rows(row_ids.size()); +std::vector>> +BRWT::get_column_ranks(const std::vector &row_ids) const { + std::vector>> rows(row_ids.size()); + + Vector> slice; + // expect at least 3 relations per row + slice.reserve(row_ids.size() * 4); - auto slice = slice_rows(row_ids); + slice_rows(row_ids, &slice); assert(slice.size() >= row_ids.size()); @@ -70,8 +70,11 @@ BRWT::get_rows(const std::vector &row_ids) const { for (size_t i = 0; i < rows.size(); ++i) { // every row in `slice` ends with `-1` - auto row_end = std::find(row_begin, slice.end(), - std::numeric_limits::max()); + auto row_end = row_begin; + while (row_end->first != std::numeric_limits::max()) { + ++row_end; + assert(row_end != slice.end()); + } rows[i].assign(row_begin, row_end); row_begin = row_end + 1; } @@ -79,20 +82,13 @@ BRWT::get_rows(const std::vector &row_ids) const { return rows; } -BRWT::SetBitPositions BRWT::slice_rows(const std::vector &row_ids) const { - return slice_rows(row_ids); -} - // If T = Column // return positions of set bits. // If T = std::pair // return positions of set bits with their column ranks. +// Appends to `slice` template -Vector BRWT::slice_rows(const std::vector &row_ids) const { - Vector slice; - // expect at least one relation per row - slice.reserve(row_ids.size() * 2); - +void BRWT::slice_rows(const std::vector &row_ids, Vector *slice) const { T delim; if constexpr(utils::is_pair_v) { delim = std::make_pair(std::numeric_limits::max(), 0); @@ -110,18 +106,18 @@ Vector BRWT::slice_rows(const std::vector &row_ids) const { if constexpr(utils::is_pair_v) { if (uint64_t rank = nonzero_rows_->conditional_rank1(i)) { // only a single column is stored in leaves - slice.emplace_back(0, rank); + slice->emplace_back(0, rank); } } else { if ((*nonzero_rows_)[i]) { // only a single column is stored in leaves - slice.push_back(0); + slice->push_back(0); } } - slice.push_back(delim); + slice->push_back(delim); } - return slice; + return; } // construct indexing for children and the inverse mapping @@ -173,69 +169,53 @@ Vector BRWT::slice_rows(const std::vector &row_ids) const { } } - if (!child_row_ids.size()) - return Vector(row_ids.size(), delim); + if (!child_row_ids.size()) { + for (size_t i = 0; i < row_ids.size(); ++i) { + slice->push_back(delim); + } + return; + } // TODO: query by columns and merge them in the very end to avoid remapping // the same column indexes many times when propagating to the root. // TODO: implement a cache efficient method for merging the columns. // query all children subtrees and get relations from them - std::vector> child_slices(child_nodes_.size()); - std::vector pos(child_nodes_.size()); + size_t slice_start = slice->size(); + + std::vector pos(child_nodes_.size()); for (size_t j = 0; j < child_nodes_.size(); ++j) { - child_slices[j] = child_nodes_[j]->slice_rows(child_row_ids); - // transform column indexes + pos[j] = slice->size(); + child_nodes_[j]->slice_rows(child_row_ids, slice); + + assert(slice->size() >= pos[j] + child_row_ids.size()); - for (auto &v : child_slices[j]) { + // transform column indexes + for (size_t i = pos[j]; i < slice->size(); ++i) { + auto &v = (*slice)[i]; if (v != delim) { auto &col = utils::get_first(v); col = assignments_.get(j, col); } } - assert(child_slices[j].size() >= child_row_ids.size()); - pos[j] = &child_slices[j].front() - 1; } + size_t slice_offset = slice->size(); + for (size_t i = 0; i < row_ids.size(); ++i) { if (!skip_row[i]) { // merge rows from child submatrices - for (auto &p : pos) { - while (*(++p) != delim) { - slice.push_back(*p); + for (size_t &p : pos) { + while ((*slice)[p++] != delim) { + slice->push_back((*slice)[p - 1]); } } } - slice.push_back(delim); + slice->push_back(delim); } - return slice; -} - -std::vector>> -BRWT::get_column_ranks(const std::vector &row_ids) const { - std::vector>> rows(row_ids.size()); - - Vector> slice - = slice_rows>(row_ids); - - assert(slice.size() >= row_ids.size()); - - auto row_begin = slice.begin(); - - for (size_t i = 0; i < rows.size(); ++i) { - // every row in `slice` ends with `-1` - auto row_end = row_begin; - while (row_end->first != std::numeric_limits::max()) { - ++row_end; - assert(row_end != slice.end()); - } - rows[i].assign(row_begin, row_end); - row_begin = row_end + 1; - } - - return rows; + slice->erase(slice->begin() + slice_start, slice->begin() + slice_offset); } std::vector BRWT::get_column(Column column) const { @@ -367,26 +347,6 @@ double BRWT::shrinking_rate() const { return rate_sum / num_nodes; } -uint64_t BRWT::total_column_size() const { - uint64_t total_size = 0; - - BFT([&](const BRWT &node) { - total_size += node.nonzero_rows_->size(); - }); - - return total_size; -} - -uint64_t BRWT::total_num_set_bits() const { - uint64_t total_num_set_bits = 0; - - BFT([&](const BRWT &node) { - total_num_set_bits += node.nonzero_rows_->num_set_bits(); - }); - - return total_num_set_bits; -} - void BRWT::print_tree_structure(std::ostream &os) const { BFT([&os](const BRWT &node) { // print node and its stats diff --git a/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.hpp b/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.hpp index 26c196224d..a488d586e5 100644 --- a/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.hpp +++ b/metagraph/src/annotation/binary_matrix/multi_brwt/brwt.hpp @@ -31,10 +31,7 @@ class BRWT : public BinaryMatrix, public GetEntrySupport { bool get(Row row, Column column) const override; std::vector get_column(Column column) const override; std::vector get_rows(const std::vector &rows) const override; - // get all selected rows appended with -1 and concatenated - SetBitPositions slice_rows(const std::vector &rows) const; // query row and get ranks of each set bit in its column - Vector> get_column_ranks(Row row) const; std::vector>> get_column_ranks(const std::vector &rows) const; @@ -48,17 +45,17 @@ class BRWT : public BinaryMatrix, public GetEntrySupport { double avg_arity() const; uint64_t num_nodes() const; double shrinking_rate() const; - uint64_t total_column_size() const; - uint64_t total_num_set_bits() const; void print_tree_structure(std::ostream &os) const; private: // breadth-first traversal void BFT(std::function callback) const; + // get all selected rows appended with -1 and concatenated // helper function for querying rows in batches + // appends to `slice` template - Vector slice_rows(const std::vector &rows) const; + void slice_rows(const std::vector &rows, Vector *slice) const; // assigns columns to the child nodes RangePartition assignments_;