Skip to content

Commit

Permalink
Optimize memory reallocations (#11112)
Browse files Browse the repository at this point in the history
  • Loading branch information
razdoburdin authored Dec 19, 2024
1 parent 24e19e7 commit 2d1c26b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
12 changes: 10 additions & 2 deletions src/tree/common_row_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,20 @@ class CommonRowPartitioner {
CommonRowPartitioner(Context const* ctx, bst_idx_t num_row, bst_idx_t _base_rowid,
bool is_col_split)
: base_rowid{_base_rowid}, is_col_split_{is_col_split} {
row_set_collection_.Clear();
Reset(ctx, num_row, _base_rowid, is_col_split);
}

void Reset(Context const* ctx, bst_idx_t num_row, bst_idx_t _base_rowid, bool is_col_split) {
base_rowid = _base_rowid;
is_col_split_ = is_col_split;

std::vector<bst_idx_t>& row_indices = *row_set_collection_.Data();
row_indices.resize(num_row);

bst_idx_t* p_row_indices = row_indices.data();
common::Iota(ctx, p_row_indices, p_row_indices + row_indices.size(), base_rowid);
common::Iota(ctx, p_row_indices, p_row_indices + num_row, base_rowid);

row_set_collection_.Clear();
row_set_collection_.Init();

if (is_col_split_) {
Expand Down
12 changes: 9 additions & 3 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,15 +346,21 @@ class HistUpdater {
void InitData(DMatrix *fmat, RegTree const *p_tree) {
monitor_->Start(__func__);
bst_bin_t n_total_bins{0};
partitioner_.clear();
size_t page_idx = 0;
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(ctx_, HistBatch(param_))) {
if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins();
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
fmat->Info().IsColumnSplit());
if (page_idx < partitioner_.size()) {
partitioner_[page_idx].Reset(this->ctx_, page.Size(), page.base_rowid,
fmat->Info().IsColumnSplit());
} else {
partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid,
fmat->Info().IsColumnSplit());
}
page_idx++;
}
histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(),
fmat->Info().IsColumnSplit(), hist_param_);
Expand Down

0 comments on commit 2d1c26b

Please sign in to comment.