-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcomm-nccl.hpp
27 lines (20 loc) · 1.09 KB
/
comm-nccl.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <comm-mpi.hpp>
#include <cuda_runtime_api.h>
#include <nccl.h>
class ColCommNCCL : public ColCommMPI {
private:
ncclComm_t MergeNCCL;
std::vector<ncclComm_t> NeighborNCCL;
ncclComm_t AllReduceNCCL;
ncclComm_t DupNCCL;
std::vector<ncclComm_t> allocedNCCL;
template<typename T> inline void level_merge(T* data, long long len, cudaStream_t stream) const;
template<typename T> inline void level_sum(T* data, long long len, cudaStream_t stream) const;
template<typename T> inline void neighbor_bcast(T* data, const long long box_dims[], cudaStream_t stream) const;
template<typename T> inline void neighbor_reduce(T* data, const long long box_dims[], cudaStream_t stream) const;
public:
ColCommNCCL() : ColCommMPI(), MergeNCCL(nullptr), NeighborNCCL(), AllReduceNCCL(nullptr), DupNCCL(nullptr), allocedNCCL() {};
ColCommNCCL(const std::pair<long long, long long> Tree[], std::pair<long long, long long> Mapping[], const long long Rows[], const long long Cols[], MPI_Comm world = MPI_COMM_WORLD);
void free_all_comms();
static int set_device(MPI_Comm world = MPI_COMM_WORLD);
};