Skip to content

Commit

Permalink
Refactor create_cache (#106)
Browse files Browse the repository at this point in the history
* refactor create_cache

pass and dispatch on boundary condition

put tmp1 to initial cache (always needed for the RelaxationCallback to work)

fix upwind discretization for BBMBBMEquations1D for reflecting boundary conditions and add test

* format

* Update src/equations/bbm_bbm_1d.jl

Co-authored-by: Hendrik Ranocha <[email protected]>

---------

Co-authored-by: Hendrik Ranocha <[email protected]>
  • Loading branch information
JoshuaLampert and ranocha authored May 20, 2024
1 parent 6d27850 commit 338f8f8
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 50 deletions.
65 changes: 36 additions & 29 deletions src/equations/bbm_bbm_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,39 +127,43 @@ function source_terms_manufactured_reflecting(q, x, t, equations::BBMBBMEquation
return SVector(dq1, dq2)
end

function create_cache(mesh,
equations::BBMBBMEquations1D,
solver,
initial_condition,
RealT,
uEltype)
tmp1 = Array{RealT}(undef, nnodes(mesh)) # tmp1 is needed for the `RelaxationCallback`
function create_cache(mesh, equations::BBMBBMEquations1D,
solver, initial_condition,
::BoundaryConditionPeriodic,
RealT, uEltype)
D = equations.D
if solver.D1 isa PeriodicDerivativeOperator ||
solver.D1 isa UniformPeriodicCoupledOperator ||
solver.D1 isa PeriodicUpwindOperators
invImD2 = inv(I - 1 / 6 * D^2 * Matrix(solver.D2))
return (invImD2 = invImD2, tmp1 = tmp1)
elseif solver.D1 isa DerivativeOperator ||
solver.D1 isa UniformCoupledOperator ||
solver.D1 isa UpwindOperators
invImD2 = inv(I - 1 / 6 * D^2 * Matrix(solver.D2))
return (invImD2 = invImD2,)
end

function create_cache(mesh, equations::BBMBBMEquations1D,
solver, initial_condition,
::BoundaryConditionReflecting,
RealT, uEltype)
D = equations.D
N = nnodes(mesh)
M = mass_matrix(solver.D1)
Pd = BandedMatrix((-1 => fill(one(real(mesh)), N - 2),), (N, N - 2))
D2d = (sparse(solver.D2) * Pd)[2:(end - 1), :]
# homogeneous Dirichlet boundary conditions
invImD2d = inv(I - 1 / 6 * D^2 * D2d)
m = diag(M)
m[1] = 0
m[end] = 0
PdM = Diagonal(m)

# homogeneous Neumann boundary conditions
if solver.D1 isa DerivativeOperator ||
solver.D1 isa UniformCoupledOperator
D1_b = BandedMatrix(solver.D1)
M = mass_matrix(solver.D1)
Pd = BandedMatrix((-1 => fill(one(eltype(D1_b)), size(D1_b, 1) - 2),),
(size(D1_b, 1), size(D1_b, 1) - 2))
D2d = (sparse(solver.D2) * Pd)[2:(end - 1), :]
# homogeneous Dirichtlet boundary conditions
invImD2d = inv(I - 1 / 6 * D^2 * D2d)
m = diag(M)
m[1] = 0
m[end] = 0
PdM = Diagonal(m)
# homogeneous Neumann boundary conditions
invImD2n = inv(I + 1 / 6 * D^2 * inv(M) * D1_b' * PdM * D1_b)
return (invImD2d = invImD2d, invImD2n = invImD2n, tmp1 = tmp1)
elseif solver.D1 isa UpwindOperators
D1plus_b = BandedMatrix(solver.D1.plus)
invImD2n = inv(I + 1 / 6 * D^2 * inv(M) * D1plus_b' * PdM * D1plus_b)
else
@error "unknown type of first-derivative operator: $(typeof(solver.D1))"
end
return (invImD2d = invImD2d, invImD2n = invImD2n)
end

# Discretization that conserves the mass (for eta and v) and the energy for periodic boundary conditions, see
Expand All @@ -186,6 +190,9 @@ function rhs!(du_ode, u_ode, t, mesh, equations::BBMBBMEquations1D, initial_cond
@timeit timer() "dv hyperbolic" dv[:]=-solver.D1 *
(equations.gravity * eta + 0.5 * v .^ 2)
elseif solver.D1 isa PeriodicUpwindOperators
# Note that the upwind operators here are not actually used
# We would need to define two different matrices `invImD2` for eta and v for energy conservation
# To really use the upwind operators, we can use them with `BBMBBMVariableEquations1D`
@timeit timer() "deta hyperbolic" deta[:]=-solver.D1.central * (D * v + eta .* v)
@timeit timer() "dv hyperbolic" dv[:]=-solver.D1.central *
(equations.gravity * eta + 0.5 * v .^ 2)
Expand Down Expand Up @@ -225,8 +232,8 @@ function rhs!(du_ode, u_ode, t, mesh, equations::BBMBBMEquations1D, initial_cond
@timeit timer() "dv hyperbolic" dv[:]=-solver.D1 *
(equations.gravity * eta + 0.5 * v .^ 2)
elseif solver.D1 isa UpwindOperators
@timeit timer() "deta hyperbolic" deta[:]=-solver.D1.central * (D * v + eta .* v)
@timeit timer() "dv hyperbolic" dv[:]=-solver.D1.central *
@timeit timer() "deta hyperbolic" deta[:]=-solver.D1.minus * (D * v + eta .* v)
@timeit timer() "dv hyperbolic" dv[:]=-solver.D1.plus *
(equations.gravity * eta + 0.5 * v .^ 2)
else
@error "unknown type of first-derivative operator: $(typeof(solver.D1))"
Expand Down
16 changes: 6 additions & 10 deletions src/equations/bbm_bbm_variable_bathymetry_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,10 @@ function initial_condition_dingemans(x, t, equations::BBMBBMVariableEquations1D,
return SVector(eta, v, D)
end

function create_cache(mesh,
equations::BBMBBMVariableEquations1D,
solver,
initial_condition,
RealT,
uEltype)
function create_cache(mesh, equations::BBMBBMVariableEquations1D,
solver, initial_condition,
::BoundaryConditionPeriodic,
RealT, uEltype)
# Assume D is independent of time and compute D evaluated at mesh points once.
D = Array{RealT}(undef, nnodes(mesh))
x = grid(solver)
Expand All @@ -164,15 +162,13 @@ function create_cache(mesh,
if solver.D1 isa PeriodicDerivativeOperator ||
solver.D1 isa UniformPeriodicCoupledOperator
invImDKD = inv(I - 1 / 6 * Matrix(solver.D1) * K * Matrix(solver.D1))
invImD2K = inv(I - 1 / 6 * Matrix(solver.D2) * K)
elseif solver.D1 isa PeriodicUpwindOperators
invImDKD = inv(I - 1 / 6 * Matrix(solver.D1.minus) * K * Matrix(solver.D1.plus))
invImD2K = inv(I - 1 / 6 * Matrix(solver.D2) * K)
else
@error "unknown type of first-derivative operator: $(typeof(solver.D1))"
end
tmp1 = Array{RealT}(undef, nnodes(mesh)) # tmp1 is needed for the `RelaxationCallback`
return (invImDKD = invImDKD, invImD2K = invImD2K, D = D, tmp1 = tmp1)
invImD2K = inv(I - 1 / 6 * Matrix(solver.D2) * K)
return (invImDKD = invImDKD, invImD2K = invImD2K, D = D)
end

# Discretization that conserves the mass (for eta and v) and the energy for periodic boundary conditions, see
Expand Down
13 changes: 5 additions & 8 deletions src/equations/svaerd_kalisch_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,10 @@ function source_terms_manufactured(q, x, t, equations::SvaerdKalischEquations1D)
return SVector(dq1, dq2, zero(dq1))
end

function create_cache(mesh,
equations::SvaerdKalischEquations1D,
solver,
initial_condition,
RealT,
uEltype)
function create_cache(mesh, equations::SvaerdKalischEquations1D,
solver, initial_condition,
::BoundaryConditionPeriodic,
RealT, uEltype)
# Assume D is independent of time and compute D evaluated at mesh points once.
D = Array{RealT}(undef, nnodes(mesh))
x = grid(solver)
Expand All @@ -178,7 +176,6 @@ function create_cache(mesh,
alpha_hat = sqrt.(equations.alpha * sqrt.(equations.gravity * D) .* D .^ 2)
beta_hat = equations.beta * D .^ 3
gamma_hat = equations.gamma * sqrt.(equations.gravity * D) .* D .^ 3
tmp1 = similar(h)
tmp2 = similar(h)
hmD1betaD1 = Array{RealT}(undef, nnodes(mesh), nnodes(mesh))
if solver.D1 isa PeriodicDerivativeOperator ||
Expand All @@ -194,7 +191,7 @@ function create_cache(mesh,
end
return (hmD1betaD1 = hmD1betaD1, D1betaD1 = D1betaD1, D = D, h = h, hv = hv,
alpha_hat = alpha_hat, beta_hat = beta_hat, gamma_hat = gamma_hat,
tmp1 = tmp1, tmp2 = tmp2, D1_central = D1_central, D1 = solver.D1)
tmp2 = tmp2, D1_central = D1_central, D1 = solver.D1)
end

# Discretization that conserves the mass (for eta and for flat bottom hv) and the energy for periodic boundary conditions, see
Expand Down
8 changes: 5 additions & 3 deletions src/semidiscretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end
boundary_conditions=boundary_condition_periodic,
RealT=real(solver),
uEltype=RealT,
initial_cache=NamedTuple())
initial_cache=(tmp1 = Array{RealT}(undef, nnodes(mesh)),))
Construct a semidiscretization of a PDE.
"""
Expand All @@ -56,9 +56,11 @@ function Semidiscretization(mesh, equations, initial_condition, solver;
# `RealT` is used as real type for node locations etc.
# while `uEltype` is used as element type of solutions etc.
RealT = real(solver), uEltype = RealT,
initial_cache = NamedTuple())
# tmp1 is needed for the `RelaxationCallback`
initial_cache = (tmp1 = Array{RealT}(undef, nnodes(mesh)),))
cache = (;
create_cache(mesh, equations, solver, initial_condition, RealT, uEltype)...,
create_cache(mesh, equations, solver, initial_condition, boundary_conditions,
RealT, uEltype)...,
initial_cache...)

Semidiscretization{typeof(mesh), typeof(equations), typeof(initial_condition),
Expand Down
30 changes: 30 additions & 0 deletions test/test_bbm_bbm_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,36 @@ EXAMPLES_DIR = joinpath(examples_dir(), "bbm_bbm_1d")
change_waterheight=4.2697991261385974e-11,
change_velocity=0.5469460931577577,
change_entropy=130.69415963528576)

# test upwind operators
using SummationByPartsOperators: upwind_operators, Mattsson2017
using SparseArrays: sparse
using OrdinaryDiffEq: solve
D1 = upwind_operators(Mattsson2017; derivative_order = 1,
accuracy_order = accuracy_order, xmin = mesh.xmin,
xmax = mesh.xmax,
N = mesh.N)
D2 = sparse(D1.plus) * sparse(D1.minus)
solver = Solver(D1, D2)
semi = Semidiscretization(mesh, equations, initial_condition, solver,
boundary_conditions = boundary_conditions,
source_terms = source_terms)
ode = semidiscretize(semi, (0.0, 1.0))
sol = solve(ode, Tsit5(), abstol = 1e-7, reltol = 1e-7,
save_everystep = false, callback = callbacks, saveat = saveat)
atol = 1e-12
rtol = 1e-12
errs = errors(analysis_callback)
l2 = [6.465599803116574e-6 2.268226230557415e-8]
l2_measured = errs.l2_error[:, end]
for (l2_expected, l2_actual) in zip(l2, l2_measured)
@test isapprox(l2_expected, l2_actual, atol = atol, rtol = rtol)
end
linf = [0.00015506984862057038 8.639888086914294e-8]
linf_measured = errs.linf_error[:, end]
for (linf_expected, linf_actual) in zip(linf, linf_measured)
@test isapprox(linf_expected, linf_actual, atol = atol, rtol = rtol)
end
end
end

Expand Down

0 comments on commit 338f8f8

Please sign in to comment.