Skip to content

Commit

Permalink
Merge pull request #2330 from boutproject/pcr_cyclic_copy+asan
Browse files Browse the repository at this point in the history
Add Parallel cyclic reduction Laplacian solver
  • Loading branch information
ZedThree authored May 25, 2021
2 parents a05809a + f24aa25 commit 8ad1d8d
Show file tree
Hide file tree
Showing 10 changed files with 1,397 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ set(BOUT_SOURCES
./src/invert/laplace/impls/multigrid/multigrid_solver.cxx
./src/invert/laplace/impls/naulin/naulin_laplace.cxx
./src/invert/laplace/impls/naulin/naulin_laplace.hxx
./src/invert/laplace/impls/pcr/pcr.cxx
./src/invert/laplace/impls/pcr/pcr.hxx
./src/invert/laplace/impls/pdd/pdd.cxx
./src/invert/laplace/impls/pdd/pdd.hxx
./src/invert/laplace/impls/petsc/petsc_laplace.cxx
Expand Down
1 change: 1 addition & 0 deletions include/invert_laplace.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ constexpr auto LAPLACE_CYCLIC = "cyclic";
constexpr auto LAPLACE_MULTIGRID = "multigrid";
constexpr auto LAPLACE_NAULIN = "naulin";
constexpr auto LAPLACE_IPT = "ipt";
constexpr auto LAPLACE_PCR = "pcr";

// Inversion flags for each boundary
/// Zero-gradient for DC (constant in Z) component. Default is zero value
Expand Down
83 changes: 83 additions & 0 deletions src/invert/laplace/impls/cyclic/cyclic_laplace.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ Field3D LaplaceCyclic::solve(const Field3D& rhs, const Field3D& x0) {
// Solve tridiagonal systems
cr->setCoefs(a3D, b3D, c3D);
cr->solve(bcmplx3D, xcmplx3D);
// verify_solution(a3D,b3D,c3D,bcmplx3D,xcmplx3D,nsys);

// FFT back to real space
BOUT_OMP(parallel) {
Expand Down Expand Up @@ -476,3 +477,85 @@ Field3D LaplaceCyclic::solve(const Field3D& rhs, const Field3D& x0) {

return x;
}

void LaplaceCyclic ::verify_solution(const Matrix<dcomplex>& a_ver,
const Matrix<dcomplex>& b_ver,
const Matrix<dcomplex>& c_ver,
const Matrix<dcomplex>& r_ver,
const Matrix<dcomplex>& x_sol, const int nsys) {
output.write("Verify solution\n");
const int nx = xe - xs + 1; // Number of X points on this processor,
// including boundaries but not guard cells
const int xproc = localmesh->getXProcIndex();
const int yproc = localmesh->getYProcIndex();
const int nprocs = localmesh->getNXPE();
const int myrank = yproc * nprocs + xproc;
Matrix<dcomplex> y_ver(nsys, nx + 2);
Matrix<dcomplex> error(nsys, nx + 2);

MPI_Status status;
Array<MPI_Request> request(4);
Array<dcomplex> sbufup(nsys);
Array<dcomplex> sbufdown(nsys);
Array<dcomplex> rbufup(nsys);
Array<dcomplex> rbufdown(nsys);

// nsys = nmode * ny; // Number of systems of equations to solve
Matrix<dcomplex> x_ver(nsys, nx + 2);

for (int kz = 0; kz < nsys; kz++) {
for (int ix = 0; ix < nx; ix++) {
x_ver(kz, ix + 1) = x_sol(kz, ix);
}
}

if (xproc > 0) {
MPI_Irecv(&rbufdown[0], nsys, MPI_DOUBLE_COMPLEX, myrank - 1, 901, MPI_COMM_WORLD,
&request[1]);
for (int kz = 0; kz < nsys; kz++) {
sbufdown[kz] = x_ver(kz, 1);
}
MPI_Isend(&sbufdown[0], nsys, MPI_DOUBLE_COMPLEX, myrank - 1, 900, MPI_COMM_WORLD,
&request[0]);
}
if (xproc < nprocs - 1) {
MPI_Irecv(&rbufup[0], nsys, MPI_DOUBLE_COMPLEX, myrank + 1, 900, MPI_COMM_WORLD,
&request[3]);
for (int kz = 0; kz < nsys; kz++) {
sbufup[kz] = x_ver(kz, nx);
}
MPI_Isend(&sbufup[0], nsys, MPI_DOUBLE_COMPLEX, myrank + 1, 901, MPI_COMM_WORLD,
&request[2]);
}

if (xproc > 0) {
MPI_Wait(&request[0], &status);
MPI_Wait(&request[1], &status);
for (int kz = 0; kz < nsys; kz++) {
x_ver(kz, 0) = rbufdown[kz];
}
}
if (xproc < nprocs - 1) {
MPI_Wait(&request[2], &status);
MPI_Wait(&request[3], &status);
for (int kz = 0; kz < nsys; kz++) {
x_ver(kz, nx + 1) = rbufup[kz];
}
}

BoutReal max_error = 0.0;
for (int kz = 0; kz < nsys; kz++) {
for (int i = 0; i < nx; i++) {
y_ver(kz, i) = a_ver(kz, i) * x_ver(kz, i) + b_ver(kz, i) * x_ver(kz, i + 1)
+ c_ver(kz, i) * x_ver(kz, i + 2);
error(kz, i) = y_ver(kz, i) - r_ver(kz, i);
max_error = std::max(max_error, std::abs(error(kz, i)));
output.write("abs error {}, r={}, y={}, kz {}, i {}, a={}, b={}, c={}, x-= {}, "
"x={}, x+ = {}\n",
error(kz, i).real(), r_ver(kz, i).real(), y_ver(kz, i).real(), kz, i,
a_ver(kz, i).real(), b_ver(kz, i).real(), c_ver(kz, i).real(),
x_ver(kz, i).real(), x_ver(kz, i + 1).real(), x_ver(kz, i + 2).real());
}
}
output.write("max abs error {}\n", max_error);
}
4 changes: 4 additions & 0 deletions src/invert/laplace/impls/cyclic/cyclic_laplace.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ public:

Field3D solve(const Field3D &b) override {return solve(b,b);}
Field3D solve(const Field3D &b, const Field3D &x0) override;
void verify_solution(const Matrix<dcomplex>& a_ver, const Matrix<dcomplex>& b_ver,
const Matrix<dcomplex>& c_ver, const Matrix<dcomplex>& r_ver,
const Matrix<dcomplex>& x_sol, int nsys);

private:
Field2D Acoef, C1coef, C2coef, Dcoef;

Expand Down
2 changes: 1 addition & 1 deletion src/invert/laplace/impls/makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

BOUT_TOP = ../../../..

DIRS = serial_tri serial_band pdd spt petsc cyclic multigrid naulin petsc3damg iterative_parallel_tri
DIRS = serial_tri serial_band pdd spt petsc cyclic multigrid naulin petsc3damg iterative_parallel_tri pcr

include $(BOUT_TOP)/make.config
8 changes: 8 additions & 0 deletions src/invert/laplace/impls/pcr/makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

BOUT_TOP = ../../../../..

SOURCEC = pcr.cxx
SOURCEH = pcr.hxx
TARGET = lib

include $(BOUT_TOP)/make.config
Loading

0 comments on commit 8ad1d8d

Please sign in to comment.