Skip to content

Commit

Permalink
EAMxx: add horizontal average diagnostic field
Browse files Browse the repository at this point in the history
  • Loading branch information
mahf708 committed Dec 9, 2024
1 parent b13a08f commit 6ec6f3f
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 1 deletion.
1 change: 1 addition & 0 deletions components/eamxx/src/diagnostics/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set(DIAGNOSTIC_SRCS
field_at_height.cpp
field_at_level.cpp
field_at_pressure_level.cpp
horiz_avg.cpp
longwave_cloud_forcing.cpp
number_path.cpp
potential_temperature.cpp
Expand Down
65 changes: 65 additions & 0 deletions components/eamxx/src/diagnostics/horiz_avg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "diagnostics/horiz_avg.hpp"

#include "share/field/field_utils.hpp"

namespace scream {

HorizAvgDiag::HorizAvgDiag(const ekat::Comm &comm,
const ekat::ParameterList &params)
: AtmosphereDiagnostic(comm, params) {
const auto &fname = m_params.get<std::string>("field_name");
m_diag_name = fname + "_horiz_avg";
}

void HorizAvgDiag::set_grids(
const std::shared_ptr<const GridsManager> grids_manager) {
const auto &fn = m_params.get<std::string>("field_name");
const auto &gn = m_params.get<std::string>("grid_name");
const auto g = grids_manager->get_grid("Physics");

add_field<Required>(fn, gn);

// first clone the area unscaled, we will scale it later in initialize_impl
m_scaled_area = g->get_geometry_data("area").clone();
}

void HorizAvgDiag::initialize_impl(const RunType /*run_type*/) {
using namespace ShortFieldTagsNames;
const auto &f = get_fields_in().front();
const auto &fid = f.get_header().get_identifier();
const auto &layout = fid.get_layout();

EKAT_REQUIRE_MSG(layout.rank() >= 1 && layout.rank() <= 3,
"Error! Field rank not supported by HorizAvgDiag.\n"
" - field name: " +
fid.name() +
"\n"
" - field layout: " +
layout.to_string() + "\n");
EKAT_REQUIRE_MSG(layout.tags()[0] == COL,
"Error! HorizAvgDiag diagnostic expects a layout starting "
"with the 'COL' tag.\n"
" - field name : " +
fid.name() +
"\n"
" - field layout: " +
layout.to_string() + "\n");

FieldIdentifier d_fid(m_diag_name, layout.clone().strip_dim(COL),
fid.get_units(), fid.get_grid_name());
m_diagnostic_output = Field(d_fid);
m_diagnostic_output.allocate_view();

// scale the area field
auto total_area = field_sum<Real>(m_scaled_area, &m_comm);
m_scaled_area.scale(sp(1.0) / total_area);
}

void HorizAvgDiag::compute_diagnostic_impl() {
const auto &f = get_fields_in().front();
const auto &d = m_diagnostic_output;
// Call the horiz_contraction impl that will take care of everything
horiz_contraction<Real>(d, f, m_scaled_area, &m_comm);
}

} // namespace scream
43 changes: 43 additions & 0 deletions components/eamxx/src/diagnostics/horiz_avg.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef EAMXX_HORIZ_AVERAGE_HPP
#define EAMXX_HORIZ_AVERAGE_HPP

#include "share/atm_process/atmosphere_diagnostic.hpp"

namespace scream {

/*
* This diagnostic will calculate the area-weighted average of a field
* across the COL tag dimension, producing an N-1 dimensional field
* that is area-weighted average of the input field.
*/

class HorizAvgDiag : public AtmosphereDiagnostic {
public:
// Constructors
HorizAvgDiag(const ekat::Comm &comm, const ekat::ParameterList &params);

// The name of the diagnostic
std::string name() const { return m_diag_name; }

// Set the grid
void set_grids(const std::shared_ptr<const GridsManager> grids_manager);

protected:
#ifdef KOKKOS_ENABLE_CUDA
public:
#endif
void compute_diagnostic_impl();

protected:
void initialize_impl(const RunType /*run_type*/);

// Name of each field (because the diagnostic impl is generic)
std::string m_diag_name;

// Need area field, let's store it scaled by its norm
Field m_scaled_area;
};

} // namespace scream

#endif // EAMXX_HORIZ_AVERAGE_HPP
2 changes: 2 additions & 0 deletions components/eamxx/src/diagnostics/register_diagnostics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "diagnostics/number_path.hpp"
#include "diagnostics/aerocom_cld.hpp"
#include "diagnostics/atm_backtend.hpp"
#include "diagnostics/horiz_avg.hpp"

namespace scream {

Expand Down Expand Up @@ -51,6 +52,7 @@ inline void register_diagnostics () {
diag_factory.register_product("NumberPath",&create_atmosphere_diagnostic<NumberPathDiagnostic>);
diag_factory.register_product("AeroComCld",&create_atmosphere_diagnostic<AeroComCld>);
diag_factory.register_product("AtmBackTendDiag",&create_atmosphere_diagnostic<AtmBackTendDiag>);
diag_factory.register_product("HorizAvgDiag",&create_atmosphere_diagnostic<HorizAvgDiag>);
}

} // namespace scream
Expand Down
3 changes: 3 additions & 0 deletions components/eamxx/src/diagnostics/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,6 @@ CreateDiagTest(aerocom_cld "aerocom_cld_test.cpp")

# Test atm_tend
CreateDiagTest(atm_backtend "atm_backtend_test.cpp")

# Test horizontal averaging
CreateDiagTest(horiz_avg "horiz_avg_test.cpp")
197 changes: 197 additions & 0 deletions components/eamxx/src/diagnostics/tests/horiz_avg_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#include "catch2/catch.hpp"
#include "diagnostics/register_diagnostics.hpp"
#include "share/field/field_utils.hpp"
#include "share/grid/mesh_free_grids_manager.hpp"
#include "share/util/scream_setup_random_test.hpp"
#include "share/util/scream_universal_constants.hpp"

namespace scream {

std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols,
const int nlevs) {
const int num_global_cols = ncols * comm.size();

using vos_t = std::vector<std::string>;
ekat::ParameterList gm_params;
gm_params.set("grids_names", vos_t{"Point Grid"});
auto &pl = gm_params.sublist("Point Grid");
pl.set<std::string>("type", "point_grid");
pl.set("aliases", vos_t{"Physics"});
pl.set<int>("number_of_global_columns", num_global_cols);
pl.set<int>("number_of_vertical_levels", nlevs);

auto gm = create_mesh_free_grids_manager(comm, gm_params);
gm->build_grids();

return gm;
}

TEST_CASE("horiz_avg") {
using namespace ShortFieldTagsNames;
using namespace ekat::units;
using TeamPolicy = Kokkos::TeamPolicy<Field::device_t::execution_space>;
using TeamMember = typename TeamPolicy::member_type;
using KT = ekat::KokkosTypes<DefaultDevice>;
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>;

// A numerical tolerance
auto tol = std::numeric_limits<Real>::epsilon() * 100;

// A world comm
ekat::Comm comm(MPI_COMM_WORLD);

// A time stamp
util::TimeStamp t0({2024, 1, 1}, {0, 0, 0});

// Create a grids manager - single column for these tests
constexpr int nlevs = 3;
constexpr int dim3 = 4;
const int ngcols = 6 * comm.size();

auto gm = create_gm(comm, ngcols, nlevs);
auto grid = gm->get_grid("Physics");

// Input (randomized) qc
FieldLayout scalar1d_layout{{COL}, {ngcols}};
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}};
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols, dim3, nlevs}};

FieldIdentifier qc1_fid("qc", scalar1d_layout, kg / kg, grid->name());
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid->name());
FieldIdentifier qc3_fid("qc", scalar3d_layout, kg / kg, grid->name());

Field qc1(qc1_fid);
Field qc2(qc2_fid);
Field qc3(qc3_fid);

qc1.allocate_view();
qc2.allocate_view();
qc3.allocate_view();

// Construct random number generator stuff
using RPDF = std::uniform_real_distribution<Real>;
RPDF pdf(sp(0.0), sp(200.0));

auto engine = scream::setup_random_test();

// Construct the Diagnostics
std::map<std::string, std::shared_ptr<AtmosphereDiagnostic>> diags;
auto &diag_factory = AtmosphereDiagnosticFactory::instance();
register_diagnostics();

ekat::ParameterList params;
REQUIRE_THROWS(diag_factory.create("HorizAvgDiag", comm,
params)); // No 'field_name' parameter

// Set time for qc and randomize its values
qc1.get_header().get_tracking().update_time_stamp(t0);
qc2.get_header().get_tracking().update_time_stamp(t0);
qc3.get_header().get_tracking().update_time_stamp(t0);
randomize(qc1, engine, pdf);
randomize(qc2, engine, pdf);
randomize(qc3, engine, pdf);

// Create and set up the diagnostic
params.set("grid_name", grid->name());
params.set<std::string>("field_name", "qc");
auto diag1 = diag_factory.create("HorizAvgDiag", comm, params);
auto diag2 = diag_factory.create("HorizAvgDiag", comm, params);
auto diag3 = diag_factory.create("HorizAvgDiag", comm, params);
diag1->set_grids(gm);
diag2->set_grids(gm);
diag3->set_grids(gm);

auto area = grid->get_geometry_data("area");

diag1->set_required_field(qc1);
diag1->initialize(t0, RunType::Initial);

diag1->compute_diagnostic();
auto diag1_f = diag1->get_diagnostic();

FieldIdentifier diag0_fid("qc_horiz_avg_manual",
scalar1d_layout.clone().strip_dim(COL), kg / kg,
grid->name());
Field diag0(diag0_fid);
diag0.allocate_view();
auto diag0_v = diag0.get_view<Real>();

auto qc1_v = qc1.get_view<Real *>();
auto area_v = area.get_view<const Real *>();

// calculate total area
Real atot = field_sum<Real>(area, &comm);
// calculate weighted avg
Real wavg = sp(0.0);
Kokkos::parallel_reduce(
"HorizAvgDiag::compute_diagnostic_impl::weighted_sum", ngcols,
KOKKOS_LAMBDA(const int icol, Real &local_wavg) {
local_wavg += (area_v[icol] / atot) * qc1_v[icol];
},
wavg);
Kokkos::deep_copy(diag0_v, wavg);

diag1_f.sync_to_host();
auto diag1_v_h = diag1_f.get_view<Real, Host>();
REQUIRE(diag1_v_h() == wavg);

// Try known cases
// Set qc1_v to 1.0 to get weighted average of 1.0
wavg = sp(1.0);
Kokkos::deep_copy(qc1_v, wavg);
diag1->compute_diagnostic();
auto diag1_v2_host = diag1_f.get_view<Real, Host>();
REQUIRE_THAT(diag1_v2_host(),
Catch::Matchers::WithinRel(
wavg, tol)); // Catch2's floating point comparison

// other diags
// Set qc2_v to 5.0 to get weighted average of 5.0
wavg = sp(5.0);
auto qc2_v = qc2.get_view<Real **>();
Kokkos::deep_copy(qc2_v, wavg);

diag2->set_required_field(qc2);
diag2->initialize(t0, RunType::Initial);
diag2->compute_diagnostic();
auto diag2_f = diag2->get_diagnostic();

auto diag2_v_host = diag2_f.get_view<Real *, Host>();

for(int i = 0; i < nlevs; ++i) {
REQUIRE_THAT(diag2_v_host(i), Catch::Matchers::WithinRel(wavg, tol));
}

auto qc3_v = qc3.get_view<Real ***>();
FieldIdentifier diag3_manual_fid("qc_horiz_avg_manual",
scalar3d_layout.clone().strip_dim(COL),
kg / kg, grid->name());
Field diag3_manual(diag3_manual_fid);
diag3_manual.allocate_view();
auto diag3_manual_v = diag3_manual.get_view<Real **>();
// calculate diag3_manual by hand
auto p = ESU::get_default_team_policy(dim3 * nlevs, ngcols);
Kokkos::parallel_for(
"HorizAvgDiag::compute_diagnostic_impl::manual_diag3", p,
KOKKOS_LAMBDA(const TeamMember &m) {
const int idx = m.league_rank();
const int j = idx / nlevs;
const int k = idx % nlevs;
Real sum = sp(0.0);
Kokkos::parallel_reduce(
Kokkos::TeamThreadRange(m, ngcols),
[&](const int icol, Real &accum) {
accum += (area_v(icol) / atot) * qc3_v(icol, j, k);
},
sum);
Kokkos::single(Kokkos::PerTeam(m),
[&]() { diag3_manual_v(j, k) = sum; });
});
diag3->set_required_field(qc3);
diag3->initialize(t0, RunType::Initial);
diag3->compute_diagnostic();
auto diag3_f = diag3->get_diagnostic();
REQUIRE(views_are_equal(diag3_f, diag3_manual));
}

} // namespace scream
11 changes: 10 additions & 1 deletion components/eamxx/src/share/io/scream_io_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ create_diagnostic (const std::string& diag_field_name,
std::regex backtend ("([A-Za-z0-9_]+)_atm_backtend$");
std::regex pot_temp ("(Liq)?PotentialTemperature$");
std::regex vert_layer ("(z|geopotential|height)_(mid|int)$");
std::regex horiz_avg ("([A-Za-z0-9_]+)_horiz_avg$");

std::string diag_name;
std::smatch matches;
Expand Down Expand Up @@ -191,7 +192,15 @@ create_diagnostic (const std::string& diag_field_name,
diag_name = "VerticalLayer";
params.set<std::string>("diag_name","dz");
params.set<std::string>("vert_location","mid");
} else {
}
else if (std::regex_search(diag_field_name,matches,horiz_avg)) {
diag_name = "HorizAvgDiag";
// Set the grid_name
params.set("grid_name",grid->name());
params.set<std::string>("field_name",matches[1].str());
}
else
{
// No existing special regex matches, so we assume that the diag field name IS the diag name.
diag_name = diag_field_name;
}
Expand Down

0 comments on commit 6ec6f3f

Please sign in to comment.