Skip to content

Commit

Permalink
Merge pull request #56 from PALEOtoolkit/reactionmethod_despecialize
Browse files Browse the repository at this point in the history
ReactionMethod etc updates to reduce startup time
  • Loading branch information
sjdaines authored Sep 23, 2022
2 parents 1703941 + 327ccc9 commit 06fd1e0
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 83 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PALEOboxes"
uuid = "804b410e-d900-4b2a-9ecd-f5a06d4c1fd4"
authors = ["Stuart Daines <[email protected]>"]
version = "0.21.3"
version = "0.21.4"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Expand Down
9 changes: 6 additions & 3 deletions src/Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,10 @@ function initialize_reactiondata!(

check_modeldata(model, modeldata)

# using Ref here seems to trade off time to create ReactionMethodDispatchList
# TODO using Ref here seems to trade off time to create ReactionMethodDispatchList
# and time for first call to do_deriv ??
# (Ref gives fast ReactionMethodDispatchList creation, but slow first do_deriv)
# NB: passing Ref to call_method seems to speed up first do_deriv

modeldata.sorted_methodsdata_setup =
[
Expand Down Expand Up @@ -549,7 +550,7 @@ function _create_dispatch_methodlist(methodsdata, cellranges, generated_dispatch
end # timeit

if generated_dispatch
@timeit "ReactionMethodDispatchList" rmdl = ReactionMethodDispatchList(Tuple(methods), Tuple(vardatas), Tuple(crs))
@timeit "ReactionMethodDispatchList" rmdl = ReactionMethodDispatchList(methods, vardatas, crs)
else
@timeit "ReactionMethodDispatchListNoGen" rmdl = ReactionMethodDispatchListNoGen(methods, vardatas, crs)
end
Expand Down Expand Up @@ -615,7 +616,9 @@ function emits unrolled code with a function call for each Tuple element.
push!(ex.args,
quote
# let
call_method(dl.methods[$j][], dl.vardatas[$j][], dl.cellranges[$j], deltat)
# call_method(dl.methods[$j][], dl.vardatas[$j][], dl.cellranges[$j], deltat)
# pass Ref to function to reduce compile time
call_method(dl.methods[$j], dl.vardatas[$j], dl.cellranges[$j], deltat)
# end
end
)
Expand Down
9 changes: 7 additions & 2 deletions src/PALEOboxes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,15 @@ function precompile_reaction(rdict, classname)
end

# create and take a timestep for a test configuration
function run_model(configfile, configname)
function run_model(configfile::AbstractString, configname::AbstractString)

model = create_model_from_config(configfile, configname)

run_model(model)
return nothing
end

function run_model(model::Model)
modeldata = create_modeldata(model)
allocate_variables!(model, modeldata)

Expand All @@ -114,7 +119,7 @@ function run_model(configfile, configname)
dispatch_setup(model, :norm_value, modeldata)
dispatch_setup(model, :initial_value, modeldata)

# take a time step
# take a time step - TODO, can be model dependent on missing setup
# dispatchlists = modeldata.dispatchlists_all
# do_deriv(dispatchlists)

Expand Down
50 changes: 27 additions & 23 deletions src/ReactionMethod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ the data types of `vardata` or to cache expensive calculations:
This is called after model arrays are allocated, and prior to setup.
"""
struct ReactionMethod{M, R, P, V, Nargs} <: AbstractReactionMethod
struct ReactionMethod{M, R, P, Nargs} <: AbstractReactionMethod
"callback from Model framework"
methodfn::M

Expand All @@ -53,9 +53,10 @@ struct ReactionMethod{M, R, P, V, Nargs} <: AbstractReactionMethod
"a descriptive name, eg generated from the name of methodfn"
name::String

"Tuple{Vararg{AbstractVarList}} of [`VariableReaction`](@ref)s.
Corresponding Variable accessors `vardata` (views on Arrays) will be provided to the `methodfn` callback."
varlists::V
"Tuple of VarLists, each representing a list of [`VariableReaction`](@ref)s.
Corresponding Variable accessors `vardata` (views on Arrays) will be provided to the `methodfn` callback.
NB: not concretely typed to reduce compile time, as not performance-critical"
varlists::Tuple{Vararg{AbstractVarList}}

"optional context field (of arbitrary type) to store data needed by methodfn."
p::P
Expand All @@ -71,22 +72,22 @@ struct ReactionMethod{M, R, P, V, Nargs} <: AbstractReactionMethod
methodfn::M,
reaction::R,
name,
varlists::V,
varlists::Tuple{Vararg{AbstractVarList}},
p::P,
operatorID::Vector{Int64},
domain;
preparefn = (m, vardata) -> vardata,
) where {M <: Function, R <: AbstractReaction, P, V <: Tuple{Vararg{AbstractVarList}}}
) where {M <: Function, R <: AbstractReaction, P}

# Find number of arguments that methodfn takes
# (in order to support two forms of 'methodfn', with and without Parameters)
nargs = fieldcount(methods(methodfn).ms[1].sig) - 1

newmethod = new{M, R, P, V, nargs}(
newmethod = new{M, R, P, nargs}(
methodfn,
reaction,
name,
deepcopy(varlists),
varlists,
p,
operatorID,
domain,
Expand All @@ -102,20 +103,23 @@ struct ReactionMethod{M, R, P, V, Nargs} <: AbstractReactionMethod
end
end

get_nargs(method::ReactionMethod{M, R, P, V, Nargs}) where {M, R, P, V, Nargs} = Nargs
get_nargs(methodref::Ref{ReactionMethod{M, R, P, V, Nargs}}) where {M, R, P, V, Nargs} = Nargs
get_nargs(method::ReactionMethod{M, R, P, Nargs}) where {M, R, P, Nargs} = Nargs
get_nargs(methodref::Ref{ReactionMethod{M, R, P, Nargs}}) where {M, R, P, Nargs} = Nargs
# deprecated form without pars
@inline call_method(method::ReactionMethod{M, R, P, V, 4}, vardata, cr, modelctxt) where {M, R, P, V} =
@inline call_method(method::ReactionMethod{M, R, P, 4}, vardata, cr, modelctxt) where {M, R, P} =
method.methodfn(method, vardata, cr, modelctxt)
# updated form with pars
@inline call_method(method::ReactionMethod{M, R, P, V, 5}, vardata, cr, modelctxt) where {M, R, P, V} =
@inline call_method(method::ReactionMethod{M, R, P, 5}, vardata, cr, modelctxt) where {M, R, P} =
method.methodfn(method, method.reaction.pars, vardata, cr, modelctxt)

@noinline call_method(methodref::Ref{<: ReactionMethod}, vardataref::Ref, cr, modelctxt) =
call_method(methodref[], vardataref[], cr, modelctxt)

# for benchmarking etc: apply codefn to the ReactionMethod methodfn (without this, will just apply to the call_method wrapper)
# codefn=code_warntype, code_llvm, code_native
call_method_codefn(io::IO, codefn, method::ReactionMethod{M, R, P, V, 4}, vardata, cr, modelctxt; kwargs...) where {M, R, P, V} =
call_method_codefn(io::IO, codefn, method::ReactionMethod{M, R, P, 4}, vardata, cr, modelctxt; kwargs...) where {M, R, P} =
codefn(io, method.methodfn, (typeof(method), typeof(vardata), typeof(cr), typeof(modelctxt)); kwargs...)
call_method_codefn(io::IO, codefn, method::ReactionMethod{M, R, P, V, 5}, vardata, cr, modelctxt; kwargs...) where {M, R, P, V} =
call_method_codefn(io::IO, codefn, method::ReactionMethod{M, R, P, 5}, vardata, cr, modelctxt; kwargs...) where {M, R, P} =
codefn(io, method.methodfn, (typeof(method), typeof(method.reaction.pars), typeof(vardata), typeof(cr), typeof(modelctxt)); kwargs...)


Expand All @@ -124,15 +128,15 @@ call_method_codefn(io::IO, codefn, method::ReactionMethod{M, R, P, V, 5}, vardat
Get all [`VariableReaction`](@ref)s from `method` as a Tuple of `Vector{VariableReaction}`
"""
get_variables_tuple(method::AbstractReactionMethod) = Tuple(get_variables(vl) for vl in method.varlists)
get_variables_tuple(@nospecialize(method::ReactionMethod)) = Tuple(get_variables(vl) for vl in method.varlists)

"""
get_variables(method::AbstractReactionMethod; filterfn = v -> true) -> Vector{VariableReaction}
Get VariableReactions from `method.varlists` as a flat Vector, optionally restricting to those that match `filterfn`
"""
function get_variables(
method::AbstractReactionMethod;
@nospecialize(method::ReactionMethod);
filterfn = v -> true
)
vars = VariableReaction[]
Expand All @@ -153,7 +157,7 @@ Get a single VariableReaction `v` by `localname`.
If `localname` not present, returns `nothing` if `allow_not_found==true` otherwise errors.
"""
function get_variable(
method::AbstractReactionMethod, localname::AbstractString;
@nospecialize(method::ReactionMethod), localname::AbstractString;
allow_not_found=false
)
matchvars = get_variables(method; filterfn = v -> v.localname==localname)
Expand All @@ -164,20 +168,20 @@ function get_variable(
return isempty(matchvars) ? nothing : matchvars[1]
end

fullname(method::AbstractReactionMethod) = fullname(method.reaction)*"."*method.name
fullname(@nospecialize(method::ReactionMethod)) = fullname(method.reaction)*"."*method.name

is_method_setup(method::AbstractReactionMethod) = (method in method.reaction.methods_setup)
is_method_initialize(method::AbstractReactionMethod) = (method in method.reaction.methods_initialize)
is_method_do(method::AbstractReactionMethod) = (method in method.reaction.methods_do)
is_method_setup(@nospecialize(method::ReactionMethod)) = (method in base(method.reaction).methods_setup)
is_method_initialize(@nospecialize(method::ReactionMethod)) = (method in base(method.reaction).methods_initialize)
is_method_do(@nospecialize(method::ReactionMethod)) = (method in base(method.reaction).methods_do)

get_rate_stoichiometry(m::ReactionMethod) = []
get_rate_stoichiometry(@nospecialize(m::ReactionMethod)) = []

###########################################
# Pretty printing
############################################

"compact form"
function Base.show(io::IO, method::ReactionMethod)
function Base.show(io::IO, @nospecialize(method::ReactionMethod))
print(
io,
"ReactionMethod(fullname='", fullname(method),
Expand Down
3 changes: 3 additions & 0 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ struct ReactionMethodDispatchList{M <:Tuple, V <:Tuple, C <: Tuple}
cellranges::C
end

ReactionMethodDispatchList(methods::Vector, vardatas::Vector, cellranges::Vector) =
ReactionMethodDispatchList(Tuple(methods), Tuple(vardatas), Tuple(cellranges))

struct ReactionMethodDispatchListNoGen
methods::Vector
vardatas::Vector
Expand Down
37 changes: 26 additions & 11 deletions src/VariableReaction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ const VarDepT = VariableReaction{VT_ReactDependency}
const VarTargetT = VariableReaction{VT_ReactTarget}
const VarContribT = VariableReaction{VT_ReactContributor}

function Base.copy(v::VariableReaction{T}) where T
vcopy = VariableReaction{T}(
method = v.method,
localname = v.localname,
attributes = copy(v.attributes), # NB: no deepcopy
# attributes = Dict{Symbol, Any}(k=>copy(v) for (k, v) in v.attributes), #
linkreq_domain = v.linkreq_domain,
linkreq_subdomain = v.linkreq_subdomain,
linkreq_name = v.linkreq_name,
link_optional = v.link_optional,
linkvar = v.linkvar,
)
return vcopy
end

"""
get_domvar_attribute(var::VariableReaction, name::Symbol, missing_value=missing) -> value
Expand Down Expand Up @@ -206,7 +221,7 @@ VarDep(v::VarDepT) = v
function VarDep(v::Union{VarPropT, VarTargetT})
vdep = VarDepT(
localname = v.localname,
attributes = deepcopy(v.attributes),
attributes = copy(v.attributes), # NB: no deepcopy
linkreq_domain = v.linkreq_domain,
linkreq_subdomain = v.linkreq_subdomain,
linkreq_name = v.linkreq_name
Expand All @@ -233,7 +248,7 @@ VarContrib(localname, units, description; kwargs... ) =
VarContrib(v::VarContribT) = v
VarContrib(v::VarTargetT) = VarContribT(
localname = v.localname,
attributes = deepcopy(v.attributes),
attributes = copy(v.attributes), # NB: no deepcopy
linkreq_domain = v.linkreq_domain,
linkreq_subdomain = v.linkreq_subdomain,
linkreq_name = v.linkreq_name
Expand Down Expand Up @@ -268,7 +283,7 @@ VarState(localname, units, description; attributes::Tuple=(), kwargs... ) =
# TODO: define a VarInit Type. Currently (ab)using VarDep
VarInit(v::Union{VarPropT, VarTargetT, VarContribT}) = VarDepT(
localname = v.localname,
attributes = deepcopy(v.attributes),
attributes = copy(v.attributes), # NB: no deepcopy
linkreq_domain = v.linkreq_domain,
linkreq_subdomain = v.linkreq_subdomain,
linkreq_name = v.linkreq_name
Expand Down Expand Up @@ -337,7 +352,7 @@ Create a `VarList_single` describing a single `VariableReaction`,
struct VarList_single <: AbstractVarList
var::VariableReaction
components::Bool
VarList_single(var; components=false) = new(var, components)
VarList_single(var; components=false) = new(copy(var), components)
end

get_variables(vl::VarList_single) = [vl.var]
Expand All @@ -356,7 +371,7 @@ struct VarList_components <: AbstractVarList
vars::Vector{VariableReaction}
allow_unlinked::Bool
VarList_components(varcollection; allow_unlinked=false) =
new([v for v in varcollection], allow_unlinked)
new([copy(v) for v in varcollection], allow_unlinked)
end

get_variables(vl::VarList_components) = vl.vars
Expand Down Expand Up @@ -400,7 +415,7 @@ If `components = true`, each NamedTuple field will be a Vector of data array com
"""
function VarList_namedtuple(varcollection; components=false)
keys = Symbol.([v.localname for v in varcollection])
vars = [v for v in varcollection]
vars = [copy(v) for v in varcollection]

return VarList_namedtuple(vars, keys, components)
end
Expand All @@ -420,7 +435,7 @@ function VarList_namedtuple_fields(objectwithvars; components=false)
if getproperty(objectwithvars, f) isa VariableReaction]

keys = [f for (f,v) in fieldnamesvars]
vars = [v for (f,v) in fieldnamesvars]
vars = [copy(v) for (f,v) in fieldnamesvars]
return VarList_namedtuple(vars, keys, components)
end

Expand All @@ -441,7 +456,7 @@ If `components = true`, each Tuple field will be a Vector of data array componen
struct VarList_tuple <: AbstractVarList
vars::Vector{VariableReaction}
components::Bool
VarList_tuple(varcollection; components=false) = new([v for v in varcollection], components)
VarList_tuple(varcollection; components=false) = new([copy(v) for v in varcollection], components)
end

get_variables(vl::VarList_tuple) = vl.vars
Expand All @@ -465,7 +480,7 @@ struct VarList_vector <: AbstractVarList
components::Bool
forceview::Bool
VarList_vector(varcollection; components=false, forceview=false) =
new([v for v in varcollection], components, forceview)
new([copy(v) for v in varcollection], components, forceview)
end

get_variables(vl::VarList_vector) = vl.vars
Expand All @@ -485,7 +500,7 @@ If `components = true`, each Vector of Vectors element will be a Vector of data
struct VarList_vvector <: AbstractVarList
vars::Vector{Vector{VariableReaction}}
components::Bool
VarList_vvector(vars; components=false) = new(vars, components)
VarList_vvector(vars; components=false) = new([[copy(v) for v in vv] for vv in vars], components)
end

get_variables(vl::VarList_vvector) = vcat(vl.vars...)
Expand Down Expand Up @@ -526,7 +541,7 @@ Create a `VarList_fields` describing a collection of `VariableReaction`s,
"""
struct VarList_fields <: AbstractVarList
vars::Vector{VariableReaction}
VarList_fields(varcollection) = new([v for v in varcollection])
VarList_fields(varcollection) = new([copy(v) for v in varcollection])
end

get_variables(vl::VarList_fields) = vl.vars
Expand Down
12 changes: 7 additions & 5 deletions src/reactioncatalog/Reservoirs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ A single scalar biogeochemical reservoir with optional paired isotope reservoir,
Creates State and associated Variables, depending on parameter settings:
- `const=false`: usual case, create state variable `R` (units mol, with attribute `vfunction=VF_StateExplicit`)
and `R_sms` (units mol yr-1, with attribute `vfunction=VF_Deriv`).
- `const=true`: a constant value, create `R` (a Property)
- `const=true`: a constant value, create `R` (a Property), and `R_sms` (a Target)
In addition:
- a Property `R_norm` (normalized value) is always created.
Expand Down Expand Up @@ -58,7 +58,7 @@ Base.@kwdef mutable struct ReactionReservoirScalar{P} <: PB.AbstractReaction
allowed_values=PB.IsotopeTypes,
description="disable / enable isotopes and specify isotope type"),
PB.ParBool("const", false,
description="true to provide constant value with no _sms Variable"),
description="true to provide a constant value: R is not a state variable, fluxes in R_sms Variable are ignored"),
)

norm_value::Float64 = NaN
Expand Down Expand Up @@ -94,15 +94,17 @@ function PB.register_methods!(rj::ReactionReservoirScalar)
force_initial_norm_value=true, # setup :norm_value, :initial_value to get norm_value callback, even though R is not a state Variable
setup_callback=setup_callback
)
# no _sms variable
R_sms = PB.VarTargetScalar( "R_sms", "mol yr-1", "scalar reservoir source-sinks", attributes=(:field_data =>rj.pars.field_data[],))
else
R = PB.VarStateExplicitScalar("R", "mol", "scalar reservoir", attributes=(:field_data =>rj.pars.field_data[],))
PB.add_method_setup_initialvalue_vars_default!(rj, [R], setup_callback=setup_callback)

R_sms = PB.VarDerivScalar( "R_sms", "mol yr-1", "scalar reservoir source-sinks", attributes=(:field_data =>rj.pars.field_data[],))
# sms variable not used by us, but must appear in a method to be linked and created
PB.add_method_do_nothing!(rj, [R_sms])
end
PB.setfrozen!(rj.pars.const)

# sms variable not used by us, but must appear in a method to be linked and created
PB.add_method_do_nothing!(rj, [R_sms])

do_vars = [PB.VarDep(R), PB.VarPropScalar("R_norm", "", "scalar reservoir normalized")]
if rj.pars.field_data[] <: PB.AbstractIsotopeScalar
Expand Down
Loading

2 comments on commit 06fd1e0

@sjdaines
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/68832

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.21.4 -m "<description of version>" 06fd1e0d61e50d66de26e528ba04058cec54102e
git push origin v0.21.4

Please sign in to comment.