Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Oct 30, 2024
1 parent 0617f42 commit 9c857a5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
1 change: 0 additions & 1 deletion src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ Interpolation of the solution of the ODE using a trained neural network.
* `phi`: The neural network
* `θ`: The parameters of the neural network.
```
"""
@concrete struct PINOODEInterpolation{T <: PINOPhi, T2}
phi::T
Expand Down
32 changes: 17 additions & 15 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,31 +174,31 @@ end
end

#multiple parameters Сhain
@testitem "Example multiple parameters Сhain du = p1 * cos(p2 * t) + p3" tags=[:pinoode] setup=[PINOODETestSetup] begin
@testitem "Example multiple parameters Сhain du = p1 * cos(p2 * t)" tags=[:pinoode] setup=[PINOODETestSetup] begin
using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random
equation = (u, p, t) -> p[1] * cos(p[2] * t) + p[3]
equation = (u, p, t) -> p[1] * cos(p[2] * t) #+ p[3]
tspan = (0.0, 1.0)
u0 = 1.0
prob = ODEProblem(equation, u0, tspan)

input_branch_size = 3
input_branch_size = 2
chain = Chain(
Dense(input_branch_size + 1 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast), Dense(10 => 1))

x = rand(Float32, 4, 1000, 10)
x = rand(Float32, 3, 1000, 10)
θ, st = Lux.setup(Random.default_rng(), chain)
c = chain(x, θ, st)[1]

bounds = [(1.0, pi), (1.0, 2.0), (2.0, 3.0)]
bounds = [(1.0, pi), (1.0, 2.0)]#, (2.0, 3.0)]
number_of_parameters = 200
strategy = StochasticTraining(200)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(chain, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 4000)
sol = solve(prob, alg, verbose = false, maxiters = 5000)

ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) + p[3] * t
ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) #+ p[3] * t

function ground_solution_f(p, t)
reduce(hcat,
Expand All @@ -217,9 +217,9 @@ end
end

#multiple parameters DeepOnet
@testitem "Example multiple parameters DeepOnet du = p1 * cos(p2 * t) + p3" tags=[:pinoode] setup=[PINOODETestSetup] begin
@testitem "Example multiple parameters DeepOnet du = p1 * cos(p2 * t)" tags=[:pinoode] setup=[PINOODETestSetup] begin
using NeuralPDE, Lux, OptimizationOptimisers, NeuralOperators, Random
equation = (u, p, t) -> p[1] * cos(p[2] * t) + p[3]
equation = (u, p, t) -> p[1] * cos(p[2] * t) #+ p[3]
tspan = (0.0, 1.0)
u0 = 1.0
prob = ODEProblem(equation, u0, tspan)
Expand All @@ -231,18 +231,18 @@ end
Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast)))

u = rand(3, 50)
u = rand(2, 50)
v = rand(1, 40, 1)
θ, st = Lux.setup(Random.default_rng(), deeponet)
c = deeponet((u, v), θ, st)[1]

bounds = [(1.0, pi), (1.0, 2.0), (2.0, 3.0)]
number_of_parameters = 50
number_of_parameters = 100
strategy = StochasticTraining(50)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 5000)
ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) + p[3] * t
ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) #+ p[3] * t
function ground_solution_f(p, t)
reduce(hcat,
[[ground_solution(u0, p[:, i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)])
Expand Down Expand Up @@ -270,10 +270,12 @@ end
chain = Chain(
Dense(input_branch_size + 1 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast), Dense(10 => 2))

bounds = [(pi, 2pi)]
number_of_parameters = 300
strategy = StochasticTraining(300)
number_of_parameters = 100
strategy = StochasticTraining(100)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(chain, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 6000)
Expand Down Expand Up @@ -305,5 +307,5 @@ end
predict = sol.interp(p, t)
@test ground_solution_[1, :, :]predict[1, :, :] rtol=0.05
@test ground_solution_[2, :, :]predict[2, :, :] rtol=0.05
@test ground_solution_predict rtol=0.3
@test ground_solution_predict rtol=0.05
end

0 comments on commit 9c857a5

Please sign in to comment.