Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup BOUT_HOST_DEVICE qualifiers #3040

Open
wants to merge 2 commits into
base: next
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions include/bout/field2d.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,12 @@ public:
return std::end(getRegion("RGN_ALL"));
};

BoutReal& BOUT_HOST_DEVICE operator[](const Ind2D& d) { return data[d.ind]; }
const BoutReal& BOUT_HOST_DEVICE operator[](const Ind2D& d) const {
return data[d.ind];
}
BoutReal& BOUT_HOST_DEVICE operator[](const Ind3D& d);
BoutReal& operator[](const Ind2D& d) { return data[d.ind]; }
const BoutReal& operator[](const Ind2D& d) const { return data[d.ind]; }
BoutReal& operator[](const Ind3D& d);
// const BoutReal& operator[](const Ind3D &d) const;

const BoutReal& BOUT_HOST_DEVICE operator[](const Ind3D& d) const;
const BoutReal& operator[](const Ind3D& d) const;
/*!
* Access to the underlying data array.
*
Expand All @@ -199,7 +197,7 @@ public:
* If CHECK > 2 then both \p jx and \p jy are bounds checked. This will
* significantly reduce performance.
*/
BOUT_HOST_DEVICE inline BoutReal& operator()(int jx, int jy) {
inline BoutReal& operator()(int jx, int jy) {
#if CHECK > 2 && !BOUT_HAS_CUDA
if (!isAllocated()) {
throw BoutException("Field2D: () operator on empty data");
Expand All @@ -213,7 +211,7 @@ public:

return data[jx * ny + jy];
}
BOUT_HOST_DEVICE inline const BoutReal& operator()(int jx, int jy) const {
inline const BoutReal& operator()(int jx, int jy) const {
#if CHECK > 2 && !BOUT_HAS_CUDA
if (!isAllocated()) {
throw BoutException("Field2D: () operator on empty data");
Expand All @@ -232,10 +230,8 @@ public:
* DIrect access to underlying array. This version is for compatibility
* with Field3D objects
*/
BOUT_HOST_DEVICE BoutReal& operator()(int jx, int jy, int UNUSED(jz)) {
return operator()(jx, jy);
}
BOUT_HOST_DEVICE const BoutReal& operator()(int jx, int jy, int UNUSED(jz)) const {
BoutReal& operator()(int jx, int jy, int UNUSED(jz)) { return operator()(jx, jy); }
const BoutReal& operator()(int jx, int jy, int UNUSED(jz)) const {
return operator()(jx, jy);
}

Expand Down Expand Up @@ -357,7 +353,7 @@ inline Field2D DC(const Field2D& f) { return f; }
/// Returns a reference to the time-derivative of a field \p f
///
/// Wrapper around member function f.timeDeriv()
BOUT_HOST_DEVICE inline Field2D& ddt(Field2D& f) { return *(f.timeDeriv()); }
inline Field2D& ddt(Field2D& f) { return *(f.timeDeriv()); }

/// toString template specialisation
/// Defined in utils.hxx
Expand Down
18 changes: 8 additions & 10 deletions include/bout/field3d.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ public:
* The first time this is called, a new field will be
* allocated. Subsequent calls return the same field
*/
BOUT_HOST_DEVICE Field3D* timeDeriv();
Field3D* timeDeriv();

/*!
* Return the number of nx points
Expand Down Expand Up @@ -330,16 +330,14 @@ public:
return std::end(getRegion("RGN_ALL"));
};

BoutReal& BOUT_HOST_DEVICE operator[](const Ind3D& d) { return data[d.ind]; }
const BoutReal& BOUT_HOST_DEVICE operator[](const Ind3D& d) const {
return data[d.ind];
}
BoutReal& operator[](const Ind3D& d) { return data[d.ind]; }
const BoutReal& operator[](const Ind3D& d) const { return data[d.ind]; }

BoutReal& BOUT_HOST_DEVICE operator()(const IndPerp& d, int jy);
const BoutReal& BOUT_HOST_DEVICE operator()(const IndPerp& d, int jy) const;
BoutReal& operator()(const IndPerp& d, int jy);
const BoutReal& operator()(const IndPerp& d, int jy) const;

BoutReal& BOUT_HOST_DEVICE operator()(const Ind2D& d, int jz);
const BoutReal& BOUT_HOST_DEVICE operator()(const Ind2D& d, int jz) const;
BoutReal& operator()(const Ind2D& d, int jz);
const BoutReal& operator()(const Ind2D& d, int jz) const;

/*!
* Direct access to the underlying data array
Expand Down Expand Up @@ -636,7 +634,7 @@ inline void invalidateGuards(Field3D& UNUSED(var)) {}
/// Returns a reference to the time-derivative of a field \p f
///
/// Wrapper around member function f.timeDeriv()
BOUT_HOST_DEVICE inline Field3D& ddt(Field3D& f) { return *(f.timeDeriv()); }
inline Field3D& ddt(Field3D& f) { return *(f.timeDeriv()); }

/// toString template specialisation
/// Defined in utils.hxx
Expand Down
2 changes: 1 addition & 1 deletion include/bout/mesh.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ public:
}

/// Converts an Ind3D to an Ind2D representing a 2D index using a lookup -- to be used with care
BOUT_HOST_DEVICE Ind2D map3Dto2D(const Ind3D& ind3D) {
Ind2D map3Dto2D(const Ind3D& ind3D) {
return {indexLookup3Dto2D[ind3D.ind], LocalNy, 1};
}

Expand Down
7 changes: 6 additions & 1 deletion include/bout/utils.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,12 @@ inline BoutReal randomu() {
* i.e. t * t
*/
template <typename T>
BOUT_HOST_DEVICE inline T SQ(const T& t) {
inline T SQ(const T& t) {
return t * t;
}

template <>
BOUT_HOST_DEVICE inline BoutReal SQ(const BoutReal& t) {
return t * t;
}

Expand Down
6 changes: 3 additions & 3 deletions src/field/field2d.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ Field2D& Field2D::allocate() {
return *this;
}

BOUT_HOST_DEVICE Field2D* Field2D::timeDeriv() {
Field2D* Field2D::timeDeriv() {
if (deriv == nullptr) {
deriv = new Field2D{emptyFrom(*this)};
}
Expand All @@ -129,11 +129,11 @@ const Region<Ind2D>& Field2D::getRegion(const std::string& region_name) const {
}

// Not in header because we need to access fieldmesh
BOUT_HOST_DEVICE BoutReal& Field2D::operator[](const Ind3D& d) {
BoutReal& Field2D::operator[](const Ind3D& d) {
return operator[](fieldmesh->map3Dto2D(d));
}

BOUT_HOST_DEVICE const BoutReal& Field2D::operator[](const Ind3D& d) const {
const BoutReal& Field2D::operator[](const Ind3D& d) const {
return operator[](fieldmesh->map3Dto2D(d));
}

Expand Down
2 changes: 1 addition & 1 deletion src/field/field3d.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Field3D& Field3D::allocate() {
return *this;
}

BOUT_HOST_DEVICE Field3D* Field3D::timeDeriv() {
Field3D* Field3D::timeDeriv() {
if (deriv == nullptr) {
deriv = new Field3D{emptyFrom(*this)};
}
Expand Down
4 changes: 2 additions & 2 deletions src/field/field_data.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ CELL_LOC FieldData::getLocation() const {
return location;
}

BOUT_HOST_DEVICE Coordinates* FieldData::getCoordinates() const {
Coordinates* FieldData::getCoordinates() const {
auto fieldCoordinates_shared = fieldCoordinates.lock();
if (fieldCoordinates_shared) {
return fieldCoordinates_shared.get();
Expand All @@ -242,7 +242,7 @@ BOUT_HOST_DEVICE Coordinates* FieldData::getCoordinates() const {
return fieldCoordinates.lock().get();
}

BOUT_HOST_DEVICE Coordinates* FieldData::getCoordinates(CELL_LOC loc) const {
Coordinates* FieldData::getCoordinates(CELL_LOC loc) const {
if (loc == CELL_DEFAULT) {
return getCoordinates();
}
Expand Down