diff --git a/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl b/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl index b546d02..f357ed3 100644 --- a/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl +++ b/lib/ModelingToolkitTearing/src/clock_inference/clock_inference.jl @@ -3,6 +3,7 @@ Equation(Int) InitEquation(Int) Clock(SciMLBase.AbstractClock) + AssertDiscrete end struct ClockInference{S <: StateSelection.TransformationState} @@ -147,6 +148,10 @@ function (iec::InferEquationClosure)(ieq::Int, eq::Equation, is_initialization_e InferredClock.InferredDiscrete(i) => begin relative_edge = get!(Set{ClockVertex.Type}, relative_hyperedges, i) union!(relative_edge, arg_hyperedge) + # Ensure that this clock partition will be discrete. This is a separate + # variant because I don't want to give `InferredDiscrete` too many meanings. + push!(arg_hyperedge, ClockVertex.AssertDiscrete()) + add_edge!(inference_graph, arg_hyperedge) end end end @@ -163,6 +168,7 @@ function (iec::InferEquationClosure)(ieq::Int, eq::Equation, is_initialization_e union!(hyperedge, buffer) delete!(relative_hyperedges, i) end + push!(hyperedge, ClockVertex.AssertDiscrete()) end end else @@ -220,6 +226,9 @@ function infer_clocks!(ci::ClockInference) for partition in clock_partitions clockidxs = findall(Base.Fix2(Moshi.Data.isa_variant, ClockVertex.Clock), partition) if isempty(clockidxs) + if any(isequal(ClockVertex.AssertDiscrete()), partition) + throw(ExpectedDiscreteClockPartitionError(ts, partition, true)) + end push!(partition, ClockVertex.Clock(SciMLBase.ContinuousClock())) push!(clockidxs, length(partition)) end @@ -237,12 +246,16 @@ function infer_clocks!(ci::ClockInference) clock = Moshi.Match.@match partition[only(clockidxs)] begin ClockVertex.Clock(clk) => clk end + if clock == SciMLBase.ContinuousClock() && any(isequal(ClockVertex.AssertDiscrete()), partition) + throw(ExpectedDiscreteClockPartitionError(ts, partition, false)) + end for vert in partition Moshi.Match.@match vert begin ClockVertex.Variable(i) => (var_domain[i] = clock) ClockVertex.Equation(i) => (eq_domain[i] = clock) ClockVertex.InitEquation(i) => (init_eq_domain[i] = clock) ClockVertex.Clock(_) => nothing + ClockVertex.AssertDiscrete() => nothing end end end @@ -251,6 +264,49 @@ function infer_clocks!(ci::ClockInference) return ci end +struct ExpectedDiscreteClockPartitionError <: Exception + state::TearingState + partition::Vector{ClockVertex.Type} + has_no_clock::Bool +end + +function Base.showerror(io::IO, err::ExpectedDiscreteClockPartitionError) + if err.has_no_clock + println(io, """ + Found a clock partition that must be discrete (due to the presence of an \ + `InferredDiscrete`) but does not have any associated clock (and would otherwise \ + then default to being on the continuous clock). This likely means that the \ + partition was not assigned a valid discrete clock and the model is incorrect. + """) + else + println(io, """ + Found a clock partition that must be discrete (due to the presence of an \ + `InferredDiscrete`) but is associated with a continuous clock. This is likely \ + a modeling error. + """) + end + + vars = filter(Base.Fix2(Moshi.Data.isa_variant, ClockVertex.Variable), err.partition) + println(io, "Variables in the partition:") + for var in vars + println(io, " ", err.state.fullvars[var.:1]) + end + println(io) + + eqs = filter(Base.Fix2(Moshi.Data.isa_variant, ClockVertex.Equation), err.partition) + println(io, "Equations in the partition:") + for eq in eqs + println(io, " ", equations(err.state)[eq.:1]) + end + println(io) + + ieqs = filter(Base.Fix2(Moshi.Data.isa_variant, ClockVertex.InitEquation), err.partition) + println(io, "Initialization equations in the partition:") + for ieq in ieqs + println(io, " ", initialization_equations(err.state.sys)[ieq.:1]) + end +end + function resize_or_push!(v, val, idx) n = length(v) if idx > n