Skip to content

Commit

Permalink
introduce type parameters in solver (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaLampert authored Dec 13, 2023
1 parent 5ae06b4 commit bc11555
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@ abstract type AbstractSolver end
A `struct` that holds the summation by parts (SBP) operators that are used for the spatial discretization.
"""
struct Solver{RealT <: Real} <: AbstractSolver
D1::AbstractDerivativeOperator{RealT}
D2::Union{AbstractDerivativeOperator{RealT}, AbstractMatrix{RealT}}

function Solver{RealT}(D1::AbstractDerivativeOperator{RealT},
D2::Union{AbstractDerivativeOperator{RealT},
AbstractMatrix{RealT}}) where {RealT}
struct Solver{RealT <: Real, FirstDerivative <: AbstractDerivativeOperator{RealT},
SecondDerivative <:
Union{AbstractDerivativeOperator{RealT}, AbstractMatrix{RealT}}} <:
AbstractSolver
D1::FirstDerivative
D2::SecondDerivative

function Solver{RealT, FirstDerivative, SecondDerivative}(D1::FirstDerivative,
D2::SecondDerivative) where {
RealT,
FirstDerivative,
SecondDerivative
}
@assert derivative_order(D1) == 1
if D2 isa AbstractDerivativeOperator
@assert derivative_order(D2) == 2
Expand All @@ -35,7 +41,7 @@ function Solver(mesh, accuracy_order)
D1 = periodic_derivative_operator(1, accuracy_order, mesh.xmin, mesh.xmax, mesh.N)
D2 = periodic_derivative_operator(2, accuracy_order, mesh.xmin, mesh.xmax, mesh.N)
@assert real(D1) == real(D2)
Solver{real(D1)}(D1, D2)
Solver{real(D1), typeof(D1), typeof(D2)}(D1, D2)
end

# Also allow to pass custom SBP operators (for convenience without explicitly specifying the type)
Expand All @@ -50,7 +56,7 @@ function Solver(D1::AbstractDerivativeOperator{RealT},
D2::Union{AbstractDerivativeOperator{RealT}, AbstractMatrix{RealT}}) where {
RealT
}
Solver{RealT}(D1, D2)
Solver{RealT, typeof(D1), typeof(D2)}(D1, D2)
end

function Base.show(io::IO, solver::Solver{RealT}) where {RealT}
Expand Down

0 comments on commit bc11555

Please sign in to comment.