Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifications to add use_smeared_gauge to InvertParam #1522

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
6 changes: 5 additions & 1 deletion include/quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,10 @@ extern "C" {
/** The t0 parameter for distance preconditioning, the timeslice where the source is located */
int distance_pc_t0;

/** Whether to use the smeared gauge field for the Dirac operator
for whose eigenvalues are are computing. */
QudaBoolean use_smeared_gauge;

} QudaInvertParam;

// Parameter set for solving eigenvalue problems.
Expand Down Expand Up @@ -507,7 +511,7 @@ extern "C" {

/** Whether to use the smeared gauge field for the Dirac operator
for whose eigenvalues are are computing. */
bool use_smeared_gauge;
QudaBoolean use_smeared_gauge;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change to QudaBoolean from bool? I intentionally used bool with this variable as we intend to deprecate QudaBoolean.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think with this addition of the use_smeared_gauge parameter to QudaInvertParm the QudaEigParam variant is never used now? Can we just delete it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used QudaBoolean because all the other boolean members did...it looked weird to see a difference. But bool is fine. The use_smeared_gauge is needed in the InvertParam because only that gets sent to the setDiracParam functions, where it is needed. It is ok to keep it in the EigParam too. I think I put in a check that these two bools have to be the same.


/** What type of Dirac operator we are using **/
/** If !(use_norm_op) && !(use_dagger) use M. **/
Expand Down
3 changes: 2 additions & 1 deletion lib/check_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void printQudaEigParam(QudaEigParam *param) {
P(preserve_deflation, QUDA_BOOLEAN_FALSE);
P(preserve_deflation_space, 0);
P(preserve_evals, QUDA_BOOLEAN_TRUE);
P(use_smeared_gauge, false);
P(use_smeared_gauge, QUDA_BOOLEAN_FALSE);
P(use_dagger, QUDA_BOOLEAN_FALSE);
P(use_norm_op, QUDA_BOOLEAN_FALSE);
P(compute_svd, QUDA_BOOLEAN_FALSE);
Expand Down Expand Up @@ -373,6 +373,7 @@ void printQudaInvertParam(QudaInvertParam *param) {
P(twist_flavor, QUDA_TWIST_INVALID);
P(laplace3D, INVALID_INT);
P(covdev_mu, INVALID_INT);
P(use_smeared_gauge, QUDA_BOOLEAN_FALSE);
#else
// asqtad and domain wall use mass parameterization
if (param->dslash_type == QUDA_STAGGERED_DSLASH || param->dslash_type == QUDA_ASQTAD_DSLASH
Expand Down
46 changes: 31 additions & 15 deletions lib/interface_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1436,9 +1436,10 @@ namespace quda {

void setDiracParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc)
{
GaugeField *gaugePtr = (!inv_param->use_smeared_gauge) ? gaugePrecise : gaugeSmeared;
double kappa = inv_param->kappa;
if (inv_param->dirac_order == QUDA_CPS_WILSON_DIRAC_ORDER) {
kappa *= gaugePrecise->Anisotropy();
kappa *= gaugePtr->Anisotropy();
}

switch (inv_param->dslash_type) {
Expand Down Expand Up @@ -1528,7 +1529,7 @@ namespace quda {

diracParam.matpcType = inv_param->matpc_type;
diracParam.dagger = inv_param->dagger;
diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise : gaugePrecise;
diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise : gaugePtr;
diracParam.fatGauge = gaugeFatPrecise;
diracParam.longGauge = gaugeLongPrecise;
diracParam.clover = cloverPrecise;
Expand Down Expand Up @@ -1562,7 +1563,7 @@ namespace quda {
diracParam.commDim[i] = 1; // comms are always on
}

if (diracParam.gauge->Precision() != inv_param->cuda_prec_sloppy)
if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_sloppy))
errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(),
inv_param->cuda_prec_sloppy);
}
Expand All @@ -1580,7 +1581,7 @@ namespace quda {
diracParam.commDim[i] = 1; // comms are always on
}

if (diracParam.gauge->Precision() != inv_param->cuda_prec_refinement_sloppy)
if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_refinement_sloppy))
errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(),
inv_param->cuda_prec_refinement_sloppy);
}
Expand Down Expand Up @@ -1612,24 +1613,37 @@ namespace quda {
diracParam.gauge = gaugeFatPrecondition;
}

if (diracParam.gauge->Precision() != inv_param->cuda_prec_precondition)
if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_precondition))
errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(),
inv_param->cuda_prec_precondition);
}

void setDiracEigParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc, bool use_smeared_gauge)
void setDiracEigParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc)
{
setDiracParam(diracParam, inv_param, pc);

if (inv_param->overlap) {
diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatExtended : gaugeExtended;
diracParam.fatGauge = gaugeFatExtended;
diracParam.longGauge = gaugeLongExtended;
} else if (use_smeared_gauge) {
} else if (inv_param->use_smeared_gauge) {
if (!gaugeSmeared) errorQuda("No smeared gauge field present");
if (inv_param->dslash_type == QUDA_LAPLACE_DSLASH) {
if (gaugeSmeared->GhostExchange() == QUDA_GHOST_EXCHANGE_EXTENDED) {
GaugeFieldParam gauge_param(*gaugePrecise);
GaugeFieldParam gauge_param((gaugePrecise)? *gaugePrecise : *gaugeSmeared);
if (!gaugePrecise){
for (int k=0;k<gauge_param.nDim;++k){
gauge_param.x[k]-=2*gauge_param.r[k]; gauge_param.r[k]=0;} // smearedGauge is loaded as extended, so remove extensions
#ifdef MULTI_GPU
int x_face_size = gauge_param.x[1] * gauge_param.x[2] * gauge_param.x[3] / 2;
int y_face_size = gauge_param.x[0] * gauge_param.x[2] * gauge_param.x[3] / 2;
int z_face_size = gauge_param.x[0] * gauge_param.x[1] * gauge_param.x[3] / 2;
int t_face_size = gauge_param.x[0] * gauge_param.x[1] * gauge_param.x[2] / 2;
gauge_param.pad = std::max({x_face_size, y_face_size, z_face_size, t_face_size});
#endif
//gauge_param.link_type = QUDA_WILSON_LINKS;
gauge_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;}
gauge_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;
GaugeField gaugeEig(gauge_param);
copyExtendedGauge(gaugeEig, *gaugeSmeared, QUDA_CUDA_FIELD_LOCATION);
gaugeEig.exchangeGhost();
Expand All @@ -1644,6 +1658,7 @@ namespace quda {
diracParam.fatGauge = gaugeFatEigensolver;
diracParam.longGauge = gaugeLongEigensolver;
}

diracParam.clover = cloverEigensolver;

for (int i = 0; i < 4; i++) { diracParam.commDim[i] = 1; }
Expand Down Expand Up @@ -1697,8 +1712,7 @@ namespace quda {
dRef = Dirac::create(diracRefParam);
}

void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve,
bool use_smeared_gauge)
void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve)
{
DiracParam diracParam;
DiracParam diracSloppyParam;
Expand All @@ -1709,7 +1723,7 @@ namespace quda {
setDiracSloppyParam(diracSloppyParam, &param, pc_solve);
bool pre_comms_flag = (param.schwarz_type != QUDA_INVALID_SCHWARZ) ? false : true;
setDiracPreParam(diracPreParam, &param, pc_solve, pre_comms_flag);
setDiracEigParam(diracEigParam, &param, pc_solve, use_smeared_gauge);
setDiracEigParam(diracEigParam, &param, pc_solve);

d = Dirac::create(diracParam); // create the Dirac operator
dSloppy = Dirac::create(diracSloppyParam);
Expand Down Expand Up @@ -2406,6 +2420,7 @@ void checkClover(QudaInvertParam *param) {
quda::GaugeField *checkGauge(QudaInvertParam *param)
{
quda::GaugeField *U = param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise :
param->use_smeared_gauge ? gaugeSmeared :
gaugePrecise;

if (U == nullptr)
Expand All @@ -2415,7 +2430,7 @@ quda::GaugeField *checkGauge(QudaInvertParam *param)
errorQuda("Solve precision %d doesn't match gauge precision %d", param->cuda_prec, U->Precision());
}

if (param->dslash_type != QUDA_ASQTAD_DSLASH) {
if (param->dslash_type != QUDA_ASQTAD_DSLASH && !param->use_smeared_gauge) {
if (param->cuda_prec_sloppy != gaugeSloppy->Precision()
|| param->cuda_prec_precondition != gaugePrecondition->Precision()
|| param->cuda_prec_refinement_sloppy != gaugeRefinement->Precision()
Expand All @@ -2433,7 +2448,7 @@ quda::GaugeField *checkGauge(QudaInvertParam *param)
if (gaugeRefinement == nullptr) errorQuda("Refinement gauge field doesn't exist");
if (gaugeEigensolver == nullptr) errorQuda("Refinement gauge field doesn't exist");
if (param->overlap && gaugeExtended == nullptr) errorQuda("Extended gauge field doesn't exist");
} else {
} else if (!param->use_smeared_gauge) {
if (gaugeLongPrecise == nullptr) errorQuda("Precise gauge long field doesn't exist");

if (param->cuda_prec_sloppy != gaugeFatSloppy->Precision()
Expand Down Expand Up @@ -2562,6 +2577,8 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam
// Ensure that the parameter structures are sound.
checkInvertParam(inv_param);
checkEigParam(eig_param);
if (inv_param->use_smeared_gauge != eig_param->use_smeared_gauge){
errorQuda("Parameter use_smeared_gauge should be same in eig_param and *(eig_param.invert_param)\n");}

// Check that the gauge field is valid
GaugeField *cudaGauge = checkGauge(inv_param);
Expand All @@ -2585,10 +2602,9 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam

// Create the dirac operator with a sloppy and a precon.
bool pc_solve = (inv_param->solve_type == QUDA_DIRECT_PC_SOLVE) || (inv_param->solve_type == QUDA_NORMOP_PC_SOLVE);
createDiracWithEig(d, dSloppy, dPre, dEig, *inv_param, pc_solve, eig_param->use_smeared_gauge);
createDiracWithEig(d, dSloppy, dPre, dEig, *inv_param, pc_solve);
Dirac &dirac = *dEig;
//------------------------------------------------------

// Construct vectors
//------------------------------------------------------
// Create host wrappers around application vector set
Expand Down
6 changes: 2 additions & 4 deletions lib/solve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,7 @@ namespace quda
getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE);
}

void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve,
bool use_smeared_gauge);
void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve);

extern std::vector<ColorSpinorField> solutionResident;

Expand Down Expand Up @@ -349,8 +348,7 @@ namespace quda

// Create the dirac operator and operators for sloppy, precondition,
// and an eigensolver
createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, param, pc_solve,
param.eig_param ? static_cast<QudaEigParam *>(param.eig_param)->use_smeared_gauge : false);
createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, param, pc_solve);

// wrap CPU host side pointers
ColorSpinorParam cpuParam(hp_b[0], param, u.X(), pc_solution, param.input_location);
Expand Down
3 changes: 2 additions & 1 deletion tests/staggered_eigensolve_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ std::vector<double> eigensolve(test_t test_param)
eig_inv_param.solution_type = eig_param.use_pc ? QUDA_MATPC_SOLUTION : QUDA_MAT_SOLUTION;

// whether we are using the resident smeared gauge or not
eig_param.use_smeared_gauge = gauge_smear;
eig_param.use_smeared_gauge = (gauge_smear ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE);
eig_param.invert_param->use_smeared_gauge = (gauge_smear ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE);

if (dslash_type == QUDA_LAPLACE_DSLASH) {
int dimension = laplace3D < 4 ? 3 : 4;
Expand Down
Loading