Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat : NOMAD #27

Merged
merged 7 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ include("layers.jl")

include("fno.jl")
include("deeponet.jl")
include("nomad.jl")

export FourierTransform
export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel
export FourierNeuralOperator
export DeepONet
export NOMAD

end
103 changes: 103 additions & 0 deletions src/nomad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
NOMAD(approximator, decoder, concatenate)

Constructs a NOMAD from `approximator` and `decoder` architectures. Make sure the output from
`approximator` combined with the coordinate dimension has compatible size for input to `decoder`

## Arguments

- `approximator`: `Lux` network to be used as approximator net.
- `decoder`: `Lux` network to be used as decoder net.

## Keyword Arguments

- `concatenate`: function that defines the concatenation of output from `approximator` and the coordinate
dimension, defaults to concatenation along first dimension after vectorizing the tensors

## References

[1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD:
Nonlinear Manifold Decoders for Operator Learning", doi: https://arxiv.org/abs/2206.03551

## Example

```jldoctest
julia> approximator_net = Chain(Dense(8 => 32), Dense(32 => 32), Dense(32 => 16));

julia> decoder_net = Chain(Dense(18 => 16), Dense(16 => 16), Dense(16 => 8));

julia> nomad = NOMAD(approximator_net, decoder_net);

julia> ps, st = Lux.setup(Xoshiro(), nomad);

julia> u = rand(Float32, 8, 5);

julia> y = rand(Float32, 2, 5);

julia> size(first(nomad((u, y), ps, st)))
(8, 5)
```
"""
@concrete struct NOMAD <: AbstractExplicitContainerLayer{(:approximator, :decoder)}
approximator
decoder
concatenate <: Function
end

NOMAD(approximator, decoder) = NOMAD(approximator, decoder, __merge)

"""
NOMAD(; approximator = (8, 32, 32, 16), decoder = (18, 16, 8, 8),
approximator_activation = identity, decoder_activation = identity)
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

Constructs a NOMAD composed of Dense layers. Make sure that
last node of `approximator` + coordinate length = first node of `decoder`

## Keyword arguments:

- `approximator`: Tuple of integers containing the number of nodes in each layer for approximator net
- `decoder`: Tuple of integers containing the number of nodes in each layer for decoder net
- `approximator_activation`: activation function for approximator net
- `decoder_activation`: activation function for decoder net
- `concatenate`: function that defines the concatenation of output from `approximator` and the coordinate
dimension, defaults to concatenation along first dimension after vectorizing the tensors

## References

[1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD:
Nonlinear Manifold Decoders for Operator Learning", doi: https://arxiv.org/abs/2206.03551

## Example

```jldoctest
julia> nomad = NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8));

julia> ps, st = Lux.setup(Xoshiro(), nomad);

julia> u = rand(Float32, 8, 5);

julia> y = rand(Float32, 2, 5);

julia> size(first(nomad((u, y), ps, st)))
(8, 5)
```
"""
function NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8),
approximator_activation=identity, decoder_activation=identity,
concatenate=__merge)
approximator_net = Chain([Dense(approximator[i] => approximator[i + 1],
approximator_activation)
for i in 1:(length(approximator) - 1)]...)

decoder_net = Chain([Dense(decoder[i] => decoder[i + 1], decoder_activation)
for i in 1:(length(decoder) - 1)]...)

return NOMAD(approximator_net, decoder_net, concatenate)
end

function (nomad::NOMAD)(x, ps, st::NamedTuple)
a, st_a = nomad.approximator(x[1], ps.approximator, st.approximator)
out, st_d = nomad.decoder(nomad.concatenate(a, x[2]), ps.decoder, st.decoder)

return out, (approximator=st_a, decoder=st_d)
end
55 changes: 43 additions & 12 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
@inline function __project(
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2}
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, ::NoOpLayer,
_) where {T1, T2}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return dropdims(sum(b_ .* t; dims=1); dims=1), () # N x nb
end

@inline function __project(
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2}
@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, ::NoOpLayer,
_) where {T1, T2}
# b : p x u x nb
# t : p x N x nb
if size(b, 2) == 1 || size(t, 2) == 1
Expand All @@ -17,8 +17,8 @@
end
end

@inline function __project(
b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2, N}
@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, ::NoOpLayer,
_) where {T1, T2, N}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]
Expand All @@ -32,16 +32,16 @@
return dropdims(sum(b_ .* t_; dims=1); dims=1), () # u_size x N x nb
end

@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
additional::T, params) where {T1, T2, T}
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T,
params) where {T1, T2, T}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return additional(b_ .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb
end

@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3},
additional::T, params) where {T1, T2, T}
@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T,
params) where {T1, T2, T}
# b : p x u x nb
# t : p x N x nb

Expand All @@ -55,8 +55,8 @@
end
end

@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
additional::T, params) where {T1, T2, N, T}
@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, additional::T,

Check warning on line 58 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L58

Added line #L58 was not covered by tests
params) where {T1, T2, N, T}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]
Expand All @@ -69,3 +69,34 @@

return additional(b_ .* t_, params.ps, params.st) # p x u_size x N x nb => out_size x N x nb
end

@inline function __batch_vectorize(x::AbstractArray{T, N}) where {T, N}
dim_length = ndims(x) - 1
nb = size(x)[end]

Check warning on line 75 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L73-L75

Added lines #L73 - L75 were not covered by tests

slice = [Colon() for _ in 1:dim_length]
return reduce(hcat, [vec(view(x, slice..., i)) for i in 1:nb])

Check warning on line 78 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L77-L78

Added lines #L77 - L78 were not covered by tests
end

@inline function __merge(x::AbstractArray{T1, 2}, y::AbstractArray{T2, 2}) where {T1, T2}
return cat(x, y; dims=1)
end

@inline function __merge(x::AbstractArray{T1, N1},

Check warning on line 85 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L85

Added line #L85 was not covered by tests
y::AbstractArray{T2, 2}) where {T1, T2, N1}
x_ = __batch_vectorize(x)
return cat(x_, y; dims=1)

Check warning on line 88 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L87-L88

Added lines #L87 - L88 were not covered by tests
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
end

@inline function __merge(x::AbstractArray{T1, 2},

Check warning on line 91 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L91

Added line #L91 was not covered by tests
y::AbstractArray{T2, N2}) where {T1, T2, N2}
y_ = __batch_vectorize(y)
return cat(x, y_; dims=1)

Check warning on line 94 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
end

@inline function __merge(x::AbstractArray{T1, N1},

Check warning on line 97 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L97

Added line #L97 was not covered by tests
y::AbstractArray{T2, N2}) where {T1, T2, N1, N2}
x_ = __batch_vectorize(x)
y_ = __batch_vectorize(y)
return cat(x_, y_; dims=1)

Check warning on line 101 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L99-L101

Added lines #L99 - L101 were not covered by tests
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
end
27 changes: 27 additions & 0 deletions test/nomad_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
@testitem "NOMAD" setup=[SharedTestSetup] begin @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
rng = StableRNG(12345)

setups = [
(u_size=(1, 5), y_size=(1, 5), out_size=(1, 5), approximator=(1, 16, 16, 15),
decoder=(16, 8, 4, 1), name="Scalar"),
(u_size=(8, 5), y_size=(2, 5), out_size=(8, 5), approximator=(8, 32, 32, 16),
decoder=(18, 16, 8, 8), name="Vector"),
]

@testset "$(setup.name)" for setup in setups
u = rand(Float32, setup.u_size...) |> aType
y = rand(Float32, setup.y_size...) |> aType
nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder)

ps, st = Lux.setup(rng, nomad) |> dev
@inferred first(nomad((u, y), ps, st))
@jet first(nomad((u, y), ps, st))

pred = first(nomad((u, y), ps, st))
@test setup.out_size == size(pred)

__f = (u, y, ps) -> sum(abs2, first(nomad((u, y), ps, st)))
test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3,
skip_backends=[AutoEnzyme()])
end
end end
Loading