Skip to content

Commit

Permalink
Simplify XZInterpolation class
Browse files Browse the repository at this point in the history
Reduce duplication by moving some overloads to the base class.
  • Loading branch information
bendudson committed Feb 3, 2024
1 parent cbfd962 commit 6ee9f74
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 105 deletions.
77 changes: 32 additions & 45 deletions include/bout/interpolation_xz.hxx
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
/**************************************************************************
* Copyright 2010-2020 B.D.Dudson, S.Farley, P. Hill, J.T. Omotani, J.T. Parker,
* M.V.Umansky, X.Q.Xu
* Copyright 2010-2024 BOUT++ contributors
*
* Contact: Ben Dudson, [email protected]
* Contact: Ben Dudson, [email protected]
*
* This file is part of BOUT++.
*
Expand All @@ -21,8 +20,8 @@
*
**************************************************************************/

#ifndef __INTERP_XZ_H__
#define __INTERP_XZ_H__
#ifndef INTERP_XZ_H
#define INTERP_XZ_H

#include "bout/mask.hxx"

Expand Down Expand Up @@ -95,20 +94,30 @@ public:
ASSERT1(has_region);
return getRegion();
}
/// Calculate weights using given offsets in X and Z
virtual void calcWeights(const Field3D& delta_x, const Field3D& delta_z,
const std::string& region = "RGN_NOBNDRY") = 0;
virtual void calcWeights(const Field3D& delta_x, const Field3D& delta_z,
const BoutMask& mask,
const std::string& region = "RGN_NOBNDRY") = 0;
void calcWeights(const Field3D& delta_x, const Field3D& delta_z, const BoutMask& mask,
const std::string& region = "RGN_NOBNDRY") {
setMask(mask);
calcWeights(delta_x, delta_z, region);
}

/// Use pre-calculated weights
virtual Field3D interpolate(const Field3D& f,
const std::string& region = "RGN_NOBNDRY") const = 0;
virtual Field3D interpolate(const Field3D& f, const Field3D& delta_x,
const Field3D& delta_z,
const std::string& region = "RGN_NOBNDRY") = 0;
virtual Field3D interpolate(const Field3D& f, const Field3D& delta_x,
const Field3D& delta_z, const BoutMask& mask,
const std::string& region = "RGN_NOBNDRY") = 0;

/// Calculate weights then interpolate
Field3D interpolate(const Field3D& f, const Field3D& delta_x, const Field3D& delta_z,
const std::string& region = "RGN_NOBNDRY") {
calcWeights(delta_x, delta_z, region);
return interpolate(f, region);
}
Field3D interpolate(const Field3D& f, const Field3D& delta_x, const Field3D& delta_z,
const BoutMask& mask, const std::string& region = "RGN_NOBNDRY") {
calcWeights(delta_x, delta_z, mask, region);
return interpolate(f, region);
}

// Interpolate using the field at (x,y+y_offset,z), rather than (x,y,z)
void setYOffset(int offset) { y_offset = offset; }
Expand Down Expand Up @@ -162,18 +171,10 @@ public:

void calcWeights(const Field3D& delta_x, const Field3D& delta_z,
const std::string& region = "RGN_NOBNDRY") override;
void calcWeights(const Field3D& delta_x, const Field3D& delta_z, const BoutMask& mask,
const std::string& region = "RGN_NOBNDRY") override;

// Use precalculated weights
Field3D interpolate(const Field3D& f,
const std::string& region = "RGN_NOBNDRY") const override;
// Calculate weights and interpolate
Field3D interpolate(const Field3D& f, const Field3D& delta_x, const Field3D& delta_z,
const std::string& region = "RGN_NOBNDRY") override;
Field3D interpolate(const Field3D& f, const Field3D& delta_x, const Field3D& delta_z,
const BoutMask& mask,
const std::string& region = "RGN_NOBNDRY") override;

std::vector<ParallelTransform::PositionsAndWeights>
getWeightsForYApproximation(int i, int j, int k, int yoffset) override;
};
Expand Down Expand Up @@ -208,6 +209,10 @@ class XZLagrange4pt : public XZInterpolation {

Field3D t_x, t_z;

BoutReal lagrange_4pt(BoutReal v2m, BoutReal vm, BoutReal vp, BoutReal v2p,
BoutReal offset) const;
BoutReal lagrange_4pt(const BoutReal v[], BoutReal offset) const;

public:
XZLagrange4pt(Mesh* mesh = nullptr) : XZLagrange4pt(0, mesh) {}
XZLagrange4pt(int y_offset = 0, Mesh* mesh = nullptr);
Expand All @@ -218,21 +223,10 @@ public:

void calcWeights(const Field3D& delta_x, const Field3D& delta_z,
const std::string& region = "RGN_NOBNDRY") override;
void calcWeights(const Field3D& delta_x, const Field3D& delta_z, const BoutMask& mask,
const std::string& region = "RGN_NOBNDRY") override;

// Use precalculated weights
Field3D interpolate(const Field3D& f,
const std::string& region = "RGN_NOBNDRY") const override;
// Calculate weights and interpolate
Field3D interpolate(const Field3D& f, const Field3D& delta_x, const Field3D& delta_z,
const std::string& region = "RGN_NOBNDRY") override;
Field3D interpolate(const Field3D& f, const Field3D& delta_x, const Field3D& delta_z,
const BoutMask& mask,
const std::string& region = "RGN_NOBNDRY") override;
BoutReal lagrange_4pt(BoutReal v2m, BoutReal vm, BoutReal vp, BoutReal v2p,
BoutReal offset) const;
BoutReal lagrange_4pt(const BoutReal v[], BoutReal offset) const;
};

class XZBilinear : public XZInterpolation {
Expand All @@ -251,18 +245,10 @@ public:

void calcWeights(const Field3D& delta_x, const Field3D& delta_z,
const std::string& region = "RGN_NOBNDRY") override;
void calcWeights(const Field3D& delta_x, const Field3D& delta_z, const BoutMask& mask,
const std::string& region = "RGN_NOBNDRY") override;

// Use precalculated weights
Field3D interpolate(const Field3D& f,
const std::string& region = "RGN_NOBNDRY") const override;
// Calculate weights and interpolate
Field3D interpolate(const Field3D& f, const Field3D& delta_x, const Field3D& delta_z,
const std::string& region = "RGN_NOBNDRY") override;
Field3D interpolate(const Field3D& f, const Field3D& delta_x, const Field3D& delta_z,
const BoutMask& mask,
const std::string& region = "RGN_NOBNDRY") override;
};

class XZInterpolationFactory
Expand All @@ -279,11 +265,12 @@ public:
ReturnType create(const std::string& type, [[maybe_unused]] Options* options) const {
return Factory::create(type, nullptr);
}

static void ensureRegistered();
};

template <class DerivedType>
using RegisterXZInterpolation = XZInterpolationFactory::RegisterInFactory<DerivedType>;

#endif // __INTERP_XZ_H__
using RegisterUnavailableXZInterpolation =
XZInterpolationFactory::RegisterUnavailableInFactory;

#endif // INTERP_XZ_H
19 changes: 0 additions & 19 deletions src/mesh/interpolation/bilinear_xz.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ void XZBilinear::calcWeights(const Field3D& delta_x, const Field3D& delta_z,
}
}

void XZBilinear::calcWeights(const Field3D& delta_x, const Field3D& delta_z,
const BoutMask& mask, const std::string& region) {
setMask(mask);
calcWeights(delta_x, delta_z, region);
}

Field3D XZBilinear::interpolate(const Field3D& f, const std::string& region) const {
ASSERT1(f.getMesh() == localmesh);
Field3D f_interp{emptyFrom(f)};
Expand All @@ -113,16 +107,3 @@ Field3D XZBilinear::interpolate(const Field3D& f, const std::string& region) con
}
return f_interp;
}

Field3D XZBilinear::interpolate(const Field3D& f, const Field3D& delta_x,
const Field3D& delta_z, const std::string& region) {
calcWeights(delta_x, delta_z, region);
return interpolate(f, region);
}

Field3D XZBilinear::interpolate(const Field3D& f, const Field3D& delta_x,
const Field3D& delta_z, const BoutMask& mask,
const std::string& region) {
calcWeights(delta_x, delta_z, mask, region);
return interpolate(f, region);
}
19 changes: 0 additions & 19 deletions src/mesh/interpolation/hermite_spline_xz.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,6 @@ void XZHermiteSpline::calcWeights(const Field3D& delta_x, const Field3D& delta_z
}
}

void XZHermiteSpline::calcWeights(const Field3D& delta_x, const Field3D& delta_z,
const BoutMask& mask, const std::string& region) {
setMask(mask);
calcWeights(delta_x, delta_z, region);
}

/*!
* Return position and weight of points needed to approximate the function value at the
* point that the field line through (i,j,k) meets the (j+1)-plane. For the case where
Expand Down Expand Up @@ -223,16 +217,3 @@ Field3D XZHermiteSpline::interpolate(const Field3D& f, const std::string& region
}
return f_interp;
}

Field3D XZHermiteSpline::interpolate(const Field3D& f, const Field3D& delta_x,
const Field3D& delta_z, const std::string& region) {
calcWeights(delta_x, delta_z, region);
return interpolate(f, region);
}

Field3D XZHermiteSpline::interpolate(const Field3D& f, const Field3D& delta_x,
const Field3D& delta_z, const BoutMask& mask,
const std::string& region) {
calcWeights(delta_x, delta_z, mask, region);
return interpolate(f, region);
}
19 changes: 0 additions & 19 deletions src/mesh/interpolation/lagrange_4pt_xz.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,6 @@ void XZLagrange4pt::calcWeights(const Field3D& delta_x, const Field3D& delta_z,
}
}

void XZLagrange4pt::calcWeights(const Field3D& delta_x, const Field3D& delta_z,
const BoutMask& mask, const std::string& region) {
setMask(mask);
calcWeights(delta_x, delta_z, region);
}

Field3D XZLagrange4pt::interpolate(const Field3D& f, const std::string& region) const {

ASSERT1(f.getMesh() == localmesh);
Expand Down Expand Up @@ -132,19 +126,6 @@ Field3D XZLagrange4pt::interpolate(const Field3D& f, const std::string& region)
return f_interp;
}

Field3D XZLagrange4pt::interpolate(const Field3D& f, const Field3D& delta_x,
const Field3D& delta_z, const std::string& region) {
calcWeights(delta_x, delta_z, region);
return interpolate(f, region);
}

Field3D XZLagrange4pt::interpolate(const Field3D& f, const Field3D& delta_x,
const Field3D& delta_z, const BoutMask& mask,
const std::string& region) {
calcWeights(delta_x, delta_z, mask, region);
return interpolate(f, region);
}

// 4-point Lagrangian interpolation
// offset must be between 0 and 1
BoutReal XZLagrange4pt::lagrange_4pt(const BoutReal v2m, const BoutReal vm,
Expand Down
6 changes: 3 additions & 3 deletions src/mesh/interpolation_xz.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ const Field3D interpolate(const Field3D& f, const Field3D& delta_x,
const Field3D& delta_z) {
TRACE("Interpolating 3D field");
XZLagrange4pt interpolateMethod{f.getMesh()};
return interpolateMethod.interpolate(f, delta_x, delta_z);
// Cast to base pointer so virtual function overload is resolved
return static_cast<XZInterpolation*>(&interpolateMethod)
->interpolate(f, delta_x, delta_z);
}

const Field3D interpolate(const Field2D& f, const Field3D& delta_x,
Expand Down Expand Up @@ -84,8 +86,6 @@ const Field3D interpolate(const Field2D& f, const Field3D& delta_x) {
return result;
}

void XZInterpolationFactory::ensureRegistered() {}

namespace {
RegisterXZInterpolation<XZHermiteSpline> registerinterphermitespline{"hermitespline"};
RegisterXZInterpolation<XZMonotonicHermiteSpline> registerinterpmonotonichermitespline{
Expand Down

0 comments on commit 6ee9f74

Please sign in to comment.