diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 6a0ca9f..d8c6e1b 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -20,10 +20,12 @@ include("layers.jl") include("fno.jl") include("deeponet.jl") +include("nomad.jl") export FourierTransform export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel export FourierNeuralOperator export DeepONet +export NOMAD end diff --git a/src/nomad.jl b/src/nomad.jl new file mode 100644 index 0000000..99daddd --- /dev/null +++ b/src/nomad.jl @@ -0,0 +1,102 @@ +""" + NOMAD(approximator, decoder, concatenate) + +Constructs a NOMAD from `approximator` and `decoder` architectures. Make sure the output from +`approximator` combined with the coordinate dimension has compatible size for input to `decoder` + +## Arguments + + - `approximator`: `Lux` network to be used as approximator net. + - `decoder`: `Lux` network to be used as decoder net. + +## Keyword Arguments + + - `concatenate`: function that defines the concatenation of output from `approximator` and the coordinate + dimension, defaults to concatenation along first dimension after vectorizing the tensors + +## References + +[1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD: +Nonlinear Manifold Decoders for Operator Learning", doi: https://arxiv.org/abs/2206.03551 + +## Example + +```jldoctest +julia> approximator_net = Chain(Dense(8 => 32), Dense(32 => 32), Dense(32 => 16)); + +julia> decoder_net = Chain(Dense(18 => 16), Dense(16 => 16), Dense(16 => 8)); + +julia> nomad = NOMAD(approximator_net, decoder_net); + +julia> ps, st = Lux.setup(Xoshiro(), nomad); + +julia> u = rand(Float32, 8, 5); + +julia> y = rand(Float32, 2, 5); + +julia> size(first(nomad((u, y), ps, st))) +(8, 5) +``` +""" +@concrete struct NOMAD <: AbstractExplicitContainerLayer{(:approximator, :decoder)} + approximator + decoder + concatenate <: Function +end + +NOMAD(approximator, decoder) = NOMAD(approximator, decoder, __merge) + +""" + NOMAD(; approximator = (8, 32, 32, 16), decoder = (18, 16, 8, 8), + approximator_activation = identity, decoder_activation = identity) + +Constructs a NOMAD composed of Dense layers. Make sure that +last node of `approximator` + coordinate length = first node of `decoder` + +## Keyword arguments: + + - `approximator`: Tuple of integers containing the number of nodes in each layer for approximator net + - `decoder`: Tuple of integers containing the number of nodes in each layer for decoder net + - `approximator_activation`: activation function for approximator net + - `decoder_activation`: activation function for decoder net + - `concatenate`: function that defines the concatenation of output from `approximator` and the coordinate + dimension, defaults to concatenation along first dimension after vectorizing the tensors + +## References + +[1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD: +Nonlinear Manifold Decoders for Operator Learning", doi: https://arxiv.org/abs/2206.03551 + +## Example + +```jldoctest +julia> nomad = NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8)); + +julia> ps, st = Lux.setup(Xoshiro(), nomad); + +julia> u = rand(Float32, 8, 5); + +julia> y = rand(Float32, 2, 5); + +julia> size(first(nomad((u, y), ps, st))) +(8, 5) +``` +""" +function NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), + approximator_activation=identity, decoder_activation=identity, concatenate=__merge) + approximator_net = Chain([Dense(approximator[i] => approximator[i + 1], + approximator_activation) + for i in 1:(length(approximator) - 1)]...) + + decoder_net = Chain([Dense(decoder[i] => decoder[i + 1], decoder_activation) + for i in 1:(length(decoder) - 1)]...) + + return NOMAD(approximator_net, decoder_net, concatenate) +end + +function (nomad::NOMAD)(x, ps, st::NamedTuple) + a, st_a = nomad.approximator(x[1], ps.approximator, st.approximator) + out, st_d = nomad.decoder(nomad.concatenate(a, x[2]), ps.decoder, st.decoder) + + return out, (approximator=st_a, decoder=st_d) +end diff --git a/src/utils.jl b/src/utils.jl index 1d4c278..fd20eb7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -69,3 +69,34 @@ end return additional(b_ .* t_, params.ps, params.st) # p x u_size x N x nb => out_size x N x nb end + +@inline function __batch_vectorize(x::AbstractArray{T, N}) where {T, N} + dim_length = ndims(x) - 1 + nb = size(x)[end] + + slice = [Colon() for _ in 1:dim_length] + return reduce(hcat, [vec(view(x, slice..., i)) for i in 1:nb]) +end + +@inline function __merge(x::AbstractArray{T1, 2}, y::AbstractArray{T2, 2}) where {T1, T2} + return cat(x, y; dims=1) +end + +@inline function __merge( + x::AbstractArray{T1, N1}, y::AbstractArray{T2, 2}) where {T1, T2, N1} + x_ = __batch_vectorize(x) + return vcat(x_, y) +end + +@inline function __merge( + x::AbstractArray{T1, 2}, y::AbstractArray{T2, N2}) where {T1, T2, N2} + y_ = __batch_vectorize(y) + return vcat(x, y_) +end + +@inline function __merge( + x::AbstractArray{T1, N1}, y::AbstractArray{T2, N2}) where {T1, T2, N1, N2} + x_ = __batch_vectorize(x) + y_ = __batch_vectorize(y) + return vcat(x_, y_) +end diff --git a/test/nomad_tests.jl b/test/nomad_tests.jl new file mode 100644 index 0000000..00f0b83 --- /dev/null +++ b/test/nomad_tests.jl @@ -0,0 +1,28 @@ +@testitem "NOMAD" setup=[SharedTestSetup] begin + @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES + rng = StableRNG(12345) + + setups = [ + (u_size=(1, 5), y_size=(1, 5), out_size=(1, 5), + approximator=(1, 16, 16, 15), decoder=(16, 8, 4, 1), name="Scalar"), + (u_size=(8, 5), y_size=(2, 5), out_size=(8, 5), + approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), name="Vector")] + + @testset "$(setup.name)" for setup in setups + u = rand(Float32, setup.u_size...) |> aType + y = rand(Float32, setup.y_size...) |> aType + nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder) + + ps, st = Lux.setup(rng, nomad) |> dev + @inferred first(nomad((u, y), ps, st)) + @jet first(nomad((u, y), ps, st)) + + pred = first(nomad((u, y), ps, st)) + @test setup.out_size == size(pred) + + __f = (u, y, ps) -> sum(abs2, first(nomad((u, y), ps, st))) + test_gradients( + __f, u, y, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()]) + end + end +end diff --git a/test/utils_tests.jl b/test/utils_tests.jl new file mode 100644 index 0000000..e4901a0 --- /dev/null +++ b/test/utils_tests.jl @@ -0,0 +1,54 @@ +@testitem "utils" setup=[SharedTestSetup] begin + import NeuralOperators: __project, __merge, __batch_vectorize + @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES + rng = StableRNG(12345) + + setups = [ + (b_size=(16, 5), t_size=(16, 10, 5), out_size=(10, 5), + additional=NoOpLayer(), name="Scalar"), + (b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(1, 10, 5), + additional=NoOpLayer(), name="Scalar II"), + (b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(3, 10, 5), + additional=NoOpLayer(), name="Vector"), + (b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), + out_size=(4, 3, 3, 10, 5), additional=NoOpLayer(), name="Tensor"), + (b_size=(16, 5), t_size=(16, 10, 5), out_size=(4, 10, 5), + additional=Dense(16 => 4), name="additional : Scalar"), + (b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(4, 10, 5), + additional=Dense(16 => 4), name="additional : Scalar II"), + (b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(4, 3, 10, 5), + additional=Dense(16 => 4), name="additional : Vector"), + (b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), out_size=(3, 4, 3, 4, 10, 5), + additional=Chain(Dense(16 => 4), ReshapeLayer((3, 4, 3, 4, 10))), + name="additional : Tensor")] + + @testset "project : $(setup.name)" for setup in setups + b = rand(Float32, setup.b_size...) |> aType + t = rand(Float32, setup.t_size...) |> aType + + ps, st = Lux.setup(rng, setup.additional) |> dev + @inferred first(__project(b, t, setup.additional, (; ps, st))) + @jet first(__project(b, t, setup.additional, (; ps, st))) + @test setup.out_size == + size(first(__project(b, t, setup.additional, (; ps, st)))) + end + + setups = [(x_size=(6, 5), y_size=(4, 5), out_size=(10, 5), name="Scalar"), + (x_size=(12, 5), y_size=(8, 5), out_size=(20, 5), name="Vector I"), + (x_size=(4, 6, 5), y_size=(6, 5), out_size=(30, 5), name="Vector II"), + (x_size=(4, 2, 3, 5), y_size=(2, 2, 3, 5), out_size=(36, 5), name="Tensor")] + + @testset "merge $(setup.name)" for setup in setups + x_size = rand(Float32, setup.x_size...) |> aType + y_size = rand(Float32, setup.y_size...) |> aType + + @test setup.out_size == size(__merge(x_size, y_size)) + end + + @testset "batch vectorize" begin + x_size = (4, 2, 3) + x = rand(Float32, x_size..., 5) |> aType + @test size(__batch_vectorize(x)) == (prod(x_size), 5) + end + end +end