Skip to content

Commit

Permalink
ReactionReservoirScalar allow const and state_norm (#148)
Browse files Browse the repository at this point in the history
* ReactionReservoirScalar allow const and state_norm

Handle case where both 'const=true' and 'state_norm=true' parameters are set.

* Add tests for ReactionReservoirScalar const and state_norm options
  • Loading branch information
sjdaines authored Dec 20, 2024
1 parent 6deddd6 commit 0d94b07
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 18 deletions.
32 changes: 18 additions & 14 deletions src/reactioncatalog/Reservoirs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,26 @@ function setup_reactionreservoirscalar(m::PB.AbstractReactionMethod, pars, (R, R

rj.norm_value = PB.get_attribute(R_domvar, :norm_value)

if pars.const[] && (attribute_name == :setup)
PB.init_field!(
only(R), :initial_value, R_domvar, (_, _)->1.0, [], cellrange, (PB.fullname(R_domvar), "", "")
)
elseif attribute_name in (:norm_value, :initial_value)
if pars.state_norm[]
R_solve_var = only(R_solve_vars)
R_solve_domvar = R_solve_var.linkvar
PB.init_field!(
only(R_solve), attribute_name, R_domvar, (_, _)->1/rj.norm_value, [], cellrange, (PB.fullname(R_solve_domvar), " / $(rj.norm_value)", " [from $(PB.fullname(R_domvar))]")
)
else
if pars.const[]
if attribute_name == :setup
PB.init_field!(
only(R), attribute_name, R_domvar, (_, _)->1.0, [], cellrange, (PB.fullname(R_domvar), "", "")
only(R), :initial_value, R_domvar, (_, _)->1.0, [], cellrange, (PB.fullname(R_domvar), "", "")
)
end
else
if attribute_name in (:norm_value, :initial_value)
if pars.state_norm[]
R_solve_var = only(R_solve_vars)
R_solve_domvar = R_solve_var.linkvar
PB.init_field!(
only(R_solve), attribute_name, R_domvar, (_, _)->1/rj.norm_value, [], cellrange, (PB.fullname(R_solve_domvar), " / $(rj.norm_value)", " [from $(PB.fullname(R_domvar))]")
)
else
PB.init_field!(
only(R), attribute_name, R_domvar, (_, _)->1.0, [], cellrange, (PB.fullname(R_domvar), "", "")
)
end
end
end

return nothing
Expand All @@ -172,7 +176,7 @@ end
function do_reactionreservoirscalar(m::PB.AbstractReactionMethod, pars, (vars, ), cr::PB.AbstractCellRange, deltat)
rj = m.reaction

if pars.state_norm[]
if pars.state_norm[] && !pars.const[]
vars.R[] = vars.R_solve[]*rj.norm_value
vars.R_norm[] = PB.get_total(vars.R_solve[])
else
Expand Down
23 changes: 23 additions & 0 deletions test/configreservoirs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,29 @@ model1:
R:norm_value: 10.0
R:initial_value: 1.0

reservoir_NormS:
class: ReactionReservoirScalar

parameters:
state_norm: true
variable_links:
R*: NormS*
variable_attributes:
R:norm_value: 10.0
R:initial_value: 1.0

reservoir_ConstNormS:
class: ReactionReservoirScalar

parameters:
const: true
state_norm: true
variable_links:
R*: ConstNormS*
variable_attributes:
R:norm_value: 10.0
R:initial_value: 1.0

scalar_sum:
class: ReactionSum
parameters:
Expand Down
12 changes: 8 additions & 4 deletions test/runreservoirtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
modeldata = PB.create_modeldata(model)
PB.allocate_variables!(model, modeldata, 1; hostdep=false, check_units_opt=:error)

@test length( PB.get_unallocated_variables(global_domain, modeldata, 1)) == 4
@test length( PB.get_unallocated_variables(global_domain, modeldata, 1)) == 6
@test PB.check_ready(global_domain, modeldata, throw_on_error=false) == false

# allocate arrays for host dependencies and set data pointers
Expand All @@ -53,7 +53,7 @@ end
# check state variables
stateexplicit_vars, stateexplicit_sms_vars =
PB.get_host_variables(global_domain, PB.VF_StateExplicit, match_deriv_suffix="_sms")
@test length(stateexplicit_vars) == 2 # A and O
@test length(stateexplicit_vars) == 3 # A, O, NormS_solve
# get global state variable aggregator
global_stateexplicit_va = PB.VariableAggregator(stateexplicit_vars, fill(nothing, length(stateexplicit_vars)), modeldata, 1)

Expand All @@ -65,7 +65,7 @@ end

# get host-dependent variables
global_hostdep_vars_vec = PB.get_variables(global_domain, hostdep=true)
@test length(global_hostdep_vars_vec) == 4
@test length(global_hostdep_vars_vec) == 6

ocean_hostdep_vars_vec = PB.get_variables(ocean_domain, hostdep=true)
@test length(ocean_hostdep_vars_vec) == 4
Expand All @@ -78,11 +78,12 @@ end
PB.dispatch_setup(model, :initial_value, modeldata)

@info "global stateexplicit variables:\n"
@test all_data.global.NormS_solve[] == 0.1
@test all_data.global.A.v[] == 3.193e18 # A
@test all_data.global.A.v_moldelta[] == 2.0*3.193e18 # A_moldelta
# test access via flattened vector (as used eg by an ODE solver)
global_A_indices = PB.get_indices(global_stateexplicit_va, "global.A")
@test global_A_indices == 2:3
@test global_A_indices == 3:4 # NormS_solve, O, A - NB: order is implementation-dependent, not well defined
stateexplicit_vector = PB.get_data(global_stateexplicit_va)
@test stateexplicit_vector[global_A_indices] == [all_data.global.A.v[], all_data.global.A.v_moldelta[]]
# test VariableAggregatorNamed created from VariableAggregator
Expand All @@ -91,6 +92,7 @@ end

@info "global const variables:"
@test all_data.global.ConstS[] == 1.0
@test all_data.global.ConstNormS[] == 1.0

@info "ocean host-dependent variable initialisation:\n"
@test all_data.ocean.T[1] == 1.0*10.0
Expand All @@ -103,6 +105,8 @@ end
@info "global model-created variables:\n"
@test all_data.global.A_norm[] == 10.0
@test all_data.global.A_delta[] == 2.0
@test all_data.global.NormS[] == 1.0
@test all_data.global.NormS_norm[] == 0.1

@info "ocean model-created variables:\n"
@test all_data.ocean.T_conc == fill(1.0, ocean_length)
Expand Down

0 comments on commit 0d94b07

Please sign in to comment.