diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 8c9178c683..1a405daf45 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -37,6 +37,7 @@ jobs: - "Forward" - "DGM" - "NNODE" + - "PINOODE" - "NeuralAdapter" - "IntegroDiff" uses: "SciML/.github/.github/workflows/tests.yml@v1" diff --git a/Project.toml b/Project.toml index 304d6efb97..46fb9f6804 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" +NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" @@ -79,12 +80,13 @@ LogDensityProblems = "2" Lux = "1.1.0" LuxCUDA = "0.3.3" LuxCore = "1.0.1" -LuxLib = "1.3.2" +LuxLib = "1.3" MCMCChains = "6" MLDataDevices = "1.2.0" MethodOfLines = "0.11.6" ModelingToolkit = "9.46" MonteCarloMeasurements = "1.1" +NeuralOperators = "0.5" Optimisers = "0.3.3" Optimization = "4" OptimizationOptimJL = "0.4" diff --git a/docs/Project.toml b/docs/Project.toml index b8bbab2416..935421b02a 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -16,6 +16,7 @@ MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" NeuralPDE = "315f7962-48a3-4962-8226-d0f33b1235f0" +NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" @@ -44,6 +45,7 @@ MethodOfLines = "0.11" ModelingToolkit = "9.7" MonteCarloMeasurements = "1" NeuralPDE = "5" +NeuralOperators = "0.5" Optimization = "4" OptimizationOptimJL = "0.4" OptimizationOptimisers = "0.3" @@ -53,4 +55,4 @@ Plots = "1.36" QuasiMonteCarlo = "0.3.2" Random = "1" Roots = "2.0" -SpecialFunctions = "2.1" +SpecialFunctions = "2.1" \ No newline at end of file diff --git a/docs/pages.jl b/docs/pages.jl index e0a2741a2d..54b706f28c 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -3,6 +3,7 @@ pages = ["index.md", "Bayesian PINNs for Coupled ODEs" => "tutorials/Lotka_Volterra_BPINNs.md", "PINNs DAEs" => "tutorials/dae.md", "Parameter Estimation with PINNs for ODEs" => "tutorials/ode_parameter_estimation.md", + "Physics informed Neural Operator ODEs" => "tutorials/pino_ode.md", "Deep Galerkin Method" => "tutorials/dgm.md" #"examples/nnrode_example.md", # currently incorrect ], "PDE PINN Tutorials" => Any[ @@ -31,6 +32,7 @@ pages = ["index.md", "manual/training_strategies.md", "manual/adaptive_losses.md", "manual/logging.md", - "manual/neural_adapters.md"], + "manual/neural_adapters.md", + "manual/pino_ode.md"], "Developer Documentation" => Any["developer/debugging.md"] ] diff --git a/docs/src/manual/pino_ode.md b/docs/src/manual/pino_ode.md new file mode 100644 index 0000000000..c26ef79582 --- /dev/null +++ b/docs/src/manual/pino_ode.md @@ -0,0 +1,5 @@ +# Physics-Informed Neural Operator (PINO) for ODEs + +```@docs +PINOODE +``` diff --git a/docs/src/tutorials/pino_ode.md b/docs/src/tutorials/pino_ode.md new file mode 100644 index 0000000000..fb4d2790d9 --- /dev/null +++ b/docs/src/tutorials/pino_ode.md @@ -0,0 +1,99 @@ +# Physics Informed Neural Operator for ODEs + +This tutorial provides an example of how to use the Physics Informed Neural Operator (PINO) for solving a family of parametric ordinary differential equations (ODEs). + +## Operator Learning for a family of parametric ODEs + +In this section, we will define a parametric ODE and then learn it with a PINO using [`PINOODE`](@ref). The PINO will be trained to learn the mapping from the parameters of the ODE to its solution. + +```@example pino +using Test +using OptimizationOptimisers +using Lux +using Statistics, Random +using NeuralOperators +using NeuralPDE + +# Define the parametric ODE equation +equation = (u, p, t) -> p[1] * cos(p[2] * t) + p[3] +tspan = (0.0, 1.0) +u0 = 1.0 +prob = ODEProblem(equation, u0, tspan) + +# Set the number of parameters for the ODE +number_of_parameter = 3 +# Define the DeepONet architecture for the PINO +deeponet = NeuralOperators.DeepONet( + Chain( + Dense(number_of_parameter => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)), + Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), + Dense(10 => 10, Lux.tanh_fast))) + +# Define the bounds for the parameters +bounds = [(1.0, pi), (1.0, 2.0), (2.0, 3.0)] +number_of_parameter_samples = 50 +# Define the training strategy +strategy = StochasticTraining(20) +# Define the optimizer +opt = OptimizationOptimisers.Adam(0.03) +alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy) +# Solve the ODE problem using the PINOODE algorithm +sol = solve(prob, alg, verbose = false, maxiters = 4000) +``` + +Now let's compare the prediction from the learned operator with the ground truth solution which is obtained by analytic solution of the parametric ODE. + +```@example pino +using Plots + +function get_trainset(bounds, tspan, number_of_parameters, dt) + p_ = [range(start = b[1], length = number_of_parameters, stop = b[2]) for b in bounds] + p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...) + t_ = collect(tspan[1]:dt:tspan[2]) + t = collect(reshape(t_, 1, size(t_, 1), 1)) + (p, t) +end + +# Compute the ground truth solution for each parameter +ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) + p[3] * t +function ground_solution_f(p, t) + reduce(hcat, + [[ground_solution(u0, p[:, i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)]) +end + +# generate the solution with new parameters for test the model +(p, t) = get_trainset(bounds, tspan, 50, 0.025) +# compute the ground truth solution +ground_solution_ = ground_solution_f(p, t) +# predict the solution with the PINO model +predict = sol.interp((p, t)) + +# calculate the errors between the ground truth solution and the predicted solution +errors = ground_solution_ - predict +# calculate the mean error and the standard deviation of the errors +mean_error = mean(errors) +# calculate the standard deviation of the errors +std_error = std(errors) + +p, t = get_trainset(bounds, tspan, 100, 0.01) +ground_solution_ = ground_solution_f(p, t) +predict = sol.interp((p, t)) + +errors = ground_solution_ - predict +mean_error = mean(errors) +std_error = std(errors) + +# Plot the predicted solution and the ground truth solution as a filled contour plot +# predict, represents the predicted solution for each parameter value and time +plot(predict, linetype = :contourf) +plot!(ground_solution_, linetype = :contourf) +``` + +```@example pino +# 'i' is the index of the parameter 'p' in the dataset +i = 20 +# 'predict' is the predicted solution from the PINO model +plot(predict[:, i], label = "Predicted") +# 'ground' is the ground truth solution +plot!(ground_solution_[:, i], label = "Ground truth") +``` diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index c0798c6270..23f2f18f2f 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -16,6 +16,7 @@ using IntervalSets: infimum, supremum using LinearAlgebra: Diagonal using Lux: Lux, Chain, Dense, SkipConnection, StatefulLuxLayer using Lux: FromFluxAdaptor, recursive_eltype +using NeuralOperators: DeepONet using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer using MLDataDevices: CPUDevice, get_device using Optimisers: Optimisers, Adam @@ -79,7 +80,7 @@ include("adaptive_losses.jl") include("ode_solve.jl") include("dae_solve.jl") - +include("pino_ode_solve.jl") include("transform_inf_integral.jl") include("discretize.jl") @@ -90,6 +91,7 @@ include("PDE_BPINN.jl") include("dgm.jl") +export PINOODE export NNODE, NNDAE export BNNODE, ahmc_bayesian_pinn_ode, ahmc_bayesian_pinn_pde export PhysicsInformedNN, discretize diff --git a/src/pino_ode_solve.jl b/src/pino_ode_solve.jl new file mode 100644 index 0000000000..04e65ee871 --- /dev/null +++ b/src/pino_ode_solve.jl @@ -0,0 +1,391 @@ +""" + PINOODE(chain, + opt, + bounds; + init_params = nothing, + strategy = nothing + kwargs...) + +Algorithm for solving paramentric ordinary differential equations using a physics-informed +neural operator, which is used as a solver for a parametrized `ODEProblem`. + +## Positional Arguments + +* `chain`: A neural network architecture, defined as a `AbstractLuxLayer` or `Flux.Chain`. + `Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)` +* `opt`: The optimizer to train the neural network. +* `bounds`: A dictionary containing the bounds for the parameters of the parametric ODE. +* `number_of_parameters`: The number of points of train set in parameters boundaries. + +## Keyword Arguments + +* `init_params`: The initial parameters of the neural network. By default, this is `nothing`, + which thus uses the random initialization provided by the neural network library. +* `strategy`: The strategy for training the neural network. +* `additional_loss`: additional loss function added to the default one. For example, add training on data. +* `kwargs`: Extra keyword arguments are splatted to the Optimization.jl `solve` call. + +## References + +* Sifan Wang "Learning the solution operator of parametric partial differential equations with physics-informed DeepOnets" +* Zongyi Li "Physics-Informed Neural Operator for Learning Partial Differential Equations" +""" +@concrete struct PINOODE + chain + opt + bounds + number_of_parameters::Int + init_params + strategy <: Union{Nothing, AbstractTrainingStrategy} + additional_loss <: Union{Nothing, Function} + kwargs +end + +function PINOODE(chain, + opt, + bounds, + number_of_parameters; + init_params = nothing, + strategy = nothing, + additional_loss = nothing, + kwargs...) + chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain)) + return PINOODE(chain, opt, bounds, number_of_parameters, + init_params, strategy, additional_loss, kwargs) +end + +@concrete struct PINOPhi + model <: AbstractLuxLayer + smodel <: StatefulLuxLayer +end + +function PINOPhi(model::AbstractLuxLayer, st) + return PINOPhi(model, StatefulLuxLayer{false}(model, nothing, st)) +end + +function generate_pino_phi_θ(chain::AbstractLuxLayer, ::Nothing) + θ, st = LuxCore.setup(Random.default_rng(), chain) + PINOPhi(chain, st), θ +end + +function generate_pino_phi_θ(chain::AbstractLuxLayer, init_params) + st = LuxCore.initialstates(Random.default_rng(), chain) + PINOPhi(chain, st), init_params +end + +function (f::PINOPhi{C, T})(x, θ) where {C <: AbstractLuxLayer, T} + dev = safe_get_device(θ) + return f(dev, safe_expand(dev, x), θ) +end + +function (f::PINOPhi{C, T})(dev, x, θ) where {C <: AbstractLuxLayer, T} + f.smodel(dev(x), θ) +end + +function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ) where {C <: DeepONet, T} + p, t = x + branch_left, branch_right = p, p + trunk_left, trunk_right = t .+ sqrt(eps(eltype(t))), t + x_left = (branch_left, trunk_left) + x_right = (branch_right, trunk_right) + (phi(x_left, θ) .- phi(x_right, θ)) ./ sqrt(eps(eltype(t))) +end + +function dfdx(phi::PINOPhi{C, T}, x::Array, + θ) where {C <: Lux.Chain, T} + ε = [zeros(eltype(x), size(x)[1] - 1)..., sqrt(eps(eltype(x)))] + (phi(x .+ ε, θ) - phi(x, θ)) ./ sqrt(eps(eltype(x))) +end + +function physics_loss( + phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where {C <: DeepONet, T} + p, t = x + f = prob.f + out = phi(x, θ) + if size(p, 1) == 1 + f_vec = reduce(hcat, + [reduce(vcat, [f(out[j, i], p[1, i], t[j]) for j in axes(t, 2)]) + for i in axes(p, 2)]) + else + f_vec = reduce(hcat, + [reduce(vcat, [f(out[j, i], p[:, i], t[j]) for j in axes(t, 2)]) + for i in axes(p, 2)]) + end + du = dfdx(phi, x, θ) + norm = prod(size(du)) + sum(abs2, du .- f_vec) / norm +end + +function physics_loss( + phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where { + C <: Lux.Chain, T} + p, t = x + x_ = reduce(vcat, x) + f = prob.f + out = phi(x_, θ) + if size(p, 1) == 1 && size(out, 1) == 1 + f_vec = f.(out, p, t) + elseif size(p, 1) > 1 + f_vec = reduce(hcat, + [reduce(vcat, [f(out[1, i, j], p[:, i, j], t[1, i, j]) for j in axes(t, 3)]) + for i in axes(p, 2)]) + elseif size(out, 1) > 1 + f_vec = reduce(hcat, + [reduce(vcat, [f(out[:, i, j], p[1, i, j], t[1, i, j]) for j in axes(t, 3)]) + for i in axes(p, 2)]) + end + du = dfdx(phi, x_, θ) + norm = prod(size(out)) + sum(abs2, du .- f_vec) / norm +end + +function initial_condition_loss( + phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where { + C <: DeepONet, T} + p, t = x + t0 = reshape([prob.tspan[1]], (1, 1, 1)) + x0 = (p, t0) + u = phi(x0, θ) + u0 = size(prob.u0, 1) == 1 ? fill(prob.u0, size(u)) : + reduce(vcat, [fill(u0, size(u)) for u0 in prob.u0]) + norm = prod(size(u0)) + sum(abs2, u .- u0) / norm +end + +function initial_condition_loss( + phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where { + C <: Lux.Chain, T} + p, t = x + t0 = fill(prob.tspan[1], size(t)) + x0 = reduce(vcat, (p, t0)) + u = phi(x0, θ) + u0 = size(prob.u0, 1) == 1 ? fill(prob.u0, size(t)) : + reduce(vcat, [fill(u0, size(t)) for u0 in prob.u0]) + norm = prod(size(u0)) + sum(abs2, u .- u0) / norm +end + +function get_trainset( + strategy::GridTraining, chain::DeepONet, bounds, number_of_parameters, tspan) + dt = strategy.dx + p_ = [range(start = b[1], length = number_of_parameters, stop = b[2]) for b in bounds] + p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...) + t_ = collect(tspan[1]:dt:tspan[2]) + t = reshape(t_, 1, size(t_, 1), 1) + (p, t) +end + +function get_trainset( + strategy::GridTraining, chain::Chain, bounds, number_of_parameters, tspan) + dt = strategy.dx + tspan_ = tspan[1]:dt:tspan[2] + pspan = [range(start = b[1], length = number_of_parameters, stop = b[2]) + for b in bounds] + x_ = hcat(vec(map( + points -> collect(points), Iterators.product([pspan..., tspan_]...)))...) + x = reshape(x_, size(bounds, 1) + 1, prod(size.(pspan, 1)), size(tspan_, 1)) + p, t = x[1:(end - 1), :, :], x[[end], :, :] + (p, t) +end + +function get_trainset( + strategy::StochasticTraining, chain::Union{DeepONet, Chain}, + bounds, number_of_parameters, tspan) + (number_of_parameters != strategy.points && chain isa Chain) && + throw(error("number_of_parameters should be the same strategy.points for StochasticTraining")) + p = reduce(vcat, + [(bound[2] .- bound[1]) .* rand(1, number_of_parameters) .+ bound[1] + for bound in bounds]) + t = (tspan[2] .- tspan[1]) .* rand(1, strategy.points, 1) .+ tspan[1] + (p, t) +end + +function generate_loss( + strategy::GridTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan) + x = get_trainset(strategy, phi.smodel.model, bounds, number_of_parameters, tspan) + function loss(θ, _) + initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ) + end +end + +function generate_loss( + strategy::StochasticTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan) + function loss(θ, _) + x = get_trainset(strategy, phi.smodel.model, bounds, number_of_parameters, tspan) + initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ) + end +end + +""" +PINOODEInterpolation(phi, θ) + +Interpolation of the solution of the ODE using a trained neural network. + +## Arguments +* `phi`: The neural network +* `θ`: The parameters of the neural network. +``` +""" +@concrete struct PINOODEInterpolation{T <: PINOPhi, T2} + phi::T + θ::T2 +end + +""" +Override interpolation method for PINOODEInterpolation + +## Arguments +* `x`: Input data on which the solution is to be interpolated. +## Example + +```jldoctest +interp = PINOODEInterpolation(phi, θ) +x = rand(2, 50, 10) +interp(x) +``` +""" +(f::PINOODEInterpolation)(x::AbstractArray) = f.phi(x, f.θ) + +""" +Override interpolation method for PINOODEInterpolation + +## Arguments +# * `p`: The parameters points on which the solution is to be interpolated. +# * `t`: The time points on which the solution is to be interpolated. + +## Example +```jldoctest +interp = PINOODEInterpolation(phi, θ) +p,t = rand(1, 50, 10), rand(1, 50, 10) +interp(p, t) +``` +""" +function (f::PINOODEInterpolation)(p::AbstractArray, t::AbstractArray) + if f.phi.model isa DeepONet + f.phi((p, t), f.θ) + elseif f.phi.model isa Chain + if size(p, 2) != size(t, 2) + error("t should be same size as p") + end + f.phi(reduce(vcat, (p, t)), f.θ) + else + error("Only DeepONet and Chain neural networks are supported with PINO ODE") + end +end + +function (f::PINOODEInterpolation)(p::AbstractArray, t::Number) + if f.phi.model isa DeepONet + t_ = [t] + f.phi((p, t_), f.θ) + elseif f.phi.model isa Chain + t_ = fill(t, size(p)) + f.phi(reduce(vcat, (p, t_)), f.θ) + else + error("Only DeepONet and Chain neural networks are supported with PINO ODE") + end +end + +SciMLBase.interp_summary(::PINOODEInterpolation) = "Trained neural network interpolation" +SciMLBase.allowscomplex(::PINOODE) = true + +function (sol::SciMLBase.AbstractODESolution)(t::Union{Number, AbstractArray}) + sol.interp(sol.prob.p, t) +end + +function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, + alg::PINOODE, + args...; + abstol = 1.0f-8, + reltol = 1.0f-3, + verbose = false, + saveat = nothing, + maxiters = nothing) + (; tspan, u0, f) = prob + (; chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss) = alg + + if !(chain isa AbstractLuxLayer) + error("Only Lux.AbstractLuxLayer neural networks are supported") + + if !(chain isa DeepONet) || !(chain isa Chain) + error("Only DeepONet and Chain neural networks are supported with PINO ODE") + end + end + + phi, init_params = generate_pino_phi_θ(chain, init_params) + + init_params = ComponentArray(init_params) + + isinplace(prob) && + throw(error("The PINOODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t).")) + + try + if chain isa DeepONet + in_dim = chain.branch.layers.layer_1.in_dims + u = rand(in_dim, number_of_parameters) + v = rand(1, 10, 1) + x = (u, v) + phi(x, init_params) + end + if chain isa Chain + in_dim = chain.layers.layer_1.in_dims + x = rand(in_dim, number_of_parameters) + phi(x, init_params) + end + catch err + if isa(err, DimensionMismatch) + throw(DimensionMismatch("Dimensions of input data and chain should match")) + else + throw(err) + end + end + + if strategy === nothing + strategy = StochasticTraining(100) + elseif !(strategy isa GridTraining || strategy isa StochasticTraining) + throw(ArgumentError("Only GridTraining and StochasticTraining strategy is supported")) + end + + inner_f = generate_loss( + strategy, prob, phi, bounds, number_of_parameters, tspan) + + function total_loss(θ, _) + L2_loss = inner_f(θ, nothing) + if !(additional_loss isa Nothing) + L2_loss = L2_loss + additional_loss(phi, θ) + end + L2_loss + end + + # Optimization Algo for Training Strategies + opt_algo = Optimization.AutoZygote() + + # Creates OptimizationFunction Object from total_loss + optf = OptimizationFunction(total_loss, opt_algo) + + iteration = 0 + callback = function (p, l) + iteration += 1 + verbose && println("Current loss is: $l, Iteration: $iteration") + l < abstol + end + + optprob = OptimizationProblem(optf, init_params) + res = solve(optprob, opt; callback, maxiters, alg.kwargs...) + + (p, t) = get_trainset(strategy, phi.smodel.model, bounds, number_of_parameters, tspan) + interp = PINOODEInterpolation(phi, res.u) + u = interp(p, t) + prob_sol = ODEProblem(f.f, u0, tspan, p) + + sol = SciMLBase.build_solution(prob_sol, alg, t, u; + k = res, dense = true, + interp = interp, + calculate_error = false, + retcode = ReturnCode.Success, + original = res, + resid = res.objective) + SciMLBase.has_analytic(prob.f) && + SciMLBase.calculate_solution_errors!(sol; timeseries_errors = true, + dense_errors = false) + sol +end diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 974f2529fa..ca07676f26 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -291,8 +291,8 @@ end """ WeightedIntervalTraining(weights, samples) -A training strategy that generates points for training based on the given inputs. -We split the timespan into equal segments based on the number of weights, +A training strategy that generates points for training based on the given inputs. +We split the timespan into equal segments based on the number of weights, then sample points in each segment based on that segments corresponding weight, such that the total number of sampled points is equivalent to the given samples diff --git a/test/PINO_ode_tests.jl b/test/PINO_ode_tests.jl new file mode 100644 index 0000000000..4f6e7e43f1 --- /dev/null +++ b/test/PINO_ode_tests.jl @@ -0,0 +1,311 @@ +@testsetup module PINOODETestSetup +using Lux, NeuralOperators + +function get_trainset(chain::DeepONet, bounds, number_of_parameters, tspan, dt) + p_ = [range(start = b[1], length = number_of_parameters, stop = b[2]) for b in bounds] + p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...) + t_ = collect(tspan[1]:dt:tspan[2]) + t = reshape(t_, 1, size(t_, 1), 1) + (p, t) +end + +function get_trainset(chain::Lux.Chain, bounds, number_of_parameters, tspan, dt) + tspan_ = tspan[1]:dt:tspan[2] + pspan = [range(start = b[1], length = number_of_parameters, stop = b[2]) + for b in bounds] + x_ = hcat(vec(map( + points -> collect(points), Iterators.product([pspan..., tspan_]...)))...) + x = reshape(x_, size(bounds, 1) + 1, prod(size.(pspan, 1)), size(tspan_, 1)) + p, t = x[1:(end - 1), :, :], x[[end], :, :] + (p, t) +end +export get_trainset +end +#Test Chain +@testitem "Example Chain du = cos(p * t)" tags=[:pinoode] setup=[PINOODETestSetup] begin + using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random + equation = (u, p, t) -> cos(p * t) + tspan = (0.0, 1.0) + u0 = 1.0 + prob = ODEProblem(equation, u0, tspan) + chain = Chain( + Dense(2 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 1)) + x = rand(2, 50, 10) + θ, st = Lux.setup(Random.default_rng(), chain) + b = chain(x, θ, st)[1] + + bounds = [(pi, 2pi)] + number_of_parameters = 300 + strategy = StochasticTraining(300) + opt = OptimizationOptimisers.Adam(0.01) + alg = PINOODE(chain, opt, bounds, number_of_parameters; strategy = strategy) + sol = solve(prob, alg, verbose = false, maxiters = 5000) + ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p) + p, t = get_trainset(chain, bounds, 50, tspan, 0.025) + ground_solution = ground_analytic.(u0, p, t) + predict_sol = sol.interp(p, t) + @test ground_solution≈predict_sol rtol=0.05 + p, t = get_trainset(chain, bounds, 100, tspan, 0.01) + ground_solution = ground_analytic.(u0, p, t) + predict_sol = sol.interp(p, t) + @test ground_solution≈predict_sol rtol=0.05 + + p = sol.prob.p + ground_solution = ground_analytic.(u0, p, [1.0]) + predict_sol = sol(1.0) + @test ground_solution≈predict_sol rtol=0.05 + + p = sol.prob.p + t = rand(size(p)...) + ground_solution = ground_analytic.(u0, p, t) + predict_sol = sol(t) + @test ground_solution≈predict_sol rtol=0.05 +end + +#Test DeepONet +@testitem "Example DeepONet du = cos(p * t)" tags=[:pinoode] setup=[PINOODETestSetup] begin + using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random + equation = (u, p, t) -> cos(p * t) + tspan = (0.0, 1.0) + u0 = 1.0 + prob = ODEProblem(equation, u0, tspan) + deeponet = NeuralOperators.DeepONet( + Chain( + Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)), + Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), + Dense(10 => 10, Lux.tanh_fast))) + u = rand(Float32, 1, 50) + v = rand(Float32, 1, 40, 1) + branch = deeponet.branch + θ, st = Lux.setup(Random.default_rng(), branch) + b = branch(u, θ, st)[1] + trunk = deeponet.trunk + θ, st = Lux.setup(Random.default_rng(), trunk) + t = trunk(v, θ, st)[1] + θ, st = Lux.setup(Random.default_rng(), deeponet) + deeponet((u, v), θ, st)[1] + + bounds = [(pi, 2pi)] + number_of_parameters = 50 + strategy = StochasticTraining(40) + opt = OptimizationOptimisers.Adam(0.01) + alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy) + sol = solve(prob, alg, verbose = false, maxiters = 3000) + ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p) + p, t = get_trainset(deeponet, bounds, 50, tspan, 0.025) + ground_solution = ground_analytic.(u0, p, vec(t)) + predict_sol = sol.interp(p, t) + @test ground_solution≈predict_sol rtol=0.05 + p, t = get_trainset(deeponet, bounds, 100, tspan, 0.01) + ground_solution = ground_analytic.(u0, p, vec(t)) + predict_sol = sol.interp(p, t) + @test ground_solution≈predict_sol rtol=0.05 + + p, t = sol.prob.p, rand(1, 20, 1) + ground_solution = ground_analytic.(u0, p, vec(t)) + predict_sol = sol(t) + @test ground_solution≈predict_sol rtol=0.05 +end + +@testitem "Example du = cos(p * t) + u" tags=[:pinoode] setup=[PINOODETestSetup] begin + using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random + eq_(u, p, t) = cos(p * t) + u + tspan = (0.0, 1.0) + u0 = 1.0 + prob = ODEProblem(eq_, u0, tspan) + deeponet = NeuralOperators.DeepONet( + Chain( + Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)), + Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), + Dense(10 => 10, Lux.tanh_fast))) + bounds = [(0.1, 2.0)] + number_of_parameters = 40 + dt = (tspan[2] - tspan[1]) / 40 + strategy = GridTraining(0.1) + opt = OptimizationOptimisers.Adam(0.01) + alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy) + sol = solve(prob, alg, verbose = false, maxiters = 4000) + sol.original.objective + #if u0 == 1 + ground_analytic_(u0, p, t) = (p * sin(p * t) - cos(p * t) + (p^2 + 2) * exp(t)) / + (p^2 + 1) + p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt) + ground_solution = ground_analytic_.(u0, p, vec(t)) + predict_sol = sol.interp(p, t) + @test ground_solution≈predict_sol rtol=0.05 +end + +@testitem "Example with data du = p*t^2" tags=[:pinoode] setup=[PINOODETestSetup] begin + using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random + equation = (u, p, t) -> p * t^2 + tspan = (0.0, 1.0) + u0 = 0.0 + prob = ODEProblem(equation, u0, tspan) + deeponet = NeuralOperators.DeepONet( + Chain( + Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)), + Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), + Dense(10 => 10, Lux.tanh_fast))) + bounds = [(0.0, 10.0)] + number_of_parameters = 60 + dt = (tspan[2] - tspan[1]) / 40 + strategy = StochasticTraining(60) + opt = OptimizationOptimisers.Adam(0.01) + + #generate data + ground_analytic = (u0, p, t) -> u0 + p * t^3 / 3 + p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt) + sol = ground_analytic.(u0, p, vec(t)) + function additional_loss_(phi, θ) + u = phi((p, t), θ) + norm = prod(size(u)) + sum(abs2, u .- sol) / norm + end + + alg = PINOODE( + deeponet, opt, bounds, number_of_parameters; strategy = strategy, + additional_loss = additional_loss_) + sol = solve(prob, alg, verbose = false, maxiters = 3000) + + p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt) + ground_solution = ground_analytic.(u0, p, vec(t)) + predict_sol = sol.interp(p, t) + @test ground_solution≈predict_sol rtol=0.05 +end + +#multiple parameters Сhain +@testitem "Example multiple parameters Сhain du = p1 * cos(p2 * t)" tags=[:pinoode] setup=[PINOODETestSetup] begin + using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random + equation = (u, p, t) -> p[1] * cos(p[2] * t) #+ p[3] + tspan = (0.0, 1.0) + u0 = 1.0 + prob = ODEProblem(equation, u0, tspan) + + input_branch_size = 2 + chain = Chain( + Dense(input_branch_size + 1 => 10, Lux.tanh_fast), + Dense(10 => 10, Lux.tanh_fast), + Dense(10 => 10, Lux.tanh_fast), Dense(10 => 1)) + + x = rand(Float32, 3, 1000, 10) + θ, st = Lux.setup(Random.default_rng(), chain) + c = chain(x, θ, st)[1] + + bounds = [(1.0, pi), (1.0, 2.0)]#, (2.0, 3.0)] + number_of_parameters = 200 + strategy = StochasticTraining(200) + opt = OptimizationOptimisers.Adam(0.01) + alg = PINOODE(chain, opt, bounds, number_of_parameters; strategy = strategy) + sol = solve(prob, alg, verbose = false, maxiters = 5000) + + ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) #+ p[3] * t + + function ground_solution_f(p, t) + reduce(hcat, + [[ground_solution(u0, p[:, i, j], t[1, i, j]) for j in axes(t, 3)] + for i in axes(p, 2)])' + end + (p, t) = get_trainset(chain, bounds, 20, tspan, 0.1) + ground_solution_ = ground_solution_f(p, t) + predict = sol.interp(p, t)[1, :, :] + @test ground_solution_≈predict rtol=0.05 + + p, t = get_trainset(chain, bounds, 50, tspan, 0.025) + ground_solution_ = ground_solution_f(p, t) + predict_sol = sol.interp(p, t)[1, :, :] + @test ground_solution_≈predict_sol rtol=0.05 +end + +#multiple parameters DeepOnet +@testitem "Example multiple parameters DeepOnet du = p1 * cos(p2 * t)" tags=[:pinoode] setup=[PINOODETestSetup] begin + using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random + equation = (u, p, t) -> p[1] * cos(p[2] * t) #+ p[3] + tspan = (0.0, 1.0) + u0 = 1.0 + prob = ODEProblem(equation, u0, tspan) + + input_branch_size = 3 + deeponet = NeuralOperators.DeepONet( + Chain( + Dense(input_branch_size => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)), + Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), + Dense(10 => 10, Lux.tanh_fast))) + + u = rand(2, 50) + v = rand(1, 40, 1) + θ, st = Lux.setup(Random.default_rng(), deeponet) + c = deeponet((u, v), θ, st)[1] + + bounds = [(1.0, pi), (1.0, 2.0), (2.0, 3.0)] + number_of_parameters = 100 + strategy = StochasticTraining(50) + opt = OptimizationOptimisers.Adam(0.01) + alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy) + sol = solve(prob, alg, verbose = false, maxiters = 5000) + ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) #+ p[3] * t + function ground_solution_f(p, t) + reduce(hcat, + [[ground_solution(u0, p[:, i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)]) + end + + (p, t) = get_trainset(deeponet, bounds, 50, tspan, 0.025) + ground_solution_ = ground_solution_f(p, t) + predict = sol.interp(p, t) + @test ground_solution_≈predict rtol=0.05 + + p, t = get_trainset(deeponet, bounds, 100, tspan, 0.01) + ground_solution_ = ground_solution_f(p, t) + predict = sol.interp(p, t) + @test ground_solution_≈predict rtol=0.05 +end + +#vector output +@testitem "Example du = [cos(p * t), sin(p * t)]" tags=[:pinoode] setup=[PINOODETestSetup] begin + using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random + equation = (u, p, t) -> [cos(p * t), sin(p * t)] + tspan = (0.0, 1.0) + u0 = [1.0, 0.0] + prob = ODEProblem(equation, u0, tspan) + input_branch_size = 1 + chain = Chain( + Dense(input_branch_size + 1 => 10, Lux.tanh_fast), + Dense(10 => 10, Lux.tanh_fast), + Dense(10 => 10, Lux.tanh_fast), + Dense(10 => 10, Lux.tanh_fast), Dense(10 => 2)) + + bounds = [(pi, 2pi)] + number_of_parameters = 100 + strategy = StochasticTraining(100) + opt = OptimizationOptimisers.Adam(0.01) + alg = PINOODE(chain, opt, bounds, number_of_parameters; strategy = strategy) + sol = solve(prob, alg, verbose = false, maxiters = 6000) + + ground_solution = (u0, p, t) -> [1 + sin(p * t) / p, 1 / p - cos(p * t) / p] + function ground_solution_f(p, t) + ans_1 = reduce(hcat, + [reduce(vcat, + [ground_solution(u0, p[1, i, 1], t[1, 1, j])[1] for i in axes(p, 2)]) + for j in axes(t, 3)]) + ans_2 = reduce(hcat, + [reduce(vcat, + [ground_solution(u0, p[1, i, 1], t[1, 1, j])[2] for i in axes(p, 2)]) + for j in axes(t, 3)]) + + ans_1 = reshape(ans_1, 1, size(ans_1)...) + ans_2 = reshape(ans_2, 1, size(ans_2)...) + vcat(ans_1, ans_2) + end + p, t = get_trainset(chain, bounds, 50, tspan, 0.025) + ground_solution_ = ground_solution_f(p, t) + predict = sol.interp(p, t) + @test ground_solution_[1, :, :]≈predict[1, :, :] rtol=0.05 + @test ground_solution_[2, :, :]≈predict[2, :, :] rtol=0.05 + @test ground_solution_≈predict rtol=0.05 + + p, t = get_trainset(chain, bounds, 300, tspan, 0.01) + ground_solution_ = ground_solution_f(p, t) + predict = sol.interp(p, t) + @test ground_solution_[1, :, :]≈predict[1, :, :] rtol=0.05 + @test ground_solution_[2, :, :]≈predict[2, :, :] rtol=0.05 + @test ground_solution_≈predict rtol=0.05 +end