diff --git a/include/quda.h b/include/quda.h index da78a50a3b..e8242a2d0d 100644 --- a/include/quda.h +++ b/include/quda.h @@ -464,6 +464,10 @@ extern "C" { /** The t0 parameter for distance preconditioning, the timeslice where the source is located */ int distance_pc_t0; + /** Whether to use the smeared gauge field for the Dirac operator, usually + when defined as a spatial Laplacian: mainly used in computing Laplacian eigenvectors */ + QudaBoolean use_smeared_gauge; + } QudaInvertParam; // Parameter set for solving eigenvalue problems. @@ -505,10 +509,6 @@ extern "C" { false, but preserve_deflation would be true */ QudaBoolean preserve_evals; - /** Whether to use the smeared gauge field for the Dirac operator - for whose eigenvalues are are computing. */ - bool use_smeared_gauge; - /** What type of Dirac operator we are using **/ /** If !(use_norm_op) && !(use_dagger) use M. **/ /** If use_dagger, use Mdag **/ diff --git a/lib/check_params.h b/lib/check_params.h index cdbe36169b..19518ee24e 100644 --- a/lib/check_params.h +++ b/lib/check_params.h @@ -175,7 +175,6 @@ void printQudaEigParam(QudaEigParam *param) { P(preserve_deflation, QUDA_BOOLEAN_FALSE); P(preserve_deflation_space, 0); P(preserve_evals, QUDA_BOOLEAN_TRUE); - P(use_smeared_gauge, false); P(use_dagger, QUDA_BOOLEAN_FALSE); P(use_norm_op, QUDA_BOOLEAN_FALSE); P(compute_svd, QUDA_BOOLEAN_FALSE); @@ -373,6 +372,7 @@ void printQudaInvertParam(QudaInvertParam *param) { P(twist_flavor, QUDA_TWIST_INVALID); P(laplace3D, INVALID_INT); P(covdev_mu, INVALID_INT); + P(use_smeared_gauge, QUDA_BOOLEAN_FALSE); #else // asqtad and domain wall use mass parameterization if (param->dslash_type == QUDA_STAGGERED_DSLASH || param->dslash_type == QUDA_ASQTAD_DSLASH diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 6fd6382488..386e65c4d7 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -1436,9 +1436,10 @@ namespace quda { void setDiracParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc) { + GaugeField *gaugePtr = (!inv_param->use_smeared_gauge) ? gaugePrecise : gaugeSmeared; double kappa = inv_param->kappa; if (inv_param->dirac_order == QUDA_CPS_WILSON_DIRAC_ORDER) { - kappa *= gaugePrecise->Anisotropy(); + kappa *= gaugePtr->Anisotropy(); } switch (inv_param->dslash_type) { @@ -1528,7 +1529,7 @@ namespace quda { diracParam.matpcType = inv_param->matpc_type; diracParam.dagger = inv_param->dagger; - diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise : gaugePrecise; + diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise : gaugePtr; diracParam.fatGauge = gaugeFatPrecise; diracParam.longGauge = gaugeLongPrecise; diracParam.clover = cloverPrecise; @@ -1562,7 +1563,7 @@ namespace quda { diracParam.commDim[i] = 1; // comms are always on } - if (diracParam.gauge->Precision() != inv_param->cuda_prec_sloppy) + if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_sloppy)) errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(), inv_param->cuda_prec_sloppy); } @@ -1580,7 +1581,7 @@ namespace quda { diracParam.commDim[i] = 1; // comms are always on } - if (diracParam.gauge->Precision() != inv_param->cuda_prec_refinement_sloppy) + if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_refinement_sloppy)) errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(), inv_param->cuda_prec_refinement_sloppy); } @@ -1612,12 +1613,12 @@ namespace quda { diracParam.gauge = gaugeFatPrecondition; } - if (diracParam.gauge->Precision() != inv_param->cuda_prec_precondition) + if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_precondition)) errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(), inv_param->cuda_prec_precondition); } - void setDiracEigParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc, bool use_smeared_gauge) + void setDiracEigParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc) { setDiracParam(diracParam, inv_param, pc); @@ -1625,11 +1626,24 @@ namespace quda { diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatExtended : gaugeExtended; diracParam.fatGauge = gaugeFatExtended; diracParam.longGauge = gaugeLongExtended; - } else if (use_smeared_gauge) { + } else if (inv_param->use_smeared_gauge) { if (!gaugeSmeared) errorQuda("No smeared gauge field present"); if (inv_param->dslash_type == QUDA_LAPLACE_DSLASH) { if (gaugeSmeared->GhostExchange() == QUDA_GHOST_EXCHANGE_EXTENDED) { - GaugeFieldParam gauge_param(*gaugePrecise); + GaugeFieldParam gauge_param((gaugePrecise)? *gaugePrecise : *gaugeSmeared); + if (!gaugePrecise){ + for (int k=0;kdslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise : + param->use_smeared_gauge ? gaugeSmeared : gaugePrecise; if (U == nullptr) @@ -2415,7 +2430,7 @@ quda::GaugeField *checkGauge(QudaInvertParam *param) errorQuda("Solve precision %d doesn't match gauge precision %d", param->cuda_prec, U->Precision()); } - if (param->dslash_type != QUDA_ASQTAD_DSLASH) { + if (param->dslash_type != QUDA_ASQTAD_DSLASH && !param->use_smeared_gauge) { if (param->cuda_prec_sloppy != gaugeSloppy->Precision() || param->cuda_prec_precondition != gaugePrecondition->Precision() || param->cuda_prec_refinement_sloppy != gaugeRefinement->Precision() @@ -2433,7 +2448,7 @@ quda::GaugeField *checkGauge(QudaInvertParam *param) if (gaugeRefinement == nullptr) errorQuda("Refinement gauge field doesn't exist"); if (gaugeEigensolver == nullptr) errorQuda("Refinement gauge field doesn't exist"); if (param->overlap && gaugeExtended == nullptr) errorQuda("Extended gauge field doesn't exist"); - } else { + } else if (!param->use_smeared_gauge) { if (gaugeLongPrecise == nullptr) errorQuda("Precise gauge long field doesn't exist"); if (param->cuda_prec_sloppy != gaugeFatSloppy->Precision() @@ -2585,10 +2600,9 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam // Create the dirac operator with a sloppy and a precon. bool pc_solve = (inv_param->solve_type == QUDA_DIRECT_PC_SOLVE) || (inv_param->solve_type == QUDA_NORMOP_PC_SOLVE); - createDiracWithEig(d, dSloppy, dPre, dEig, *inv_param, pc_solve, eig_param->use_smeared_gauge); + createDiracWithEig(d, dSloppy, dPre, dEig, *inv_param, pc_solve); Dirac &dirac = *dEig; //------------------------------------------------------ - // Construct vectors //------------------------------------------------------ // Create host wrappers around application vector set diff --git a/lib/solve.cpp b/lib/solve.cpp index 79fa868633..e999c3f28c 100644 --- a/lib/solve.cpp +++ b/lib/solve.cpp @@ -317,8 +317,7 @@ namespace quda getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE); } - void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam ¶m, bool pc_solve, - bool use_smeared_gauge); + void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam ¶m, bool pc_solve); extern std::vector solutionResident; @@ -349,8 +348,7 @@ namespace quda // Create the dirac operator and operators for sloppy, precondition, // and an eigensolver - createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, param, pc_solve, - param.eig_param ? static_cast(param.eig_param)->use_smeared_gauge : false); + createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, param, pc_solve); // wrap CPU host side pointers ColorSpinorParam cpuParam(hp_b[0], param, u.X(), pc_solution, param.input_location); diff --git a/tests/staggered_eigensolve_test.cpp b/tests/staggered_eigensolve_test.cpp index 95ebff044b..8464eb5ef6 100644 --- a/tests/staggered_eigensolve_test.cpp +++ b/tests/staggered_eigensolve_test.cpp @@ -163,7 +163,7 @@ std::vector eigensolve(test_t test_param) eig_inv_param.solution_type = eig_param.use_pc ? QUDA_MATPC_SOLUTION : QUDA_MAT_SOLUTION; // whether we are using the resident smeared gauge or not - eig_param.use_smeared_gauge = gauge_smear; + eig_param.invert_param->use_smeared_gauge = (gauge_smear ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE); if (dslash_type == QUDA_LAPLACE_DSLASH) { int dimension = laplace3D < 4 ? 3 : 4;