Skip to content

Commit

Permalink
slice_rows without reallocations (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
karasikov authored Oct 25, 2023
1 parent 1917eb8 commit ee58d9a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 106 deletions.
160 changes: 60 additions & 100 deletions metagraph/src/annotation/binary_matrix/multi_brwt/brwt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ bool BRWT::get(Row row, Column column) const {
assert(row < num_rows());
assert(column < num_columns());

// if this is a leaf
// if leaf
if (!child_nodes_.size())
return (*nonzero_rows_)[row];

Expand All @@ -29,70 +29,66 @@ bool BRWT::get(Row row, Column column) const {
return child_nodes_[child_node]->get(rank - 1, assignments_.rank(column));
}

Vector<std::pair<BRWT::Column, uint64_t>> BRWT::get_column_ranks(Row i) const {
assert(i < num_rows());
std::vector<BRWT::SetBitPositions>
BRWT::get_rows(const std::vector<Row> &row_ids) const {
std::vector<SetBitPositions> rows(row_ids.size());

// check if the row is empty
uint64_t rank = nonzero_rows_->conditional_rank1(i);
if (!rank)
return {};
Vector<Column> slice;
// expect at least 3 relations per row
slice.reserve(row_ids.size() * 4);

// check whether it is a leaf
if (!child_nodes_.size()) {
assert(assignments_.size() == 1);
// the bit is set
return {{ 0, rank }};
}
slice_rows(row_ids, &slice);

// check all child nodes
Vector<std::pair<BRWT::Column, uint64_t>> row;
uint64_t index_in_child = rank - 1;
assert(slice.size() >= row_ids.size());

for (size_t k = 0; k < child_nodes_.size(); ++k) {
const auto &child = *child_nodes_[k];
auto row_begin = slice.begin();

for (auto [col_id, rank] : child.get_column_ranks(index_in_child)) {
row.emplace_back(assignments_.get(k, col_id), rank);
}
for (size_t i = 0; i < rows.size(); ++i) {
// every row in `slice` ends with `-1`
auto row_end = std::find(row_begin, slice.end(),
std::numeric_limits<Column>::max());
rows[i].assign(row_begin, row_end);
row_begin = row_end + 1;
}
return row;

return rows;
}

std::vector<BRWT::SetBitPositions>
BRWT::get_rows(const std::vector<Row> &row_ids) const {
std::vector<SetBitPositions> rows(row_ids.size());
std::vector<Vector<std::pair<BRWT::Column, uint64_t>>>
BRWT::get_column_ranks(const std::vector<Row> &row_ids) const {
std::vector<Vector<std::pair<Column, uint64_t>>> rows(row_ids.size());

Vector<std::pair<Column, uint64_t>> slice;
// expect at least 3 relations per row
slice.reserve(row_ids.size() * 4);

auto slice = slice_rows(row_ids);
slice_rows(row_ids, &slice);

assert(slice.size() >= row_ids.size());

auto row_begin = slice.begin();

for (size_t i = 0; i < rows.size(); ++i) {
// every row in `slice` ends with `-1`
auto row_end = std::find(row_begin, slice.end(),
std::numeric_limits<Column>::max());
auto row_end = row_begin;
while (row_end->first != std::numeric_limits<Column>::max()) {
++row_end;
assert(row_end != slice.end());
}
rows[i].assign(row_begin, row_end);
row_begin = row_end + 1;
}

return rows;
}

BRWT::SetBitPositions BRWT::slice_rows(const std::vector<Row> &row_ids) const {
return slice_rows<Column>(row_ids);
}

// If T = Column
// return positions of set bits.
// If T = std::pair<Column, uint64_t>
// return positions of set bits with their column ranks.
// Appends to `slice`
template <typename T>
Vector<T> BRWT::slice_rows(const std::vector<Row> &row_ids) const {
Vector<T> slice;
// expect at least one relation per row
slice.reserve(row_ids.size() * 2);

void BRWT::slice_rows(const std::vector<Row> &row_ids, Vector<T> *slice) const {
T delim;
if constexpr(utils::is_pair_v<T>) {
delim = std::make_pair(std::numeric_limits<Column>::max(), 0);
Expand All @@ -110,18 +106,18 @@ Vector<T> BRWT::slice_rows(const std::vector<Row> &row_ids) const {
if constexpr(utils::is_pair_v<T>) {
if (uint64_t rank = nonzero_rows_->conditional_rank1(i)) {
// only a single column is stored in leaves
slice.emplace_back(0, rank);
slice->emplace_back(0, rank);
}
} else {
if ((*nonzero_rows_)[i]) {
// only a single column is stored in leaves
slice.push_back(0);
slice->push_back(0);
}
}
slice.push_back(delim);
slice->push_back(delim);
}

return slice;
return;
}

// construct indexing for children and the inverse mapping
Expand Down Expand Up @@ -173,69 +169,53 @@ Vector<T> BRWT::slice_rows(const std::vector<Row> &row_ids) const {
}
}

if (!child_row_ids.size())
return Vector<T>(row_ids.size(), delim);
if (!child_row_ids.size()) {
for (size_t i = 0; i < row_ids.size(); ++i) {
slice->push_back(delim);
}
return;
}

// TODO: query by columns and merge them in the very end to avoid remapping
// the same column indexes many times when propagating to the root.
// TODO: implement a cache efficient method for merging the columns.

// query all children subtrees and get relations from them
std::vector<Vector<T>> child_slices(child_nodes_.size());
std::vector<const T *> pos(child_nodes_.size());
size_t slice_start = slice->size();

std::vector<size_t> pos(child_nodes_.size());

for (size_t j = 0; j < child_nodes_.size(); ++j) {
child_slices[j] = child_nodes_[j]->slice_rows<T>(child_row_ids);
// transform column indexes
pos[j] = slice->size();
child_nodes_[j]->slice_rows<T>(child_row_ids, slice);

assert(slice->size() >= pos[j] + child_row_ids.size());

for (auto &v : child_slices[j]) {
// transform column indexes
for (size_t i = pos[j]; i < slice->size(); ++i) {
auto &v = (*slice)[i];
if (v != delim) {
auto &col = utils::get_first(v);
col = assignments_.get(j, col);
}
}
assert(child_slices[j].size() >= child_row_ids.size());
pos[j] = &child_slices[j].front() - 1;
}

size_t slice_offset = slice->size();

for (size_t i = 0; i < row_ids.size(); ++i) {
if (!skip_row[i]) {
// merge rows from child submatrices
for (auto &p : pos) {
while (*(++p) != delim) {
slice.push_back(*p);
for (size_t &p : pos) {
while ((*slice)[p++] != delim) {
slice->push_back((*slice)[p - 1]);
}
}
}
slice.push_back(delim);
slice->push_back(delim);
}

return slice;
}

std::vector<Vector<std::pair<BRWT::Column, uint64_t>>>
BRWT::get_column_ranks(const std::vector<Row> &row_ids) const {
std::vector<Vector<std::pair<Column, uint64_t>>> rows(row_ids.size());

Vector<std::pair<Column, uint64_t>> slice
= slice_rows<std::pair<Column, uint64_t>>(row_ids);

assert(slice.size() >= row_ids.size());

auto row_begin = slice.begin();

for (size_t i = 0; i < rows.size(); ++i) {
// every row in `slice` ends with `-1`
auto row_end = row_begin;
while (row_end->first != std::numeric_limits<Column>::max()) {
++row_end;
assert(row_end != slice.end());
}
rows[i].assign(row_begin, row_end);
row_begin = row_end + 1;
}

return rows;
slice->erase(slice->begin() + slice_start, slice->begin() + slice_offset);
}

std::vector<BRWT::Row> BRWT::get_column(Column column) const {
Expand Down Expand Up @@ -367,26 +347,6 @@ double BRWT::shrinking_rate() const {
return rate_sum / num_nodes;
}

uint64_t BRWT::total_column_size() const {
uint64_t total_size = 0;

BFT([&](const BRWT &node) {
total_size += node.nonzero_rows_->size();
});

return total_size;
}

uint64_t BRWT::total_num_set_bits() const {
uint64_t total_num_set_bits = 0;

BFT([&](const BRWT &node) {
total_num_set_bits += node.nonzero_rows_->num_set_bits();
});

return total_num_set_bits;
}

void BRWT::print_tree_structure(std::ostream &os) const {
BFT([&os](const BRWT &node) {
// print node and its stats
Expand Down
9 changes: 3 additions & 6 deletions metagraph/src/annotation/binary_matrix/multi_brwt/brwt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ class BRWT : public BinaryMatrix, public GetEntrySupport {
bool get(Row row, Column column) const override;
std::vector<Row> get_column(Column column) const override;
std::vector<SetBitPositions> get_rows(const std::vector<Row> &rows) const override;
// get all selected rows appended with -1 and concatenated
SetBitPositions slice_rows(const std::vector<Row> &rows) const;
// query row and get ranks of each set bit in its column
Vector<std::pair<Column, uint64_t>> get_column_ranks(Row row) const;
std::vector<Vector<std::pair<Column, uint64_t>>>
get_column_ranks(const std::vector<Row> &rows) const;

Expand All @@ -48,17 +45,17 @@ class BRWT : public BinaryMatrix, public GetEntrySupport {
double avg_arity() const;
uint64_t num_nodes() const;
double shrinking_rate() const;
uint64_t total_column_size() const;
uint64_t total_num_set_bits() const;

void print_tree_structure(std::ostream &os) const;

private:
// breadth-first traversal
void BFT(std::function<void(const BRWT &node)> callback) const;
// get all selected rows appended with -1 and concatenated
// helper function for querying rows in batches
// appends to `slice`
template <typename T>
Vector<T> slice_rows(const std::vector<Row> &rows) const;
void slice_rows(const std::vector<Row> &rows, Vector<T> *slice) const;

// assigns columns to the child nodes
RangePartition assignments_;
Expand Down

0 comments on commit ee58d9a

Please sign in to comment.