Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat : NOMAD #27

Merged
merged 7 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ include("layers.jl")

include("fno.jl")
include("deeponet.jl")
include("nomad.jl")

export FourierTransform
export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel
export FourierNeuralOperator
export DeepONet
export NOMAD

end
102 changes: 102 additions & 0 deletions src/nomad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
NOMAD(approximator, decoder, concatenate)

Constructs a NOMAD from `approximator` and `decoder` architectures. Make sure the output from
`approximator` combined with the coordinate dimension has compatible size for input to `decoder`

## Arguments

- `approximator`: `Lux` network to be used as approximator net.
- `decoder`: `Lux` network to be used as decoder net.

## Keyword Arguments

- `concatenate`: function that defines the concatenation of output from `approximator` and the coordinate
dimension, defaults to concatenation along first dimension after vectorizing the tensors

## References

[1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD:
Nonlinear Manifold Decoders for Operator Learning", doi: https://arxiv.org/abs/2206.03551

## Example

```jldoctest
julia> approximator_net = Chain(Dense(8 => 32), Dense(32 => 32), Dense(32 => 16));

julia> decoder_net = Chain(Dense(18 => 16), Dense(16 => 16), Dense(16 => 8));

julia> nomad = NOMAD(approximator_net, decoder_net);

julia> ps, st = Lux.setup(Xoshiro(), nomad);

julia> u = rand(Float32, 8, 5);

julia> y = rand(Float32, 2, 5);

julia> size(first(nomad((u, y), ps, st)))
(8, 5)
```
"""
@concrete struct NOMAD <: AbstractExplicitContainerLayer{(:approximator, :decoder)}
approximator
decoder
concatenate <: Function
end

NOMAD(approximator, decoder) = NOMAD(approximator, decoder, __merge)

"""
NOMAD(; approximator = (8, 32, 32, 16), decoder = (18, 16, 8, 8),
approximator_activation = identity, decoder_activation = identity)

Constructs a NOMAD composed of Dense layers. Make sure that
last node of `approximator` + coordinate length = first node of `decoder`

## Keyword arguments:

- `approximator`: Tuple of integers containing the number of nodes in each layer for approximator net
- `decoder`: Tuple of integers containing the number of nodes in each layer for decoder net
- `approximator_activation`: activation function for approximator net
- `decoder_activation`: activation function for decoder net
- `concatenate`: function that defines the concatenation of output from `approximator` and the coordinate
dimension, defaults to concatenation along first dimension after vectorizing the tensors

## References

[1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD:
Nonlinear Manifold Decoders for Operator Learning", doi: https://arxiv.org/abs/2206.03551

## Example

```jldoctest
julia> nomad = NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8));

julia> ps, st = Lux.setup(Xoshiro(), nomad);

julia> u = rand(Float32, 8, 5);

julia> y = rand(Float32, 2, 5);

julia> size(first(nomad((u, y), ps, st)))
(8, 5)
```
"""
function NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8),
approximator_activation=identity, decoder_activation=identity, concatenate=__merge)
approximator_net = Chain([Dense(approximator[i] => approximator[i + 1],
approximator_activation)
for i in 1:(length(approximator) - 1)]...)

decoder_net = Chain([Dense(decoder[i] => decoder[i + 1], decoder_activation)
for i in 1:(length(decoder) - 1)]...)

return NOMAD(approximator_net, decoder_net, concatenate)
end

function (nomad::NOMAD)(x, ps, st::NamedTuple)
a, st_a = nomad.approximator(x[1], ps.approximator, st.approximator)
out, st_d = nomad.decoder(nomad.concatenate(a, x[2]), ps.decoder, st.decoder)

return out, (approximator=st_a, decoder=st_d)
end
31 changes: 31 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,34 @@

return additional(b_ .* t_, params.ps, params.st) # p x u_size x N x nb => out_size x N x nb
end

@inline function __batch_vectorize(x::AbstractArray{T, N}) where {T, N}
dim_length = ndims(x) - 1
nb = size(x)[end]

slice = [Colon() for _ in 1:dim_length]
return reduce(hcat, [vec(view(x, slice..., i)) for i in 1:nb])
end

@inline function __merge(x::AbstractArray{T1, 2}, y::AbstractArray{T2, 2}) where {T1, T2}
return cat(x, y; dims=1)
end

@inline function __merge(
x::AbstractArray{T1, N1}, y::AbstractArray{T2, 2}) where {T1, T2, N1}
x_ = __batch_vectorize(x)
return vcat(x_, y)
end

@inline function __merge(

Check warning on line 91 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L91

Added line #L91 was not covered by tests
x::AbstractArray{T1, 2}, y::AbstractArray{T2, N2}) where {T1, T2, N2}
y_ = __batch_vectorize(y)
return vcat(x, y_)

Check warning on line 94 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
end

@inline function __merge(
x::AbstractArray{T1, N1}, y::AbstractArray{T2, N2}) where {T1, T2, N1, N2}
x_ = __batch_vectorize(x)
y_ = __batch_vectorize(y)
return vcat(x_, y_)
end
28 changes: 28 additions & 0 deletions test/nomad_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
@testitem "NOMAD" setup=[SharedTestSetup] begin
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
rng = StableRNG(12345)

setups = [
(u_size=(1, 5), y_size=(1, 5), out_size=(1, 5),
approximator=(1, 16, 16, 15), decoder=(16, 8, 4, 1), name="Scalar"),
(u_size=(8, 5), y_size=(2, 5), out_size=(8, 5),
approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), name="Vector")]

@testset "$(setup.name)" for setup in setups
u = rand(Float32, setup.u_size...) |> aType
y = rand(Float32, setup.y_size...) |> aType
nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder)

ps, st = Lux.setup(rng, nomad) |> dev
@inferred first(nomad((u, y), ps, st))
@jet first(nomad((u, y), ps, st))

pred = first(nomad((u, y), ps, st))
@test setup.out_size == size(pred)

__f = (u, y, ps) -> sum(abs2, first(nomad((u, y), ps, st)))
test_gradients(
__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
end
end
end
54 changes: 54 additions & 0 deletions test/utils_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
@testitem "utils" setup=[SharedTestSetup] begin
import NeuralOperators: __project, __merge, __batch_vectorize
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
rng = StableRNG(12345)

setups = [
(b_size=(16, 5), t_size=(16, 10, 5), out_size=(10, 5),
additional=NoOpLayer(), name="Scalar"),
(b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(1, 10, 5),
additional=NoOpLayer(), name="Scalar II"),
(b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(3, 10, 5),
additional=NoOpLayer(), name="Vector"),
(b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5),
out_size=(4, 3, 3, 10, 5), additional=NoOpLayer(), name="Tensor"),
(b_size=(16, 5), t_size=(16, 10, 5), out_size=(4, 10, 5),
additional=Dense(16 => 4), name="additional : Scalar"),
(b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(4, 10, 5),
additional=Dense(16 => 4), name="additional : Scalar II"),
(b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(4, 3, 10, 5),
additional=Dense(16 => 4), name="additional : Vector"),
(b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), out_size=(3, 4, 3, 4, 10, 5),
additional=Chain(Dense(16 => 4), ReshapeLayer((3, 4, 3, 4, 10))),
name="additional : Tensor")]

@testset "project : $(setup.name)" for setup in setups
b = rand(Float32, setup.b_size...) |> aType
t = rand(Float32, setup.t_size...) |> aType

ps, st = Lux.setup(rng, setup.additional) |> dev
@inferred first(__project(b, t, setup.additional, (; ps, st)))
@jet first(__project(b, t, setup.additional, (; ps, st)))
@test setup.out_size ==
size(first(__project(b, t, setup.additional, (; ps, st))))
end

setups = [(x_size=(6, 5), y_size=(4, 5), out_size=(10, 5), name="Scalar"),
(x_size=(12, 5), y_size=(8, 5), out_size=(20, 5), name="Vector I"),
(x_size=(4, 6, 5), y_size=(6, 5), out_size=(30, 5), name="Vector II"),
(x_size=(4, 2, 3, 5), y_size=(2, 2, 3, 5), out_size=(36, 5), name="Tensor")]

@testset "merge $(setup.name)" for setup in setups
x_size = rand(Float32, setup.x_size...) |> aType
y_size = rand(Float32, setup.y_size...) |> aType

@test setup.out_size == size(__merge(x_size, y_size))
end

@testset "batch vectorize" begin
x_size = (4, 2, 3)
x = rand(Float32, x_size..., 5) |> aType
@test size(__batch_vectorize(x)) == (prod(x_size), 5)
end
end
end
Loading