Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deeponet multi-output fix #11

Closed
wants to merge 17 commits into from
Closed
5 changes: 3 additions & 2 deletions src/LuxNeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ using PrecompileTools: @recompile_invalidations
using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
using FFTW: FFTW, irfft, rfft
using Lux
using Lux: _print_wrapper_model
using LuxCore: LuxCore, AbstractExplicitLayer
using NNlib: NNlib, ⊠
using NNlib: NNlib, ⊠, batched_adjoint
using Random: Random, AbstractRNG
using Reexport: @reexport
end
Expand All @@ -26,6 +26,7 @@ include("layers.jl")

include("fno.jl")
include("deeponet.jl")
include("display.jl")

export FourierTransform
export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel
Expand Down
107 changes: 67 additions & 40 deletions src/deeponet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Constructs a DeepONet composed of Dense layers. Make sure the last node of `bran
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
- `branch_activation`: activation function for branch net
- `trunk_activation`: activation function for trunk net
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`

## References

Expand All @@ -22,34 +24,28 @@ operators", doi: https://arxiv.org/abs/1910.03193

```jldoctest
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
@compact(
branch = Chain(
Branch net :
(
Chain(
layer_1 = Dense(64 => 32), # 2_080 parameters
layer_2 = Dense(32 => 32), # 1_056 parameters
layer_3 = Dense(32 => 16), # 528 parameters
),
trunk = Chain(
)

Trunk net :
(
Chain(
layer_1 = Dense(1 => 8), # 16 parameters
layer_2 = Dense(8 => 8), # 72 parameters
layer_3 = Dense(8 => 16), # 144 parameters
),
) do (u, y)
t = trunk(y)
b = branch(u)
@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
@argcheck size(t, 1) == size(b, 1) "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work."
b_ = if ndims(t) == ndims(b)
b
else
reshape(b, size(b, 1), 1, (size(b))[2:end]...)
end
return dropdims(sum(t .* b_; dims = 1); dims = 1)
end # Total: 3_896 parameters,
# plus 0 states.
)
```
"""
function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
branch_activation=identity, trunk_activation=identity)
function DeepONet(;
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
trunk_activation=identity, additional=nothing)

# checks for last dimension size
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
Expand All @@ -62,7 +58,7 @@ function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
for i in 1:(length(trunk) - 1)]...)

return DeepONet(branch_net, trunk_net)
return DeepONet(branch_net, trunk_net; additional=additional)
end

"""
Expand All @@ -76,6 +72,11 @@ nets output should have the same first dimension.
- `branch`: `Lux` network to be used as branch net.
- `trunk`: `Lux` network to be used as trunk net.

## Keyword Arguments

- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`

## References

[1] Lu Lu, Pengzhan Jin, George Em Karniadakis, "DeepONet: Learning nonlinear operators for
Expand All @@ -90,43 +91,69 @@ julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));

julia> deeponet = DeepONet(branch_net, trunk_net)
@compact(
branch = Chain(
Branch net :
(
Chain(
layer_1 = Dense(64 => 32), # 2_080 parameters
layer_2 = Dense(32 => 32), # 1_056 parameters
layer_3 = Dense(32 => 16), # 528 parameters
),
trunk = Chain(
)

Trunk net :
(
Chain(
layer_1 = Dense(1 => 8), # 16 parameters
layer_2 = Dense(8 => 8), # 72 parameters
layer_3 = Dense(8 => 16), # 144 parameters
),
) do (u, y)
t = trunk(y)
b = branch(u)
@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
@argcheck size(t, 1) == size(b, 1) "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work."
b_ = if ndims(t) == ndims(b)
b
else
reshape(b, size(b, 1), 1, (size(b))[2:end]...)
end
return dropdims(sum(t .* b_; dims = 1); dims = 1)
end # Total: 3_896 parameters,
# plus 0 states.
)

julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));

julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));

julia> additional = Chain(Dense(1 => 4));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input for additional layer should be size of inner embedding size

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it not need reduction/sum/dropdim before additional layer. It should be additional = Chain(Dense(16 => 4)); here. Otherwise It's created a bottleneck and we lose information here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed now. Using the linear layer as additional layer for the cases where we do not have the additional layer did not seem ideal to me because it would imply weighted sum, where the weights would be learnt during training, but since DeepONets by default take the dot product, aka non-weighted sum, which could be required by many users.


julia> deeponet = DeepONet(branch_net, trunk_net; additional=additional)
Branch net :
(
Chain(
layer_1 = Dense(64 => 32), # 2_080 parameters
layer_2 = Dense(32 => 32), # 1_056 parameters
layer_3 = Dense(32 => 16), # 528 parameters
),
)

Trunk net :
(
Chain(
layer_1 = Dense(1 => 8), # 16 parameters
layer_2 = Dense(8 => 8), # 72 parameters
layer_3 = Dense(8 => 16), # 144 parameters
),
)

Additional net :
(
Dense(1 => 4), # 8 parameters
)
```
"""
function DeepONet(branch::L1, trunk::L2) where {L1, L2}
return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y)
function DeepONet(branch::L1, trunk::L2; additional=nothing) where {L1, L2}
return @compact(; branch, trunk, additional, dispatch=:DeepONet) do (u, y)
t = trunk(y) # p x N x nb...
b = branch(u) # p x nb...

@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \
amount of nodes in the last layer. Otherwise \
Σᵢ bᵢⱼ tᵢₖ won't work."

b_ = ndims(t) == ndims(b) ? b : reshape(b, size(b, 1), 1, size(b)[2:end]...)
@return dropdims(sum(t .* b_; dims=1); dims=1)
if isnothing(additional)
out_ = __project(b, t)
else
out_ = additional(__project(b, t))
end
@return out_
end
end
22 changes: 22 additions & 0 deletions src/display.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# function Base.show(io::IO, model::conv) where {conv <: OperatorConv}
# # print(io, model.name*"() # "*string(Lux.parameterlength(model))*" parameters")
# print(io, model.name)
# end

# function Base.show(io::IO, ::MIME"text/plain", model::conv) where {conv <: OperatorConv}
# show(io, model.name)
# end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these, printing was fixed upstream


function Base.show(io::IO, model::Lux.CompactLuxLayer{:DeepONet})
_print_wrapper_model(io, "Branch net :\n", model.layers.branch)
print(io, "\n \n")
_print_wrapper_model(io, "Trunk net :\n", model.layers.trunk)
if :additional in keys(model.layers)
print(io, "\n \n")
_print_wrapper_model(io, "Additional net :\n", model.layers.additional)
end
end
ayushinav marked this conversation as resolved.
Show resolved Hide resolved

function Base.show(io::IO, ::MIME"text/plain", x::CompactLuxLayer{:DeepONet})
show(io, x)
end
32 changes: 32 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,34 @@
# Temporarily capture certain calls like AMDGPU for ComplexFloats
@inline __batched_mul(x, y) = x ⊠ y

@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}) where {T1, T2}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return dropdims(sum(b_ .* t; dims=1); dims=1) # N x nb
end

@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}) where {T1, T2}
# b : p x u x nb
# t : p x N x nb
if size(b, 2) == 1 || size(t, 2) == 1
return sum(b .* t; dims=1) # 1 x N x nb
else
return __batched_mul(batched_adjoint(t), b) # N x p x nb
end
end

@inline function __project(
b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}) where {T1, T2, N}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]

b_ = reshape(b, size(b, 1), 1, u_size..., size(b)[end])
# p x 1 x u_size x nb

t_ = reshape(t, size(t)[1:2]..., ones(eltype(u_size), length(u_size))..., size(t)[end])
# p x N x (1,1,1...) x nb

return dropdims(sum(b_ .* t_; dims=1); dims=1) # N x u_size x nb
end
70 changes: 38 additions & 32 deletions test/deeponet_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,53 @@
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
rng = StableRNG(12345)

u = rand(Float32, 64, 5) |> aType # sensor_points x nb
y = rand(Float32, 1, 10, 5) |> aType # ndims x N x nb
out_size = (10, 5)
setups = [
(u_size=(64, 5), y_size=(1, 10, 5), out_size=(10, 5),
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar"),
(u_size=(64, 3, 5), y_size=(4, 10, 5), out_size=(10, 3, 5),
branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Vector"),
(u_size=(64, 4, 3, 3, 5), y_size=(4, 10, 5), out_size=(10, 4, 3, 3, 5),
branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Tensor")]

@testset "$(setup.name)" for setup in setups
u = rand(Float32, setup.u_size...) |> aType
y = rand(Float32, setup.y_size...) |> aType
deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk)

deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))

ps, st = Lux.setup(rng, deeponet) |> dev

@inferred deeponet((u, y), ps, st)
@jet deeponet((u, y), ps, st)
ps, st = Lux.setup(rng, deeponet) |> dev
@inferred deeponet((u, y), ps, st)
@jet deeponet((u, y), ps, st)

pred = first(deeponet((u, y), ps, st))
@test size(pred) == out_size
pred = first(deeponet((u, y), ps, st))
@test setup.out_size == size(pred)
end

deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
@testset "Additonal layer" begin
u = rand(Float32, 64, 1, 5) |> aType # sensor_points x nb
y = rand(Float32, 1, 10, 5) |> aType # ndims x N x nb
out_size = (4, 10, 5)

ps, st = Lux.setup(rng, deeponet) |> dev
branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16))
trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16))
additional = Chain(Dense(1 => 4))
deeponet = DeepONet(branch_net, trunk_net; additional=additional)

@inferred deeponet((u, y), ps, st)
@jet deeponet((u, y), ps, st)
ps, st = Lux.setup(rng, deeponet) |> dev
@inferred deeponet((u, y), ps, st)
@jet deeponet((u, y), ps, st)

pred = first(deeponet((u, y), ps, st))
@test size(pred) == out_size
pred = first(deeponet((u, y), ps, st))
@test out_size == size(pred)
end

deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
ps, st = Lux.setup(rng, deeponet) |> dev
@test_throws ArgumentError deeponet((u, y), ps, st)
@testset "Embedding layer mismatch" begin
u = rand(Float32, 64, 5) |> aType
y = rand(Float32, 1, 10, 5) |> aType

@testset "higher-dim input #7" begin
u = ones(Float32, 10, 10, 10) |> aType
v = ones(Float32, 1, 10, 10) |> aType
deeponet = DeepONet(; branch=(10, 10, 10), trunk=(1, 10, 10))
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
ps, st = Lux.setup(rng, deeponet) |> dev

y, st_ = deeponet((u, v), ps, st)
@test size(y) == (10, 10)

@inferred deeponet((u, v), ps, st)
@jet deeponet((u, v), ps, st)
@test_throws ArgumentError deeponet((u, y), ps, st)
end
end
end
Loading