Skip to content

Commit

Permalink
adds distributed row-gatherer
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed Apr 4, 2024
1 parent 1b01d57 commit cf8f2db
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 0 deletions.
1 change: 1 addition & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ target_sources(ginkgo
distributed/index_map.cpp
distributed/neighborhood_communicator.cpp
distributed/partition.cpp
distributed/row_gatherer.cpp
factorization/cholesky.cpp
factorization/elimination_forest.cpp
factorization/factorization.cpp
Expand Down
236 changes: 236 additions & 0 deletions core/distributed/row_gatherer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include <ginkgo/core/distributed/row_gatherer.hpp>


#include <ginkgo/core/base/dense_cache.hpp>
#include <ginkgo/core/base/precision_dispatch.hpp>
#include <ginkgo/core/distributed/neighborhood_communicator.hpp>
#include <ginkgo/core/matrix/dense.hpp>


#include "core/base/dispatch_helper.hpp"

namespace gko {
namespace experimental {
namespace distributed {


/**
* \brief
* \tparam LocalIndexType index type
* \param comm neighborhood communicator
* \param remote_local_idxs the remote indices in their local indexing
* \param recv_sizes the sizes that segregate remote_local_idxs
* \param send_sizes the number of local indices per rank that are part of
* remote_local_idxs on that ranks
* \return the local indices that are part of remote_local_idxs on other ranks,
* ordered by the rank ordering of the communicator
*/
template <typename LocalIndexType>
array<LocalIndexType> communicate_send_gather_idxs(
mpi::communicator comm, const array<LocalIndexType>& remote_local_idxs,
const array<comm_index_type>& recv_ids,
const std::vector<comm_index_type>& recv_sizes,
const array<comm_index_type>& send_ids,
const std::vector<comm_index_type>& send_sizes)
{
// create temporary inverse sparse communicator
MPI_Comm sparse_comm;
MPI_Info info;
GKO_ASSERT_NO_MPI_ERRORS(MPI_Info_create(&info));
GKO_ASSERT_NO_MPI_ERRORS(MPI_Dist_graph_create_adjacent(
comm.get(), send_ids.get_size(), send_ids.get_const_data(),
MPI_UNWEIGHTED, recv_ids.get_size(), recv_ids.get_const_data(),
MPI_UNWEIGHTED, info, false, &sparse_comm));
GKO_ASSERT_NO_MPI_ERRORS(MPI_Info_free(&info));

std::vector<comm_index_type> recv_offsets(recv_sizes.size() + 1);
std::vector<comm_index_type> send_offsets(send_sizes.size() + 1);
std::partial_sum(recv_sizes.data(), recv_sizes.data() + recv_sizes.size(),
recv_offsets.begin() + 1);
std::partial_sum(send_sizes.data(), send_sizes.data() + send_sizes.size(),
send_offsets.begin() + 1);

array<LocalIndexType> send_gather_idxs(remote_local_idxs.get_executor(),
send_offsets.back());

GKO_ASSERT_NO_MPI_ERRORS(MPI_Neighbor_alltoallv(
remote_local_idxs.get_const_data(), recv_sizes.data(),
recv_offsets.data(), mpi::type_impl<LocalIndexType>::get_type(),
send_gather_idxs.get_data(), send_sizes.data(), send_offsets.data(),
mpi::type_impl<LocalIndexType>::get_type(), sparse_comm));
GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_free(&sparse_comm));

return send_gather_idxs;
}


template <typename LocalIndexType>
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* b, LinOp* x) const
{
apply_async(b, x).wait();
}


template <typename LocalIndexType>
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* alpha, const LinOp* b,
const LinOp* beta, LinOp* x) const
GKO_NOT_IMPLEMENTED;


template <typename LocalIndexType>
std::future<void> RowGatherer<LocalIndexType>::apply_async(
ptr_param<const LinOp> b, ptr_param<LinOp> x) const
{
auto op = [b = b.get(), x = x.get(), rg = this->shared_from_this()] {
// keep a lock while the send buffer might still be in use
std::lock_guard<std::mutex> lock(rg->mutex);
// dispatch global vector
run<Vector, double, float, std::complex<double>, std::complex<float>>(
b, [&](const auto* b_global) {
using InValueType =
typename std::decay_t<decltype(*b_global)>::value_type;
// dispatch local vector
run<matrix::Dense, double, float, std::complex<double>,
std::complex<float>>(x, [&](auto* x_local) {
using OutValueType =
typename std::decay_t<decltype(*x_local)>::value_type;
auto exec = rg->get_executor();

auto b_local = b_global->get_local_vector();
rg->send_buffer.template init<InValueType>(
b_local->get_executor(),
dim<2>(rg->coll_comm_->get_send_size(),
b_local->get_size()[1]));
b_local->row_gather(
&rg->send_idxs_,
rg->send_buffer.template get<InValueType>());

auto recv_ptr = x_local->get_values();
auto send_ptr = rg->send_buffer.template get<InValueType>()
->get_values();

exec->synchronize();
mpi::contiguous_type in_type(
b_local->get_size()[1],
mpi::type_impl<InValueType>::get_type());
mpi::contiguous_type out_type(
b_local->get_size()[1],
mpi::type_impl<OutValueType>::get_type());
auto g = exec->get_scoped_device_id_guard();
auto req = rg->coll_comm_->i_all_to_all_v(
exec, send_ptr, in_type.get(), recv_ptr,
out_type.get());
req.wait();
});
});
};
return std::async(std::launch::async, op);
}

template <typename LocalIndexType>
template <typename GlobalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(
std::shared_ptr<const Executor> exec,
std::shared_ptr<const mpi::collective_communicator> coll_comm,
const index_map<LocalIndexType, GlobalIndexType>& imap)
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(
exec, dim<2>{imap.get_non_local_size(), imap.get_global_size()}),
DistributedBase(coll_comm->get_base_communicator()),
coll_comm_(std::move(coll_comm)),
send_idxs_(exec)
{
auto comm = coll_comm_->get_base_communicator();
auto inverse_comm = coll_comm_->create_inverse();

send_idxs_.resize_and_reset(coll_comm_->get_send_size());
inverse_comm
->i_all_to_all_v(
exec, imap.get_remote_local_idxs().get_flat().get_const_data(),
send_idxs_.get_data())
.wait();
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(std::shared_ptr<const Executor> exec,
mpi::communicator comm)
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(exec),
DistributedBase(comm),
coll_comm_(std::make_shared<mpi::neighborhood_communicator>(comm)),
send_idxs_(exec)
{}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(RowGatherer&& o) noexcept
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(o.get_executor()),
DistributedBase(o.get_communicator())
{
*this = std::move(o);
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=(
const RowGatherer& o)
{
if (this != &o) {
this->set_size(o.get_size());
coll_comm_ = o.coll_comm_;
send_idxs_ = o.send_idxs_;
}
return *this;
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=(
RowGatherer&& o)
{
if (this != &o) {
this->set_size(o.get_size());
o.set_size({});
coll_comm_ = std::exchange(
o.coll_comm_, std::make_shared<mpi::neighborhood_communicator>(
o.get_communicator()));
send_idxs_ = std::move(o.send_idxs_);
}
return *this;
}


template <typename LocalIndexType>
RowGatherer<LocalIndexType>::RowGatherer(const RowGatherer& o)
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(o.get_executor()),
DistributedBase(o.get_communicator())
{
*this = o;
}


#define GKO_DECLARE_ROW_GATHERER(_itype) class RowGatherer<_itype>

GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_ROW_GATHERER);

#undef GKO_DECLARE_ROW_GATHERER


#define GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR(_ltype, _gtype) \
RowGatherer<_ltype>::RowGatherer( \
std::shared_ptr<const Executor> exec, \
std::shared_ptr<const mpi::collective_communicator> coll_comm, \
const index_map<_ltype, _gtype>& imap)

GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR);

#undef GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR


} // namespace distributed
} // namespace experimental
} // namespace gko
96 changes: 96 additions & 0 deletions include/ginkgo/core/distributed/row_gatherer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_
#define GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_


#include <ginkgo/config.hpp>


#if GINKGO_BUILD_MPI


#include <future>


#include <ginkgo/core/base/dense_cache.hpp>
#include <ginkgo/core/base/mpi.hpp>
#include <ginkgo/core/distributed/collective_communicator.hpp>
#include <ginkgo/core/distributed/index_map.hpp>
#include <ginkgo/core/distributed/lin_op.hpp>


namespace gko {
namespace experimental {
namespace distributed {


template <typename LocalIndexType>
class RowGatherer final
: public EnableDistributedLinOp<RowGatherer<LocalIndexType>>,
public DistributedBase,
public std::enable_shared_from_this<RowGatherer<LocalIndexType>> {
friend class EnableDistributedPolymorphicObject<RowGatherer, LinOp>;

public:
std::future<void> apply_async(ptr_param<const LinOp> b,
ptr_param<LinOp> x) const;

template <typename GlobalIndexType = int64,
typename = std::enable_if_t<sizeof(GlobalIndexType) >=
sizeof(LocalIndexType)>>
static std::shared_ptr<RowGatherer> create(
std::shared_ptr<const Executor> exec,
std::shared_ptr<const mpi::collective_communicator> coll_comm,
const index_map<LocalIndexType, GlobalIndexType>& imap)
{
return std::shared_ptr<RowGatherer>(
new RowGatherer(std::move(exec), std::move(coll_comm), imap));
}

static std::shared_ptr<RowGatherer> create(
std::shared_ptr<const Executor> exec, mpi::communicator comm)
{
return std::shared_ptr<RowGatherer>(
new RowGatherer(std::move(exec), std::move(comm)));
}

RowGatherer(const RowGatherer& o);

RowGatherer(RowGatherer&& o) noexcept;

RowGatherer& operator=(const RowGatherer& o);

RowGatherer& operator=(RowGatherer&& o);

protected:
void apply_impl(const LinOp* b, LinOp* x) const override;
void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
LinOp* x) const override;

private:
template <typename GlobalIndexType>
RowGatherer(std::shared_ptr<const Executor> exec,
std::shared_ptr<const mpi::collective_communicator> coll_comm,
const index_map<LocalIndexType, GlobalIndexType>& imap);

RowGatherer(std::shared_ptr<const Executor> exec, mpi::communicator comm);

std::shared_ptr<const mpi::collective_communicator> coll_comm_;

array<LocalIndexType> send_idxs_;

detail::AnyDenseCache send_buffer;

mutable std::mutex mutex;
};


} // namespace distributed
} // namespace experimental
} // namespace gko

#endif
#endif // GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_
1 change: 1 addition & 0 deletions include/ginkgo/ginkgo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@

#include <ginkgo/core/distributed/preconditioner/schwarz.hpp>

#include <ginkgo/core/distributed/row_gatherer.hpp>
#include <ginkgo/core/distributed/vector.hpp>

#include <ginkgo/core/factorization/cholesky.hpp>
Expand Down

0 comments on commit cf8f2db

Please sign in to comment.