Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deeponet multi-output fix #11

Closed
wants to merge 17 commits into from
Closed
26 changes: 16 additions & 10 deletions src/deeponet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Constructs a DeepONet composed of Dense layers. Make sure the last node of `bran
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
- `branch_activation`: activation function for branch net
- `trunk_activation`: activation function for trunk net
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`

## References

Expand All @@ -33,8 +35,9 @@ julia> size(first(deeponet((u, y), ps, st)))
(10, 5)
```
"""
function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
branch_activation=identity, trunk_activation=identity)
function DeepONet(;
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
trunk_activation=identity, additional=nothing)

# checks for last dimension size
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
Expand All @@ -47,7 +50,7 @@ function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
for i in 1:(length(trunk) - 1)]...)

return DeepONet(branch_net, trunk_net)
return DeepONet(branch_net, trunk_net; additional=additional)
end

"""
Expand All @@ -61,6 +64,11 @@ nets output should have the same first dimension.
- `branch`: `Lux` network to be used as branch net.
- `trunk`: `Lux` network to be used as trunk net.

## Keyword Arguments

- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`

## References

[1] Lu Lu, Pengzhan Jin, George Em Karniadakis, "DeepONet: Learning nonlinear operators for
Expand All @@ -86,17 +94,15 @@ julia> size(first(deeponet((u, y), ps, st)))
(10, 5)
```
"""
function DeepONet(branch::L1, trunk::L2) where {L1, L2}
return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y)
t = trunk(y) # p x N x nb...
b = branch(u) # p x nb...
function DeepONet(branch::L1, trunk::L2; additional=nothing) where {L1, L2}
return @compact(; branch, trunk, additional, dispatch=:DeepONet) do (u, y)
t = trunk(y) # p x N x nb
b = branch(u) # p x u_size... x nb

@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \
amount of nodes in the last layer. Otherwise \
Σᵢ bᵢⱼ tᵢₖ won't work."

b_ = ndims(t) == ndims(b) ? b : reshape(b, size(b, 1), 1, size(b)[2:end]...)
@return dropdims(sum(t .* b_; dims=1); dims=1)
@return __project(b, t, additional)
end
end
72 changes: 72 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,75 @@
# FIXME: This is not good for performance but that is okay for now
return stack(*, eachslice(x; dims=3), eachslice(y; dims=3))
end

@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return dropdims(sum(b_ .* t; dims=1); dims=1) # N x nb
end

@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2}
# b : p x u x nb
# t : p x N x nb
if size(b, 2) == 1 || size(t, 2) == 1
return sum(b .* t; dims=1) # 1 x N x nb
else
return __batched_mul(batched_adjoint(b), t) # u x N x b
end
end

@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2, N}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]

b_ = reshape(b, size(b, 1), u_size..., 1, size(b)[end])
# p x u_size x 1 x nb

t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
# p x (1,1,1...) x N x nb

return dropdims(sum(b_ .* t_; dims=1); dims=1) # u_size x N x nb
end

@inline function __project(
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return additional(b_ .* t) # p x N x nb => out_dims x N x nb
end

@inline function __project(
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
# b : p x u x nb
# t : p x N x nb

if size(b, 2) == 1 || size(t, 2) == 1
return additional(b .* t) # p x N x nb => out_dims x N x nb
else
b_ = reshape(b, size(b)[1:2]..., 1, size(b, 3)) # p x u x 1 x nb
t_ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # p x 1 x N x nb

return additional(b_ .* t_) # p x u x N x nb => out_size x N x nb
end
end

@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},

Check warning on line 70 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L70

Added line #L70 was not covered by tests
additional::T) where {T1, T2, N, T}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]

Check warning on line 74 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L74

Added line #L74 was not covered by tests

b_ = reshape(b, size(b, 1), u_size..., 1, size(b)[end])

Check warning on line 76 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L76

Added line #L76 was not covered by tests
# p x u_size x 1 x nb

t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)

Check warning on line 79 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L79

Added line #L79 was not covered by tests
# p x (1,1,1...) x N x nb

return additional(b_ .* t_) # p x u_size x N x nb => out_size x N x nb

Check warning on line 82 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L82

Added line #L82 was not covered by tests
end
94 changes: 54 additions & 40 deletions test/deeponet_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,63 @@
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
rng = StableRNG(12345)

u = rand(Float32, 64, 5) |> aType # sensor_points x nb
y = rand(Float32, 1, 10, 5) |> aType # ndims x N x nb
out_size = (10, 5)

deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
display(deeponet)
ps, st = Lux.setup(rng, deeponet) |> dev

@inferred deeponet((u, y), ps, st)
@jet deeponet((u, y), ps, st)

pred = first(deeponet((u, y), ps, st))
@test size(pred) == out_size

deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
display(deeponet)
ps, st = Lux.setup(rng, deeponet) |> dev

@inferred deeponet((u, y), ps, st)
@jet deeponet((u, y), ps, st)

pred = first(deeponet((u, y), ps, st))
@test size(pred) == out_size

deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
display(deeponet)
ps, st = Lux.setup(rng, deeponet) |> dev
@test_throws ArgumentError deeponet((u, y), ps, st)

@testset "higher-dim input #7" begin
u = ones(Float32, 10, 10, 10) |> aType
v = ones(Float32, 1, 10, 10) |> aType
deeponet = DeepONet(; branch=(10, 10, 10), trunk=(1, 10, 10))
display(deeponet)
setups = [
(u_size=(64, 5), y_size=(1, 10, 5), out_size=(10, 5),
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar"),
(u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(1, 10, 5),
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), name="Scalar II"),
(u_size=(64, 3, 5), y_size=(4, 10, 5), out_size=(3, 10, 5),
branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Vector"),
(u_size=(64, 4, 3, 3, 5), y_size=(4, 10, 5), out_size=(4, 3, 3, 10, 5),
branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Tensor")]

@testset "$(setup.name)" for setup in setups
u = rand(Float32, setup.u_size...) |> aType
y = rand(Float32, setup.y_size...) |> aType
deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk)

ps, st = Lux.setup(rng, deeponet) |> dev
@inferred first(deeponet((u, y), ps, st))
@jet first(deeponet((u, y), ps, st))

pred = first(deeponet((u, y), ps, st))
@test setup.out_size == size(pred)
end

setups = [
(u_size=(64, 5), y_size=(1, 10, 5), out_size=(4, 10, 5),
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
additional=Dense(16 => 4), name="Scalar"),
(u_size=(64, 1, 5), y_size=(1, 10, 5), out_size=(4, 10, 5),
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
additional=Dense(16 => 4), name="Scalar II"),
(u_size=(64, 3, 5), y_size=(8, 10, 5), out_size=(4, 3, 10, 5),
branch=(64, 32, 32, 16), trunk=(8, 8, 8, 16),
additional=Dense(16 => 4), name="Vector")]

@testset "Additional layer: $(setup.name)" for setup in setups
u = rand(Float32, setup.u_size...) |> aType
y = rand(Float32, setup.y_size...) |> aType
deeponet = DeepONet(;
branch=setup.branch, trunk=setup.trunk, additional=setup.additional)

ps, st = Lux.setup(rng, deeponet) |> dev
@inferred first(deeponet((u, y), ps, st))
@jet first(deeponet((u, y), ps, st))

pred = first(deeponet((u, y), ps, st))
@test setup.out_size == size(pred)
end

@testset "Embedding layer mismatch" begin
u = rand(Float32, 64, 5) |> aType
y = rand(Float32, 1, 10, 5) |> aType

y, st_ = deeponet((u, v), ps, st)
@test size(y) == (10, 10)
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))

@inferred deeponet((u, v), ps, st)
@jet deeponet((u, v), ps, st)
ps, st = Lux.setup(rng, deeponet) |> dev
@test_throws ArgumentError deeponet((u, y), ps, st)
end
end
end
Loading