-
Notifications
You must be signed in to change notification settings - Fork 369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EAMxx: add horizontal average diagnostic field #6788
Open
mahf708
wants to merge
3
commits into
master
Choose a base branch
from
mahf708/eamxx/horiz-avg-diag
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
#include "diagnostics/horiz_avg.hpp" | ||
|
||
#include "share/field/field_utils.hpp" | ||
|
||
namespace scream { | ||
|
||
HorizAvgDiag::HorizAvgDiag(const ekat::Comm &comm, | ||
const ekat::ParameterList ¶ms) | ||
: AtmosphereDiagnostic(comm, params) { | ||
const auto &fname = m_params.get<std::string>("field_name"); | ||
m_diag_name = fname + "_horiz_avg"; | ||
} | ||
|
||
void HorizAvgDiag::set_grids( | ||
const std::shared_ptr<const GridsManager> grids_manager) { | ||
const auto &fn = m_params.get<std::string>("field_name"); | ||
const auto &gn = m_params.get<std::string>("grid_name"); | ||
const auto g = grids_manager->get_grid("Physics"); | ||
|
||
add_field<Required>(fn, gn); | ||
|
||
// first clone the area unscaled, we will scale it later in initialize_impl | ||
m_scaled_area = g->get_geometry_data("area").clone(); | ||
} | ||
|
||
void HorizAvgDiag::initialize_impl(const RunType /*run_type*/) { | ||
using namespace ShortFieldTagsNames; | ||
const auto &f = get_fields_in().front(); | ||
const auto &fid = f.get_header().get_identifier(); | ||
const auto &layout = fid.get_layout(); | ||
|
||
EKAT_REQUIRE_MSG(layout.rank() >= 1 && layout.rank() <= 3, | ||
"Error! Field rank not supported by HorizAvgDiag.\n" | ||
" - field name: " + | ||
fid.name() + | ||
"\n" | ||
" - field layout: " + | ||
layout.to_string() + "\n"); | ||
EKAT_REQUIRE_MSG(layout.tags()[0] == COL, | ||
"Error! HorizAvgDiag diagnostic expects a layout starting " | ||
"with the 'COL' tag.\n" | ||
" - field name : " + | ||
fid.name() + | ||
"\n" | ||
" - field layout: " + | ||
layout.to_string() + "\n"); | ||
|
||
FieldIdentifier d_fid(m_diag_name, layout.clone().strip_dim(COL), | ||
fid.get_units(), fid.get_grid_name()); | ||
m_diagnostic_output = Field(d_fid); | ||
m_diagnostic_output.allocate_view(); | ||
|
||
// scale the area field | ||
auto total_area = field_sum<Real>(m_scaled_area, &m_comm); | ||
m_scaled_area.scale(sp(1.0) / total_area); | ||
} | ||
|
||
void HorizAvgDiag::compute_diagnostic_impl() { | ||
const auto &f = get_fields_in().front(); | ||
const auto &d = m_diagnostic_output; | ||
// Call the horiz_contraction impl that will take care of everything | ||
horiz_contraction<Real>(d, f, m_scaled_area, &m_comm); | ||
} | ||
|
||
} // namespace scream |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#ifndef EAMXX_HORIZ_AVERAGE_HPP | ||
#define EAMXX_HORIZ_AVERAGE_HPP | ||
|
||
#include "share/atm_process/atmosphere_diagnostic.hpp" | ||
|
||
namespace scream { | ||
|
||
/* | ||
* This diagnostic will calculate the area-weighted average of a field | ||
* across the COL tag dimension, producing an N-1 dimensional field | ||
* that is area-weighted average of the input field. | ||
*/ | ||
|
||
class HorizAvgDiag : public AtmosphereDiagnostic { | ||
public: | ||
// Constructors | ||
HorizAvgDiag(const ekat::Comm &comm, const ekat::ParameterList ¶ms); | ||
|
||
// The name of the diagnostic | ||
std::string name() const { return m_diag_name; } | ||
|
||
// Set the grid | ||
void set_grids(const std::shared_ptr<const GridsManager> grids_manager); | ||
|
||
protected: | ||
#ifdef KOKKOS_ENABLE_CUDA | ||
public: | ||
#endif | ||
void compute_diagnostic_impl(); | ||
|
||
protected: | ||
void initialize_impl(const RunType /*run_type*/); | ||
|
||
// Name of each field (because the diagnostic impl is generic) | ||
std::string m_diag_name; | ||
|
||
// Need area field, let's store it scaled by its norm | ||
Field m_scaled_area; | ||
}; | ||
|
||
} // namespace scream | ||
|
||
#endif // EAMXX_HORIZ_AVERAGE_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
170 changes: 170 additions & 0 deletions
170
components/eamxx/src/diagnostics/tests/horiz_avg_test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
#include "catch2/catch.hpp" | ||
#include "diagnostics/register_diagnostics.hpp" | ||
#include "share/field/field_utils.hpp" | ||
#include "share/grid/mesh_free_grids_manager.hpp" | ||
#include "share/util/scream_setup_random_test.hpp" | ||
#include "share/util/scream_universal_constants.hpp" | ||
|
||
namespace scream { | ||
|
||
std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols, | ||
const int nlevs) { | ||
const int num_global_cols = ncols * comm.size(); | ||
|
||
using vos_t = std::vector<std::string>; | ||
ekat::ParameterList gm_params; | ||
gm_params.set("grids_names", vos_t{"Point Grid"}); | ||
auto &pl = gm_params.sublist("Point Grid"); | ||
pl.set<std::string>("type", "point_grid"); | ||
pl.set("aliases", vos_t{"Physics"}); | ||
pl.set<int>("number_of_global_columns", num_global_cols); | ||
pl.set<int>("number_of_vertical_levels", nlevs); | ||
|
||
auto gm = create_mesh_free_grids_manager(comm, gm_params); | ||
gm->build_grids(); | ||
|
||
return gm; | ||
} | ||
|
||
TEST_CASE("horiz_avg") { | ||
using namespace ShortFieldTagsNames; | ||
using namespace ekat::units; | ||
using TeamPolicy = Kokkos::TeamPolicy<Field::device_t::execution_space>; | ||
using TeamMember = typename TeamPolicy::member_type; | ||
using KT = ekat::KokkosTypes<DefaultDevice>; | ||
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>; | ||
|
||
// A numerical tolerance | ||
auto tol = std::numeric_limits<Real>::epsilon() * 100; | ||
|
||
// A world comm | ||
ekat::Comm comm(MPI_COMM_WORLD); | ||
|
||
// A time stamp | ||
util::TimeStamp t0({2024, 1, 1}, {0, 0, 0}); | ||
|
||
// Create a grids manager - single column for these tests | ||
constexpr int nlevs = 3; | ||
constexpr int dim3 = 4; | ||
const int ngcols = 6 * comm.size(); | ||
|
||
auto gm = create_gm(comm, ngcols, nlevs); | ||
auto grid = gm->get_grid("Physics"); | ||
|
||
// Input (randomized) qc | ||
FieldLayout scalar1d_layout{{COL}, {ngcols}}; | ||
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}}; | ||
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols, dim3, nlevs}}; | ||
|
||
FieldIdentifier qc1_fid("qc", scalar1d_layout, kg / kg, grid->name()); | ||
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid->name()); | ||
FieldIdentifier qc3_fid("qc", scalar3d_layout, kg / kg, grid->name()); | ||
|
||
Field qc1(qc1_fid); | ||
Field qc2(qc2_fid); | ||
Field qc3(qc3_fid); | ||
|
||
qc1.allocate_view(); | ||
qc2.allocate_view(); | ||
qc3.allocate_view(); | ||
|
||
// Construct random number generator stuff | ||
using RPDF = std::uniform_real_distribution<Real>; | ||
RPDF pdf(sp(0.0), sp(200.0)); | ||
auto engine = scream::setup_random_test(); | ||
|
||
// Construct the Diagnostics | ||
std::map<std::string, std::shared_ptr<AtmosphereDiagnostic>> diags; | ||
auto &diag_factory = AtmosphereDiagnosticFactory::instance(); | ||
register_diagnostics(); | ||
|
||
ekat::ParameterList params; | ||
REQUIRE_THROWS(diag_factory.create("HorizAvgDiag", comm, | ||
params)); // No 'field_name' parameter | ||
|
||
// Set time for qc and randomize its values | ||
qc1.get_header().get_tracking().update_time_stamp(t0); | ||
qc2.get_header().get_tracking().update_time_stamp(t0); | ||
qc3.get_header().get_tracking().update_time_stamp(t0); | ||
randomize(qc1, engine, pdf); | ||
randomize(qc2, engine, pdf); | ||
randomize(qc3, engine, pdf); | ||
|
||
// Create and set up the diagnostic | ||
params.set("grid_name", grid->name()); | ||
params.set<std::string>("field_name", "qc"); | ||
auto diag1 = diag_factory.create("HorizAvgDiag", comm, params); | ||
auto diag2 = diag_factory.create("HorizAvgDiag", comm, params); | ||
auto diag3 = diag_factory.create("HorizAvgDiag", comm, params); | ||
diag1->set_grids(gm); | ||
diag2->set_grids(gm); | ||
diag3->set_grids(gm); | ||
|
||
// Clone the area field | ||
auto area = grid->get_geometry_data("area").clone(); | ||
|
||
// Test the horiz contraction of qc1 | ||
// Get the diagnostic field | ||
diag1->set_required_field(qc1); | ||
diag1->initialize(t0, RunType::Initial); | ||
diag1->compute_diagnostic(); | ||
auto diag1_f = diag1->get_diagnostic(); | ||
|
||
// Manual calculation | ||
FieldIdentifier diag0_fid("qc_horiz_avg_manual", | ||
scalar1d_layout.clone().strip_dim(COL), kg / kg, | ||
grid->name()); | ||
Field diag0(diag0_fid); | ||
diag0.allocate_view(); | ||
|
||
// calculate total area | ||
Real atot = field_sum<Real>(area, &comm); | ||
// scale the area field | ||
area.scale(1 / atot); | ||
|
||
// calculate weighted avg | ||
horiz_contraction<Real>(diag0, qc1, area, &comm); | ||
// Compare | ||
REQUIRE(views_are_equal(diag1_f, diag0)); | ||
|
||
// Try other known cases | ||
// Set qc1_v to 1.0 to get weighted average of 1.0 | ||
Real wavg = 1; | ||
qc1.deep_copy(wavg); | ||
diag1->compute_diagnostic(); | ||
auto diag1_v2_host = diag1_f.get_view<Real, Host>(); | ||
REQUIRE_THAT(diag1_v2_host(), | ||
Catch::Matchers::WithinRel( | ||
wavg, tol)); // Catch2's floating point comparison | ||
|
||
// other diags | ||
// Set qc2_v to 5.0 to get weighted average of 5.0 | ||
wavg = sp(5.0); | ||
qc2.deep_copy(wavg); | ||
diag2->set_required_field(qc2); | ||
diag2->initialize(t0, RunType::Initial); | ||
diag2->compute_diagnostic(); | ||
auto diag2_f = diag2->get_diagnostic(); | ||
|
||
auto diag2_v_host = diag2_f.get_view<Real *, Host>(); | ||
|
||
for(int i = 0; i < nlevs; ++i) { | ||
REQUIRE_THAT(diag2_v_host(i), Catch::Matchers::WithinRel(wavg, tol)); | ||
} | ||
|
||
// Try a random case with qc3 | ||
auto qc3_v = qc3.get_view<Real ***>(); | ||
FieldIdentifier diag3_manual_fid("qc_horiz_avg_manual", | ||
scalar3d_layout.clone().strip_dim(COL), | ||
kg / kg, grid->name()); | ||
Field diag3_manual(diag3_manual_fid); | ||
diag3_manual.allocate_view(); | ||
horiz_contraction<Real>(diag3_manual, qc3, area, &comm); | ||
diag3->set_required_field(qc3); | ||
diag3->initialize(t0, RunType::Initial); | ||
diag3->compute_diagnostic(); | ||
auto diag3_f = diag3->get_diagnostic(); | ||
REQUIRE(views_are_equal(diag3_f, diag3_manual)); | ||
} | ||
|
||
} // namespace scream |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to be annoying here, but I wanted all three checks to be as close as possible. I will edit these entirely in a future PR, when I address the composability