Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structs for NeuralOperators #23

Merged
merged 16 commits into from
Aug 20, 2024
107 changes: 58 additions & 49 deletions src/deeponet.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
branch_activation = identity, trunk_activation = identity)
DeepONet(branch, trunk, additional)

Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
`trunk` are same.
Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
nets output should have the same first dimension.

## Keyword arguments:
## Arguments

- `branch`: `Lux` network to be used as branch net.
- `trunk`: `Lux` network to be used as trunk net.

## Keyword Arguments

- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
- `branch_activation`: activation function for branch net
- `trunk_activation`: activation function for trunk net
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`

Expand All @@ -23,7 +23,11 @@ operators", doi: https://arxiv.org/abs/1910.03193
## Example

```jldoctest
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));

julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));

julia> deeponet = DeepONet(branch_net, trunk_net);

julia> ps, st = Lux.setup(Xoshiro(), deeponet);

Expand All @@ -35,37 +39,28 @@ julia> size(first(deeponet((u, y), ps, st)))
(10, 5)
```
"""
function DeepONet(;
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
trunk_activation=identity, additional=nothing)

# checks for last dimension size
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
for i in 1:(length(branch) - 1)]...)

trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
for i in 1:(length(trunk) - 1)]...)

return DeepONet(branch_net, trunk_net; additional)
@concrete struct DeepONet{L1, L2, L3} <:
Lux.AbstractExplicitContainerLayer{(:branch, :trunk, :additional)}
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
branch::L1
trunk::L2
additional::L3
end

"""
DeepONet(branch, trunk)
DeepONet(branch, trunk) = DeepONet(branch, trunk, NoOpLayer())

Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
nets output should have the same first dimension.

## Arguments
"""
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
branch_activation = identity, trunk_activation = identity)

- `branch`: `Lux` network to be used as branch net.
- `trunk`: `Lux` network to be used as trunk net.
Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
`trunk` are same.

## Keyword Arguments
## Keyword arguments:

- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
- `branch_activation`: activation function for branch net
- `trunk_activation`: activation function for trunk net
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`

Expand All @@ -78,11 +73,7 @@ operators", doi: https://arxiv.org/abs/1910.03193
## Example

```jldoctest
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));

julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));

julia> deeponet = DeepONet(branch_net, trunk_net);
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));

julia> ps, st = Lux.setup(Xoshiro(), deeponet);

Expand All @@ -94,15 +85,33 @@ julia> size(first(deeponet((u, y), ps, st)))
(10, 5)
```
"""
function DeepONet(branch::L1, trunk::L2; additional=nothing) where {L1, L2}
return @compact(; branch, trunk, additional, dispatch=:DeepONet) do (u, y)
t = trunk(y) # p x N x nb
b = branch(u) # p x u_size... x nb
function DeepONet(;
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
trunk_activation=identity, additional=NoOpLayer())

# checks for last dimension size
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
for i in 1:(length(branch) - 1)]...)

trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
for i in 1:(length(trunk) - 1)]...)

return DeepONet(branch_net, trunk_net, additional)
end

function (deeponet::DeepONet)(
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
u::T1, y::T2, ps, st::NamedTuple) where {T1 <: AbstractArray, T2 <: AbstractArray}
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
b, st_b = deeponet.branch(u, ps.branch, st.branch)
t, st_t = deeponet.trunk(y, ps.trunk, st.trunk)

@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \
amount of nodes in the last layer. Otherwise \
Σᵢ bᵢⱼ tᵢₖ won't work."
@argcheck size(b, 1)==size(t, 1) "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

@return __project(b, t, additional)
end
out, st_a = __project(b, t, deeponet.additional, (; ps=ps.additional, st=st.additional))
return out, (branch=st_b, trunk=st_t, additional=st_a)
end
29 changes: 21 additions & 8 deletions src/fno.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ kernels, and two `Dense` layers to project data back to the scalar field of inte
## Example

```jldoctest
julia> fno = FourierNeuralOperator(gelu; chs=(2, 64, 64, 128, 1), modes=(16,));
julia> fno = FourierNeuralOperator(; σ=gelu, chs=(2, 64, 64, 128, 1), modes=(16,));

julia> ps, st = Lux.setup(Xoshiro(), fno);

Expand All @@ -37,8 +37,15 @@ julia> size(first(fno(u, ps, st)))
(1, 1024, 5)
```
"""
function FourierNeuralOperator(
σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
@concrete struct FourierNeuralOperator{L1, L2, L3} <:
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
Lux.AbstractExplicitContainerLayer{(:lifting, :mapping, :project)}
lifting::L1
mapping::L2
project::L3
end

function FourierNeuralOperator(;
σ=gelu, chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
permuted::Val{perm}=Val(false), kwargs...) where {C, M, perm}
@argcheck length(chs) ≥ 5

Expand All @@ -52,9 +59,15 @@ function FourierNeuralOperator(
project = perm ? Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) :
Chain(Dense(map₂, σ), Dense(map₃))

return Chain(; lifting,
mapping=Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...)
for i in 2:(C - 3)]...),
project,
name="FourierNeuralOperator")
mapping = Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...)
for i in 2:(C - 3)]...)

return FourierNeuralOperator(lifting, mapping, project)
end

function (fno::FourierNeuralOperator)(x::T, ps, st::NamedTuple) where {T}
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
lift, st_lift = fno.lifting(x, ps.lifting, st.lifting)
mapping, st_mapping = fno.mapping(lift, ps.mapping, st.mapping)
project, st_project = fno.project(mapping, ps.project, st.project)
return project, (lifting=st_lift, mapping=st_mapping, project=st_project)
end
26 changes: 18 additions & 8 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,28 @@ julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}; permuted=Val(

```
"""
struct OperatorKernel{L1, L2} <: Lux.AbstractExplicitContainerLayer{(:lin, :conv)}
lin::L2
conv::L1
activation::Function
end

function OperatorKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims{N}, transform::Type{TR},
act::A=identity; allow_fast_activation::Bool=false, permuted::Val{perm}=Val(false),
kwargs...) where {N, TR <: AbstractTransform{<:Number}, perm, A}
act = allow_fast_activation ? NNlib.fast_act(act) : act
l₁ = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch)
l₂ = OperatorConv(ch, modes, transform; permuted, kwargs...)

return @compact(; l₁, l₂, activation=act, dispatch=:OperatorKernel) do x::AbstractArray
l₁x = l₁(x)
l₂x = l₂(x)
@return @. activation(l₁x + l₂x)
end
lin = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch)
conv = OperatorConv(ch, modes, transform; permuted, kwargs...)

return OperatorKernel(lin, conv, act)
end

function (op::OperatorKernel)(x::T1, ps, st::NamedTuple) where {T1 <: AbstractArray}
x_conv, st_conv = op.conv(x, ps.conv, st.conv)
x_lin, st_lin = op.lin(x, ps.lin, st.lin)

out = @. op.activation(x_conv + x_lin)
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
return out, (lin=st_lin, conv=st_conv)
end

"""
Expand Down
28 changes: 14 additions & 14 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2}
additional::NoOpLayer, ::NamedTuple) where {T1, T2}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return dropdims(sum(b_ .* t; dims=1); dims=1) # N x nb
return dropdims(sum(b_ .* t; dims=1); dims=1), () # N x nb
end

@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2}
additional::NoOpLayer, ::NamedTuple) where {T1, T2}
# b : p x u x nb
# t : p x N x nb
if size(b, 2) == 1 || size(t, 2) == 1
return sum(b .* t; dims=1) # 1 x N x nb
return sum(b .* t; dims=1), () # 1 x N x nb
else
return batched_matmul(batched_adjoint(b), t) # u x N x b
return batched_matmul(batched_adjoint(b), t), () # u x N x b
end
end

@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2, N}
additional::NoOpLayer, ::NamedTuple) where {T1, T2, N}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]
Expand All @@ -29,34 +29,34 @@ end
t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
# p x (1,1,1...) x N x nb

return dropdims(sum(b_ .* t_; dims=1); dims=1) # u_size x N x nb
return dropdims(sum(b_ .* t_; dims=1); dims=1), () # u_size x N x nb
end

@inline function __project(
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T, params::NamedTuple) where {T1, T2, T}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return additional(b_ .* t) # p x N x nb => out_dims x N x nb
return additional(b_ .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb
end

@inline function __project(
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T, params::NamedTuple) where {T1, T2, T}
# b : p x u x nb
# t : p x N x nb

if size(b, 2) == 1 || size(t, 2) == 1
return additional(b .* t) # p x N x nb => out_dims x N x nb
return additional(b .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb
else
b_ = reshape(b, size(b)[1:2]..., 1, size(b, 3)) # p x u x 1 x nb
t_ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # p x 1 x N x nb

return additional(b_ .* t_) # p x u x N x nb => out_size x N x nb
return additional(b_ .* t_, params.ps, params.st) # p x u x N x nb => out_size x N x nb
end
end

@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
additional::T) where {T1, T2, N, T}
additional::T, params::NamedTuple) where {T1, T2, N, T}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]
Expand All @@ -67,5 +67,5 @@ end
t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
# p x (1,1,1...) x N x nb

return additional(b_ .* t_) # p x u_size x N x nb => out_size x N x nb
return additional(b_ .* t_, params.ps, params.st) # p x u_size x N x nb => out_size x N x nb
end
14 changes: 7 additions & 7 deletions test/deeponet_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk)

ps, st = Lux.setup(rng, deeponet) |> dev
@inferred first(deeponet((u, y), ps, st))
@jet first(deeponet((u, y), ps, st))
@inferred first(deeponet(u, y, ps, st))
@jet first(deeponet(u, y, ps, st))

pred = first(deeponet((u, y), ps, st))
pred = first(deeponet(u, y, ps, st))
@test setup.out_size == size(pred)
end

Expand All @@ -43,10 +43,10 @@
branch=setup.branch, trunk=setup.trunk, additional=setup.additional)

ps, st = Lux.setup(rng, deeponet) |> dev
@inferred first(deeponet((u, y), ps, st))
@jet first(deeponet((u, y), ps, st))
@inferred first(deeponet(u, y, ps, st))
@jet first(deeponet(u, y, ps, st))

pred = first(deeponet((u, y), ps, st))
pred = first(deeponet(u, y, ps, st))
@test setup.out_size == size(pred)

__f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st)))
Expand All @@ -62,7 +62,7 @@
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))

ps, st = Lux.setup(rng, deeponet) |> dev
@test_throws ArgumentError deeponet((u, y), ps, st)
@test_throws ArgumentError deeponet(u, y, ps, st)
end
end
end
1 change: 0 additions & 1 deletion test/fno_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

@testset "$(length(setup.modes))D: permuted = $(setup.permuted)" for setup in setups
fno = FourierNeuralOperator(; setup.chs, setup.modes, setup.permuted)
display(fno)
ps, st = Lux.setup(rng, fno) |> dev

x = rand(rng, Float32, setup.x_size...) |> aType
Expand Down
Loading