-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1b01d57
commit cf8f2db
Showing
4 changed files
with
334 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
#include <ginkgo/core/distributed/row_gatherer.hpp> | ||
|
||
|
||
#include <ginkgo/core/base/dense_cache.hpp> | ||
#include <ginkgo/core/base/precision_dispatch.hpp> | ||
#include <ginkgo/core/distributed/neighborhood_communicator.hpp> | ||
#include <ginkgo/core/matrix/dense.hpp> | ||
|
||
|
||
#include "core/base/dispatch_helper.hpp" | ||
|
||
namespace gko { | ||
namespace experimental { | ||
namespace distributed { | ||
|
||
|
||
/** | ||
* \brief | ||
* \tparam LocalIndexType index type | ||
* \param comm neighborhood communicator | ||
* \param remote_local_idxs the remote indices in their local indexing | ||
* \param recv_sizes the sizes that segregate remote_local_idxs | ||
* \param send_sizes the number of local indices per rank that are part of | ||
* remote_local_idxs on that ranks | ||
* \return the local indices that are part of remote_local_idxs on other ranks, | ||
* ordered by the rank ordering of the communicator | ||
*/ | ||
template <typename LocalIndexType> | ||
array<LocalIndexType> communicate_send_gather_idxs( | ||
mpi::communicator comm, const array<LocalIndexType>& remote_local_idxs, | ||
const array<comm_index_type>& recv_ids, | ||
const std::vector<comm_index_type>& recv_sizes, | ||
const array<comm_index_type>& send_ids, | ||
const std::vector<comm_index_type>& send_sizes) | ||
{ | ||
// create temporary inverse sparse communicator | ||
MPI_Comm sparse_comm; | ||
MPI_Info info; | ||
GKO_ASSERT_NO_MPI_ERRORS(MPI_Info_create(&info)); | ||
GKO_ASSERT_NO_MPI_ERRORS(MPI_Dist_graph_create_adjacent( | ||
comm.get(), send_ids.get_size(), send_ids.get_const_data(), | ||
MPI_UNWEIGHTED, recv_ids.get_size(), recv_ids.get_const_data(), | ||
MPI_UNWEIGHTED, info, false, &sparse_comm)); | ||
GKO_ASSERT_NO_MPI_ERRORS(MPI_Info_free(&info)); | ||
|
||
std::vector<comm_index_type> recv_offsets(recv_sizes.size() + 1); | ||
std::vector<comm_index_type> send_offsets(send_sizes.size() + 1); | ||
std::partial_sum(recv_sizes.data(), recv_sizes.data() + recv_sizes.size(), | ||
recv_offsets.begin() + 1); | ||
std::partial_sum(send_sizes.data(), send_sizes.data() + send_sizes.size(), | ||
send_offsets.begin() + 1); | ||
|
||
array<LocalIndexType> send_gather_idxs(remote_local_idxs.get_executor(), | ||
send_offsets.back()); | ||
|
||
GKO_ASSERT_NO_MPI_ERRORS(MPI_Neighbor_alltoallv( | ||
remote_local_idxs.get_const_data(), recv_sizes.data(), | ||
recv_offsets.data(), mpi::type_impl<LocalIndexType>::get_type(), | ||
send_gather_idxs.get_data(), send_sizes.data(), send_offsets.data(), | ||
mpi::type_impl<LocalIndexType>::get_type(), sparse_comm)); | ||
GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_free(&sparse_comm)); | ||
|
||
return send_gather_idxs; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* b, LinOp* x) const | ||
{ | ||
apply_async(b, x).wait(); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
void RowGatherer<LocalIndexType>::apply_impl(const LinOp* alpha, const LinOp* b, | ||
const LinOp* beta, LinOp* x) const | ||
GKO_NOT_IMPLEMENTED; | ||
|
||
|
||
template <typename LocalIndexType> | ||
std::future<void> RowGatherer<LocalIndexType>::apply_async( | ||
ptr_param<const LinOp> b, ptr_param<LinOp> x) const | ||
{ | ||
auto op = [b = b.get(), x = x.get(), rg = this->shared_from_this()] { | ||
// keep a lock while the send buffer might still be in use | ||
std::lock_guard<std::mutex> lock(rg->mutex); | ||
// dispatch global vector | ||
run<Vector, double, float, std::complex<double>, std::complex<float>>( | ||
b, [&](const auto* b_global) { | ||
using InValueType = | ||
typename std::decay_t<decltype(*b_global)>::value_type; | ||
// dispatch local vector | ||
run<matrix::Dense, double, float, std::complex<double>, | ||
std::complex<float>>(x, [&](auto* x_local) { | ||
using OutValueType = | ||
typename std::decay_t<decltype(*x_local)>::value_type; | ||
auto exec = rg->get_executor(); | ||
|
||
auto b_local = b_global->get_local_vector(); | ||
rg->send_buffer.template init<InValueType>( | ||
b_local->get_executor(), | ||
dim<2>(rg->coll_comm_->get_send_size(), | ||
b_local->get_size()[1])); | ||
b_local->row_gather( | ||
&rg->send_idxs_, | ||
rg->send_buffer.template get<InValueType>()); | ||
|
||
auto recv_ptr = x_local->get_values(); | ||
auto send_ptr = rg->send_buffer.template get<InValueType>() | ||
->get_values(); | ||
|
||
exec->synchronize(); | ||
mpi::contiguous_type in_type( | ||
b_local->get_size()[1], | ||
mpi::type_impl<InValueType>::get_type()); | ||
mpi::contiguous_type out_type( | ||
b_local->get_size()[1], | ||
mpi::type_impl<OutValueType>::get_type()); | ||
auto g = exec->get_scoped_device_id_guard(); | ||
auto req = rg->coll_comm_->i_all_to_all_v( | ||
exec, send_ptr, in_type.get(), recv_ptr, | ||
out_type.get()); | ||
req.wait(); | ||
}); | ||
}); | ||
}; | ||
return std::async(std::launch::async, op); | ||
} | ||
|
||
template <typename LocalIndexType> | ||
template <typename GlobalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer( | ||
std::shared_ptr<const Executor> exec, | ||
std::shared_ptr<const mpi::collective_communicator> coll_comm, | ||
const index_map<LocalIndexType, GlobalIndexType>& imap) | ||
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>( | ||
exec, dim<2>{imap.get_non_local_size(), imap.get_global_size()}), | ||
DistributedBase(coll_comm->get_base_communicator()), | ||
coll_comm_(std::move(coll_comm)), | ||
send_idxs_(exec) | ||
{ | ||
auto comm = coll_comm_->get_base_communicator(); | ||
auto inverse_comm = coll_comm_->create_inverse(); | ||
|
||
send_idxs_.resize_and_reset(coll_comm_->get_send_size()); | ||
inverse_comm | ||
->i_all_to_all_v( | ||
exec, imap.get_remote_local_idxs().get_flat().get_const_data(), | ||
send_idxs_.get_data()) | ||
.wait(); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer(std::shared_ptr<const Executor> exec, | ||
mpi::communicator comm) | ||
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(exec), | ||
DistributedBase(comm), | ||
coll_comm_(std::make_shared<mpi::neighborhood_communicator>(comm)), | ||
send_idxs_(exec) | ||
{} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer(RowGatherer&& o) noexcept | ||
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(o.get_executor()), | ||
DistributedBase(o.get_communicator()) | ||
{ | ||
*this = std::move(o); | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=( | ||
const RowGatherer& o) | ||
{ | ||
if (this != &o) { | ||
this->set_size(o.get_size()); | ||
coll_comm_ = o.coll_comm_; | ||
send_idxs_ = o.send_idxs_; | ||
} | ||
return *this; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=( | ||
RowGatherer&& o) | ||
{ | ||
if (this != &o) { | ||
this->set_size(o.get_size()); | ||
o.set_size({}); | ||
coll_comm_ = std::exchange( | ||
o.coll_comm_, std::make_shared<mpi::neighborhood_communicator>( | ||
o.get_communicator())); | ||
send_idxs_ = std::move(o.send_idxs_); | ||
} | ||
return *this; | ||
} | ||
|
||
|
||
template <typename LocalIndexType> | ||
RowGatherer<LocalIndexType>::RowGatherer(const RowGatherer& o) | ||
: EnableDistributedLinOp<RowGatherer<LocalIndexType>>(o.get_executor()), | ||
DistributedBase(o.get_communicator()) | ||
{ | ||
*this = o; | ||
} | ||
|
||
|
||
#define GKO_DECLARE_ROW_GATHERER(_itype) class RowGatherer<_itype> | ||
|
||
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_ROW_GATHERER); | ||
|
||
#undef GKO_DECLARE_ROW_GATHERER | ||
|
||
|
||
#define GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR(_ltype, _gtype) \ | ||
RowGatherer<_ltype>::RowGatherer( \ | ||
std::shared_ptr<const Executor> exec, \ | ||
std::shared_ptr<const mpi::collective_communicator> coll_comm, \ | ||
const index_map<_ltype, _gtype>& imap) | ||
|
||
GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( | ||
GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR); | ||
|
||
#undef GKO_DECLARE_ROW_GATHERER_CONSTRUCTOR | ||
|
||
|
||
} // namespace distributed | ||
} // namespace experimental | ||
} // namespace gko |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors | ||
// | ||
// SPDX-License-Identifier: BSD-3-Clause | ||
|
||
#ifndef GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_ | ||
#define GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_ | ||
|
||
|
||
#include <ginkgo/config.hpp> | ||
|
||
|
||
#if GINKGO_BUILD_MPI | ||
|
||
|
||
#include <future> | ||
|
||
|
||
#include <ginkgo/core/base/dense_cache.hpp> | ||
#include <ginkgo/core/base/mpi.hpp> | ||
#include <ginkgo/core/distributed/collective_communicator.hpp> | ||
#include <ginkgo/core/distributed/index_map.hpp> | ||
#include <ginkgo/core/distributed/lin_op.hpp> | ||
|
||
|
||
namespace gko { | ||
namespace experimental { | ||
namespace distributed { | ||
|
||
|
||
template <typename LocalIndexType> | ||
class RowGatherer final | ||
: public EnableDistributedLinOp<RowGatherer<LocalIndexType>>, | ||
public DistributedBase, | ||
public std::enable_shared_from_this<RowGatherer<LocalIndexType>> { | ||
friend class EnableDistributedPolymorphicObject<RowGatherer, LinOp>; | ||
|
||
public: | ||
std::future<void> apply_async(ptr_param<const LinOp> b, | ||
ptr_param<LinOp> x) const; | ||
|
||
template <typename GlobalIndexType = int64, | ||
typename = std::enable_if_t<sizeof(GlobalIndexType) >= | ||
sizeof(LocalIndexType)>> | ||
static std::shared_ptr<RowGatherer> create( | ||
std::shared_ptr<const Executor> exec, | ||
std::shared_ptr<const mpi::collective_communicator> coll_comm, | ||
const index_map<LocalIndexType, GlobalIndexType>& imap) | ||
{ | ||
return std::shared_ptr<RowGatherer>( | ||
new RowGatherer(std::move(exec), std::move(coll_comm), imap)); | ||
} | ||
|
||
static std::shared_ptr<RowGatherer> create( | ||
std::shared_ptr<const Executor> exec, mpi::communicator comm) | ||
{ | ||
return std::shared_ptr<RowGatherer>( | ||
new RowGatherer(std::move(exec), std::move(comm))); | ||
} | ||
|
||
RowGatherer(const RowGatherer& o); | ||
|
||
RowGatherer(RowGatherer&& o) noexcept; | ||
|
||
RowGatherer& operator=(const RowGatherer& o); | ||
|
||
RowGatherer& operator=(RowGatherer&& o); | ||
|
||
protected: | ||
void apply_impl(const LinOp* b, LinOp* x) const override; | ||
void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta, | ||
LinOp* x) const override; | ||
|
||
private: | ||
template <typename GlobalIndexType> | ||
RowGatherer(std::shared_ptr<const Executor> exec, | ||
std::shared_ptr<const mpi::collective_communicator> coll_comm, | ||
const index_map<LocalIndexType, GlobalIndexType>& imap); | ||
|
||
RowGatherer(std::shared_ptr<const Executor> exec, mpi::communicator comm); | ||
|
||
std::shared_ptr<const mpi::collective_communicator> coll_comm_; | ||
|
||
array<LocalIndexType> send_idxs_; | ||
|
||
detail::AnyDenseCache send_buffer; | ||
|
||
mutable std::mutex mutex; | ||
}; | ||
|
||
|
||
} // namespace distributed | ||
} // namespace experimental | ||
} // namespace gko | ||
|
||
#endif | ||
#endif // GKO_PUBLIC_CORE_DISTRIBUTED_ROW_GATHERER_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters