Skip to content

Commit

Permalink
Allow passing io to AnalysisCallback (#49)
Browse files Browse the repository at this point in the history
* allow passing io to AnalysisCallback

* format

* write output from AnalysisCallback to devnull in docs
  • Loading branch information
JoshuaLampert authored Sep 26, 2023
1 parent b9ad564 commit 7fe7444
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 39 deletions.
9 changes: 6 additions & 3 deletions docs/src/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ ode = semidiscretize(semi, tspan)
analysis_callback = AnalysisCallback(semi; interval = 10,
extra_analysis_errors = (:conservation_error,),
extra_analysis_integrals = (waterheight_total,
velocity, entropy))
velocity, entropy),
io = devnull)
callbacks = CallbackSet(analysis_callback)
saveat = range(tspan..., length = 100)
Expand Down Expand Up @@ -139,7 +140,8 @@ To obtain entropy-conserving time-stepping schemes DispersiveShallowWater.jl use
analysis_callback = AnalysisCallback(semi; interval = 10,
extra_analysis_errors = (:conservation_error,),
extra_analysis_integrals = (waterheight_total,
velocity, entropy))
velocity, entropy),
io = devnull)
relaxation_callback = RelaxationCallback(invariant = entropy)
callbacks = CallbackSet(relaxation_callback, analysis_callback)
sol = solve(ode, Tsit5(), abstol = 1e-7, reltol = 1e-7,
Expand Down Expand Up @@ -204,7 +206,8 @@ As before, we can run the simulation by
analysis_callback = AnalysisCallback(semi; interval = 10,
extra_analysis_errors = (:conservation_error,),
extra_analysis_integrals = (waterheight_total,
velocity, entropy))
velocity, entropy),
io = devnull)
relaxation_callback = RelaxationCallback(invariant = entropy)
callbacks = CallbackSet(relaxation_callback, analysis_callback)
sol = solve(ode, Tsit5(), abstol = 1e-7, reltol = 1e-7,
Expand Down
84 changes: 48 additions & 36 deletions src/callbacks_step/analysis.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
AnalysisCallback(semi; interval=0,
extra_analysis_errors=Symbol[],
extra_analysis_integrals=())
extra_analysis_integrals=(),
io=stdout)
Analyze a numerical solution every `interval` time steps.
The L2- and the L∞-norm for each component are computed by default.
Expand All @@ -14,6 +15,8 @@ You can also write your own function with the same signature as the examples lis
pass it via `extra_analysis_integrals`.
The computed errors and intergrals are saved for each timestep and can be obtained by calling
[`errors`](@ref) and [`integrals`](@ref).
During the Simulation, the `AnalysisCallback` will print information to `io`.
"""
mutable struct AnalysisCallback{T, AnalysisIntegrals, InitialStateIntegrals}
start_time::Float64
Expand All @@ -25,6 +28,7 @@ mutable struct AnalysisCallback{T, AnalysisIntegrals, InitialStateIntegrals}
tstops::Vector{Float64}
errors::Vector{Matrix{T}}
integrals::Vector{Vector{T}}
io::IO
end

function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:AnalysisCallback})
Expand Down Expand Up @@ -68,7 +72,8 @@ function AnalysisCallback(mesh, equations::AbstractEquations, solver;
extra_analysis_errors),
extra_analysis_integrals = (),
analysis_integrals = union(default_analysis_integrals(equations),
extra_analysis_integrals))
extra_analysis_integrals),
io = stdout)
# Decide when the callback is activated.
# With error-based step size control, some steps can be rejected. Thus,
# `integrator.iter >= integrator.stats.naccept`
Expand All @@ -90,7 +95,8 @@ function AnalysisCallback(mesh, equations::AbstractEquations, solver;
Val(nvariables(equations)))),
Vector{Float64}(),
Vector{Matrix{real(solver)}}(),
Vector{Vector{real(solver)}}())
Vector{Vector{real(solver)}}(),
io)

DiscreteCallback(condition, analysis_callback,
save_positions = (false, false),
Expand Down Expand Up @@ -174,7 +180,8 @@ function (analysis_callback::AnalysisCallback)(integrator)
semi = integrator.p
mesh, equations, solver = mesh_equations_solver(semi)

l2_error, linf_error = analysis_callback(integrator.u, integrator, semi)
l2_error, linf_error = analysis_callback(analysis_callback.io, integrator.u, integrator,
semi)

# avoid re-evaluating possible FSAL stages
u_modified!(integrator, false)
Expand All @@ -185,7 +192,7 @@ end

# This method is just called internally from `(analysis_callback::AnalysisCallback)(integrator)`
# and serves as a function barrier. Additionally, it makes the code easier to profile and optimize.
function (analysis_callback::AnalysisCallback)(u_ode, integrator, semi)
function (analysis_callback::AnalysisCallback)(io, u_ode, integrator, semi)
_, equations, solver = mesh_equations_solver(semi)
@unpack analysis_errors, analysis_integrals, tstops, errors, integrals = analysis_callback
@unpack t, dt = integrator
Expand Down Expand Up @@ -215,75 +222,80 @@ function (analysis_callback::AnalysisCallback)(u_ode, integrator, semi)
# Source: https://github.com/JuliaLang/julia/blob/b540315cb4bd91e6f3a3e4ab8129a58556947628/base/timing.jl#L86-L97
memory_use = Base.gc_live_bytes() / 2^20 # bytes -> MiB

println()
println(""^100)
println("Simulation running '", get_name(equations), "' with '", semi.initial_condition,
println(io)
println(io, ""^100)
println(io, "Simulation running '", get_name(equations), "' with '",
semi.initial_condition,
"'")
println(""^100)
println(" #timesteps: " * @sprintf("% 14d", iter) *
println(io, ""^100)
println(io,
" #timesteps: " * @sprintf("% 14d", iter) *
" " *
" run time: " * @sprintf("%10.8e s", runtime_absolute))
println(" Δt: " * @sprintf("%10.8e", dt) *
println(io,
" Δt: " * @sprintf("%10.8e", dt) *
" " *
" └── GC time: " *
@sprintf("%10.8e s (%5.3f%%)", gc_time_absolute, gc_time_percentage))
println(" sim. time: " * @sprintf("%10.8e (%5.3f%%)", t, t / t_final*100))
println(" #DOF: " * @sprintf("% 14d", nnodes(semi)) *
println(io, " sim. time: " * @sprintf("%10.8e (%5.3f%%)", t, t / t_final*100))
println(io,
" #DOF: " * @sprintf("% 14d", nnodes(semi)) *
" " *
" alloc'd memory: " * @sprintf("%14.3f MiB", memory_use))
println()
println(io)

print(" Variable: ")
print(io, " Variable: ")
for v in eachvariable(equations)
@printf(" %-14s", varnames(prim2prim, equations)[v])
@printf(io, " %-14s", varnames(prim2prim, equations)[v])
end
println()
println(io)

# Calculate L2/Linf errors, which are also returned
l2_error, linf_error = calc_error_norms(u_ode, t, semi)
current_errors = zeros(real(semi), (length(analysis_errors), nvariables(equations)))
current_errors[1, :] = l2_error
current_errors[2, :] = linf_error
print(" L2 error: ")
print(io, " L2 error: ")
for v in eachvariable(equations)
@printf(" % 10.8e", l2_error[v])
@printf(io, " % 10.8e", l2_error[v])
end
println()
println(io)

print(" Linf error: ")
print(io, " Linf error: ")
for v in eachvariable(equations)
@printf(" % 10.8e", linf_error[v])
@printf(io, " % 10.8e", linf_error[v])
end
println()
println(io)

# Conservation error
if :conservation_error in analysis_errors
@unpack initial_state_integrals = analysis_callback
state_integrals = integrate(u_ode, semi)
current_errors[3, :] = abs.(state_integrals - initial_state_integrals)
print(" |q - q₀|: ")
print(io, " |q - q₀|: ")
for v in eachvariable(equations)
@printf(" % 10.8e", current_errors[3, v])
@printf(io, " % 10.8e", current_errors[3, v])
end
println()
println(io)
end
push!(errors, current_errors)

# additional integrals
if length(analysis_integrals) > 0
println()
println(" Integrals: ")
println(io)
println(io, " Integrals: ")
end
current_integrals = zeros(real(semi), length(analysis_integrals))
analyze_integrals!(current_integrals, 1, analysis_integrals, u_ode, t, semi)
analyze_integrals!(io, current_integrals, 1, analysis_integrals, u_ode, t, semi)
push!(integrals, current_integrals)

println(""^100)
println(io, ""^100)
return l2_error, linf_error
end

# Iterate over tuples of analysis integrals in a type-stable way using "lispy tuple programming".
function analyze_integrals!(current_integrals, i, analysis_integrals::NTuple{N, Any}, u_ode,
function analyze_integrals!(io, current_integrals, i, analysis_integrals::NTuple{N, Any},
u_ode,
t, semi) where {N}

# Extract the first analysis integral and process it; keep the remaining to be processed later
Expand All @@ -292,17 +304,17 @@ function analyze_integrals!(current_integrals, i, analysis_integrals::NTuple{N,

res = analyze(quantity, u_ode, t, semi)
current_integrals[i] = res
@printf(" %-12s:", pretty_form_utf(quantity))
@printf(" % 10.8e", res)
println()
@printf(io, " %-12s:", pretty_form_utf(quantity))
@printf(io, " % 10.8e", res)
println(io)

# Recursively call this method with the unprocessed integrals
analyze_integrals!(current_integrals, i + 1, remaining_quantities, u_ode, t, semi)
analyze_integrals!(io, current_integrals, i + 1, remaining_quantities, u_ode, t, semi)
return nothing
end

# terminate the type-stable iteration over tuples
function analyze_integrals!(current_integrals, i, analysis_integrals::Tuple{}, u_ode, t,
function analyze_integrals!(io, current_integrals, i, analysis_integrals::Tuple{}, u_ode, t,
semi)
nothing
end
Expand Down

0 comments on commit 7fe7444

Please sign in to comment.