-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: add nomad * NOMAD docstrings * format, add utils test * typo fix * increase test coverage * __merge test coverage * chore: apply suggestions from code review --------- Co-authored-by: Avik Pal <[email protected]>
- Loading branch information
Showing
5 changed files
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
""" | ||
NOMAD(approximator, decoder, concatenate) | ||
Constructs a NOMAD from `approximator` and `decoder` architectures. Make sure the output from | ||
`approximator` combined with the coordinate dimension has compatible size for input to `decoder` | ||
## Arguments | ||
- `approximator`: `Lux` network to be used as approximator net. | ||
- `decoder`: `Lux` network to be used as decoder net. | ||
## Keyword Arguments | ||
- `concatenate`: function that defines the concatenation of output from `approximator` and the coordinate | ||
dimension, defaults to concatenation along first dimension after vectorizing the tensors | ||
## References | ||
[1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD: | ||
Nonlinear Manifold Decoders for Operator Learning", doi: https://arxiv.org/abs/2206.03551 | ||
## Example | ||
```jldoctest | ||
julia> approximator_net = Chain(Dense(8 => 32), Dense(32 => 32), Dense(32 => 16)); | ||
julia> decoder_net = Chain(Dense(18 => 16), Dense(16 => 16), Dense(16 => 8)); | ||
julia> nomad = NOMAD(approximator_net, decoder_net); | ||
julia> ps, st = Lux.setup(Xoshiro(), nomad); | ||
julia> u = rand(Float32, 8, 5); | ||
julia> y = rand(Float32, 2, 5); | ||
julia> size(first(nomad((u, y), ps, st))) | ||
(8, 5) | ||
``` | ||
""" | ||
@concrete struct NOMAD <: AbstractExplicitContainerLayer{(:approximator, :decoder)} | ||
approximator | ||
decoder | ||
concatenate <: Function | ||
end | ||
|
||
NOMAD(approximator, decoder) = NOMAD(approximator, decoder, __merge) | ||
|
||
""" | ||
NOMAD(; approximator = (8, 32, 32, 16), decoder = (18, 16, 8, 8), | ||
approximator_activation = identity, decoder_activation = identity) | ||
Constructs a NOMAD composed of Dense layers. Make sure that | ||
last node of `approximator` + coordinate length = first node of `decoder` | ||
## Keyword arguments: | ||
- `approximator`: Tuple of integers containing the number of nodes in each layer for approximator net | ||
- `decoder`: Tuple of integers containing the number of nodes in each layer for decoder net | ||
- `approximator_activation`: activation function for approximator net | ||
- `decoder_activation`: activation function for decoder net | ||
- `concatenate`: function that defines the concatenation of output from `approximator` and the coordinate | ||
dimension, defaults to concatenation along first dimension after vectorizing the tensors | ||
## References | ||
[1] Jacob H. Seidman and Georgios Kissas and Paris Perdikaris and George J. Pappas, "NOMAD: | ||
Nonlinear Manifold Decoders for Operator Learning", doi: https://arxiv.org/abs/2206.03551 | ||
## Example | ||
```jldoctest | ||
julia> nomad = NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8)); | ||
julia> ps, st = Lux.setup(Xoshiro(), nomad); | ||
julia> u = rand(Float32, 8, 5); | ||
julia> y = rand(Float32, 2, 5); | ||
julia> size(first(nomad((u, y), ps, st))) | ||
(8, 5) | ||
``` | ||
""" | ||
function NOMAD(; approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), | ||
approximator_activation=identity, decoder_activation=identity, concatenate=__merge) | ||
approximator_net = Chain([Dense(approximator[i] => approximator[i + 1], | ||
approximator_activation) | ||
for i in 1:(length(approximator) - 1)]...) | ||
|
||
decoder_net = Chain([Dense(decoder[i] => decoder[i + 1], decoder_activation) | ||
for i in 1:(length(decoder) - 1)]...) | ||
|
||
return NOMAD(approximator_net, decoder_net, concatenate) | ||
end | ||
|
||
function (nomad::NOMAD)(x, ps, st::NamedTuple) | ||
a, st_a = nomad.approximator(x[1], ps.approximator, st.approximator) | ||
out, st_d = nomad.decoder(nomad.concatenate(a, x[2]), ps.decoder, st.decoder) | ||
|
||
return out, (approximator=st_a, decoder=st_d) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
@testitem "NOMAD" setup=[SharedTestSetup] begin | ||
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES | ||
rng = StableRNG(12345) | ||
|
||
setups = [ | ||
(u_size=(1, 5), y_size=(1, 5), out_size=(1, 5), | ||
approximator=(1, 16, 16, 15), decoder=(16, 8, 4, 1), name="Scalar"), | ||
(u_size=(8, 5), y_size=(2, 5), out_size=(8, 5), | ||
approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), name="Vector")] | ||
|
||
@testset "$(setup.name)" for setup in setups | ||
u = rand(Float32, setup.u_size...) |> aType | ||
y = rand(Float32, setup.y_size...) |> aType | ||
nomad = NOMAD(; approximator=setup.approximator, decoder=setup.decoder) | ||
|
||
ps, st = Lux.setup(rng, nomad) |> dev | ||
@inferred first(nomad((u, y), ps, st)) | ||
@jet first(nomad((u, y), ps, st)) | ||
|
||
pred = first(nomad((u, y), ps, st)) | ||
@test setup.out_size == size(pred) | ||
|
||
__f = (u, y, ps) -> sum(abs2, first(nomad((u, y), ps, st))) | ||
test_gradients( | ||
__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()]) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
@testitem "utils" setup=[SharedTestSetup] begin | ||
import NeuralOperators: __project, __merge, __batch_vectorize | ||
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES | ||
rng = StableRNG(12345) | ||
|
||
setups = [ | ||
(b_size=(16, 5), t_size=(16, 10, 5), out_size=(10, 5), | ||
additional=NoOpLayer(), name="Scalar"), | ||
(b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(1, 10, 5), | ||
additional=NoOpLayer(), name="Scalar II"), | ||
(b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(3, 10, 5), | ||
additional=NoOpLayer(), name="Vector"), | ||
(b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), | ||
out_size=(4, 3, 3, 10, 5), additional=NoOpLayer(), name="Tensor"), | ||
(b_size=(16, 5), t_size=(16, 10, 5), out_size=(4, 10, 5), | ||
additional=Dense(16 => 4), name="additional : Scalar"), | ||
(b_size=(16, 1, 5), t_size=(16, 10, 5), out_size=(4, 10, 5), | ||
additional=Dense(16 => 4), name="additional : Scalar II"), | ||
(b_size=(16, 3, 5), t_size=(16, 10, 5), out_size=(4, 3, 10, 5), | ||
additional=Dense(16 => 4), name="additional : Vector"), | ||
(b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), out_size=(3, 4, 3, 4, 10, 5), | ||
additional=Chain(Dense(16 => 4), ReshapeLayer((3, 4, 3, 4, 10))), | ||
name="additional : Tensor")] | ||
|
||
@testset "project : $(setup.name)" for setup in setups | ||
b = rand(Float32, setup.b_size...) |> aType | ||
t = rand(Float32, setup.t_size...) |> aType | ||
|
||
ps, st = Lux.setup(rng, setup.additional) |> dev | ||
@inferred first(__project(b, t, setup.additional, (; ps, st))) | ||
@jet first(__project(b, t, setup.additional, (; ps, st))) | ||
@test setup.out_size == | ||
size(first(__project(b, t, setup.additional, (; ps, st)))) | ||
end | ||
|
||
setups = [(x_size=(6, 5), y_size=(4, 5), out_size=(10, 5), name="Scalar"), | ||
(x_size=(12, 5), y_size=(8, 5), out_size=(20, 5), name="Vector I"), | ||
(x_size=(4, 6, 5), y_size=(6, 5), out_size=(30, 5), name="Vector II"), | ||
(x_size=(4, 2, 3, 5), y_size=(2, 2, 3, 5), out_size=(36, 5), name="Tensor")] | ||
|
||
@testset "merge $(setup.name)" for setup in setups | ||
x_size = rand(Float32, setup.x_size...) |> aType | ||
y_size = rand(Float32, setup.y_size...) |> aType | ||
|
||
@test setup.out_size == size(__merge(x_size, y_size)) | ||
end | ||
|
||
@testset "batch vectorize" begin | ||
x_size = (4, 2, 3) | ||
x = rand(Float32, x_size..., 5) |> aType | ||
@test size(__batch_vectorize(x)) == (prod(x_size), 5) | ||
end | ||
end | ||
end |