Skip to content

Commit

Permalink
feat: NOMAD (#27)
Browse files Browse the repository at this point in the history
* feat: add nomad

* NOMAD docstrings

* format, add utils test

* typo fix

* increase test coverage

* __merge test coverage

* chore: apply suggestions from code review

---------

Co-authored-by: Avik Pal <[email protected]>
  • Loading branch information
ayushinav and avik-pal authored Aug 22, 2024
1 parent d22b523 commit 5fd7028
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ include("layers.jl")

include("fno.jl")
include("deeponet.jl")
include("nomad.jl")

export FourierTransform
export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel
export FourierNeuralOperator
export DeepONet
export NOMAD

end
102 changes: 102 additions & 0 deletions src/nomad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
NOMAD(approximator, decoder, concatenate)
Constructs a NOMAD from `approximator` and `decoder` architectures. Make sure the output from
`approximator` combined with the coordinate dimension has compatible size for input to `decoder`
## Arguments
- `approximator`: `Lux` network to be used as approximator net.
- `decoder`: `Lux` network to be used as decoder net.
## Keyword Arguments
- `concatenate`: function that defines the concatenation of output from `approximator` and the coordinate
dimension, defaults to concatenation along first dimension after vectorizing the tensors
## References
[1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD:
Nonlinear Manifold Decoders for Operator Learning", doi: https://arxiv.org/abs/2206.03551
## Example
```jldoctest
julia> approximator_net = Chain(Dense(8 => 32), Dense(32 => 32), Dense(32 => 16));
julia> decoder_net = Chain(Dense(18 => 16), Dense(16 => 16), Dense(16 => 8));
julia> nomad = NOMAD(approximator_net, decoder_net);
julia> ps, st = Lux.setup(Xoshiro(), nomad);
julia> u = rand(Float32, 8, 5);
julia> y = rand(Float32, 2, 5);
julia> size(first(nomad((u, y), ps, st)))
(8, 5)
```
"""
@concrete struct NOMAD <: AbstractExplicitContainerLayer{(:approximator, :decoder)}
approximator
decoder
concatenate <: Function
end

NOMAD(approximator, decoder) = NOMAD(approximator, decoder, __merge)

"""
NOMAD(; approximator = (8, 32, 32, 16), decoder = (18, 16, 8, 8),
approximator_activation = identity, decoder_activation = identity)
Constructs a NOMAD composed of Dense layers. Make sure that
last node of `approximator` + coordinate length = first node of `decoder`
## Keyword arguments:
- `approximator`: Tuple of integers containing the number of nodes in each layer for approximator net
- `decoder`: Tuple of integers containing the number of nodes in each layer for decoder net
- `approximator_activation`: activation function for approximator net
- `decoder_activation`: activation function for decoder net
- `concatenate`: function that defines the concatenation of output from `approximator` and the coordinate
dimension, defaults to concatenation along first dimension after vectorizing the tensors
## References
[1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD:
Nonlinear Manifold Decoders for Operator Learning", doi: https://arxiv.org/abs/2206.03551
## Example
```jldoctest
julia> nomad = NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8));
julia> ps, st = Lux.setup(Xoshiro(), nomad);
julia> u = rand(Float32, 8, 5);
julia> y = rand(Float32, 2, 5);
julia> size(first(nomad((u, y), ps, st)))
(8, 5)
```
"""
function NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8),
approximator_activation=identity, decoder_activation=identity, concatenate=__merge)
approximator_net = Chain([Dense(approximator[i] => approximator[i + 1],
approximator_activation)
for i in 1:(length(approximator) - 1)]...)

decoder_net = Chain([Dense(decoder[i] => decoder[i + 1], decoder_activation)
for i in 1:(length(decoder) - 1)]...)

return NOMAD(approximator_net, decoder_net, concatenate)
end

function (nomad::NOMAD)(x, ps, st::NamedTuple)
a, st_a = nomad.approximator(x[1], ps.approximator, st.approximator)
out, st_d = nomad.decoder(nomad.concatenate(a, x[2]), ps.decoder, st.decoder)

return out, (approximator=st_a, decoder=st_d)
end
31 changes: 31 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,34 @@ end

return additional(b_ .* t_, params.ps, params.st) # p x u_size x N x nb => out_size x N x nb
end

@inline function __batch_vectorize(x::AbstractArray{T, N}) where {T, N}
dim_length = ndims(x) - 1
nb = size(x)[end]

slice = [Colon() for _ in 1:dim_length]
return reduce(hcat, [vec(view(x, slice..., i)) for i in 1:nb])
end

@inline function __merge(x::AbstractArray{T1, 2}, y::AbstractArray{T2, 2}) where {T1, T2}
return cat(x, y; dims=1)
end

@inline function __merge(
x::AbstractArray{T1, N1}, y::AbstractArray{T2, 2}) where {T1, T2, N1}
x_ = __batch_vectorize(x)
return vcat(x_, y)
end

@inline function __merge(
x::AbstractArray{T1, 2}, y::AbstractArray{T2, N2}) where {T1, T2, N2}
y_ = __batch_vectorize(y)
return vcat(x, y_)
end

@inline function __merge(
x::AbstractArray{T1, N1}, y::AbstractArray{T2, N2}) where {T1, T2, N1, N2}
x_ = __batch_vectorize(x)
y_ = __batch_vectorize(y)
return vcat(x_, y_)
end
28 changes: 28 additions & 0 deletions test/nomad_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
@testitem "NOMAD" setup=[SharedTestSetup] begin
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
rng = StableRNG(12345)

setups = [
(u_size=(1, 5), y_size=(1, 5), out_size=(1, 5),
approximator=(1, 16, 16, 15), decoder=(16, 8, 4, 1), name="Scalar"),
(u_size=(8, 5), y_size=(2, 5), out_size=(8, 5),
approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), name="Vector")]

@testset "$(setup.name)" for setup in setups
u = rand(Float32, setup.u_size...) |> aType
y = rand(Float32, setup.y_size...) |> aType
nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder)

ps, st = Lux.setup(rng, nomad) |> dev
@inferred first(nomad((u, y), ps, st))
@jet first(nomad((u, y), ps, st))

pred = first(nomad((u, y), ps, st))
@test setup.out_size == size(pred)

__f = (u, y, ps) -> sum(abs2, first(nomad((u, y), ps, st)))
test_gradients(
__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
end
end
end
54 changes: 54 additions & 0 deletions test/utils_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
@testitem "utils" setup=[SharedTestSetup] begin
import NeuralOperators: __project, __merge, __batch_vectorize
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
rng = StableRNG(12345)

setups = [
(b_size=(16, 5), t_size=(16, 10, 5), out_size=(10, 5),
additional=NoOpLayer(), name="Scalar"),
(b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(1, 10, 5),
additional=NoOpLayer(), name="Scalar II"),
(b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(3, 10, 5),
additional=NoOpLayer(), name="Vector"),
(b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5),
out_size=(4, 3, 3, 10, 5), additional=NoOpLayer(), name="Tensor"),
(b_size=(16, 5), t_size=(16, 10, 5), out_size=(4, 10, 5),
additional=Dense(16 => 4), name="additional : Scalar"),
(b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(4, 10, 5),
additional=Dense(16 => 4), name="additional : Scalar II"),
(b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(4, 3, 10, 5),
additional=Dense(16 => 4), name="additional : Vector"),
(b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), out_size=(3, 4, 3, 4, 10, 5),
additional=Chain(Dense(16 => 4), ReshapeLayer((3, 4, 3, 4, 10))),
name="additional : Tensor")]

@testset "project : $(setup.name)" for setup in setups
b = rand(Float32, setup.b_size...) |> aType
t = rand(Float32, setup.t_size...) |> aType

ps, st = Lux.setup(rng, setup.additional) |> dev
@inferred first(__project(b, t, setup.additional, (; ps, st)))
@jet first(__project(b, t, setup.additional, (; ps, st)))
@test setup.out_size ==
size(first(__project(b, t, setup.additional, (; ps, st))))
end

setups = [(x_size=(6, 5), y_size=(4, 5), out_size=(10, 5), name="Scalar"),
(x_size=(12, 5), y_size=(8, 5), out_size=(20, 5), name="Vector I"),
(x_size=(4, 6, 5), y_size=(6, 5), out_size=(30, 5), name="Vector II"),
(x_size=(4, 2, 3, 5), y_size=(2, 2, 3, 5), out_size=(36, 5), name="Tensor")]

@testset "merge $(setup.name)" for setup in setups
x_size = rand(Float32, setup.x_size...) |> aType
y_size = rand(Float32, setup.y_size...) |> aType

@test setup.out_size == size(__merge(x_size, y_size))
end

@testset "batch vectorize" begin
x_size = (4, 2, 3)
x = rand(Float32, x_size..., 5) |> aType
@test size(__batch_vectorize(x)) == (prod(x_size), 5)
end
end
end

0 comments on commit 5fd7028

Please sign in to comment.