Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structs for NeuralOperators #23

Merged
merged 16 commits into from
Aug 20, 2024
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,17 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[compat]
ArgCheck = "2.3.0"
ChainRulesCore = "1.24.0"
ConcreteStructs = "0.2.3"
FFTW = "1.8.0"
Lux = "0.5.62"
LuxCore = "0.1.21"
LuxLib = "0.3.40"
Lux = "0.5.64"
LuxCore = "0.1.24"
LuxLib = "0.3.42"
NNlib = "0.9.21"
Random = "1.10"
Reexport = "1.2.2"
WeightInitializers = "1"
julia = "1.10"
5 changes: 1 addition & 4 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@ using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
using FFTW: FFTW, irfft, rfft
using Lux
using LuxCore: LuxCore, AbstractExplicitLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using LuxLib: batched_matmul
using NNlib: NNlib, batched_adjoint
using Random: Random, AbstractRNG
using Reexport: @reexport

const CRC = ChainRulesCore

@reexport using Lux

include("utils.jl")
include("transform.jl")

Expand Down
105 changes: 56 additions & 49 deletions src/deeponet.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
branch_activation = identity, trunk_activation = identity)
DeepONet(branch, trunk, additional)

Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
`trunk` are same.
Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
nets output should have the same first dimension.

## Keyword arguments:
## Arguments

- `branch`: `Lux` network to be used as branch net.
- `trunk`: `Lux` network to be used as trunk net.

## Keyword Arguments

- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
- `branch_activation`: activation function for branch net
- `trunk_activation`: activation function for trunk net
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`

Expand All @@ -23,7 +23,11 @@ operators", doi: https://arxiv.org/abs/1910.03193
## Example

```jldoctest
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));

julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));

julia> deeponet = DeepONet(branch_net, trunk_net);

julia> ps, st = Lux.setup(Xoshiro(), deeponet);

Expand All @@ -35,37 +39,27 @@ julia> size(first(deeponet((u, y), ps, st)))
(10, 5)
```
"""
function DeepONet(;
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
trunk_activation=identity, additional=nothing)

# checks for last dimension size
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
for i in 1:(length(branch) - 1)]...)

trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
for i in 1:(length(trunk) - 1)]...)

return DeepONet(branch_net, trunk_net; additional)
@concrete struct DeepONet <: AbstractExplicitContainerLayer{(:branch, :trunk, :additional)}
branch
trunk
additional
end

"""
DeepONet(branch, trunk)
DeepONet(branch, trunk) = DeepONet(branch, trunk, NoOpLayer())

Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
nets output should have the same first dimension.

## Arguments
"""
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
branch_activation = identity, trunk_activation = identity)

- `branch`: `Lux` network to be used as branch net.
- `trunk`: `Lux` network to be used as trunk net.
Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
`trunk` are same.

## Keyword Arguments
## Keyword arguments:

- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
- `branch_activation`: activation function for branch net
- `trunk_activation`: activation function for trunk net
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`

Expand All @@ -78,11 +72,7 @@ operators", doi: https://arxiv.org/abs/1910.03193
## Example

```jldoctest
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));

julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));

julia> deeponet = DeepONet(branch_net, trunk_net);
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));

julia> ps, st = Lux.setup(Xoshiro(), deeponet);

Expand All @@ -94,15 +84,32 @@ julia> size(first(deeponet((u, y), ps, st)))
(10, 5)
```
"""
function DeepONet(branch::L1, trunk::L2; additional=nothing) where {L1, L2}
return @compact(; branch, trunk, additional, dispatch=:DeepONet) do (u, y)
t = trunk(y) # p x N x nb
b = branch(u) # p x u_size... x nb
function DeepONet(;
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
trunk_activation=identity, additional=NoOpLayer())

# checks for last dimension size
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
for i in 1:(length(branch) - 1)]...)

trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
for i in 1:(length(trunk) - 1)]...)

return DeepONet(branch_net, trunk_net, additional)
end

function (deeponet::DeepONet)(x, ps, st::NamedTuple)
b, st_b = deeponet.branch(x[1], ps.branch, st.branch)
t, st_t = deeponet.trunk(x[2], ps.trunk, st.trunk)

@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \
amount of nodes in the last layer. Otherwise \
Σᵢ bᵢⱼ tᵢₖ won't work."
@argcheck size(b, 1)==size(t, 1) "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

@return __project(b, t, additional)
end
out, st_a = __project(b, t, deeponet.additional, (; ps=ps.additional, st=st.additional))
return out, (branch=st_b, trunk=st_t, additional=st_a)
end
29 changes: 21 additions & 8 deletions src/fno.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ kernels, and two `Dense` layers to project data back to the scalar field of inte
## Example

```jldoctest
julia> fno = FourierNeuralOperator(gelu; chs=(2, 64, 64, 128, 1), modes=(16,));
julia> fno = FourierNeuralOperator(; σ=gelu, chs=(2, 64, 64, 128, 1), modes=(16,));

julia> ps, st = Lux.setup(Xoshiro(), fno);

Expand All @@ -37,8 +37,15 @@ julia> size(first(fno(u, ps, st)))
(1, 1024, 5)
```
"""
function FourierNeuralOperator(
σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
@concrete struct FourierNeuralOperator <:
AbstractExplicitContainerLayer{(:lifting, :mapping, :project)}
lifting
mapping
project
end

function FourierNeuralOperator(;
σ=gelu, chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
permuted::Val{perm}=Val(false), kwargs...) where {C, M, perm}
@argcheck length(chs) ≥ 5

Expand All @@ -52,9 +59,15 @@ function FourierNeuralOperator(
project = perm ? Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) :
Chain(Dense(map₂, σ), Dense(map₃))

return Chain(; lifting,
mapping=Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...)
for i in 2:(C - 3)]...),
project,
name="FourierNeuralOperator")
mapping = Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...)
for i in 2:(C - 3)]...)

return FourierNeuralOperator(lifting, mapping, project)
end

function (fno::FourierNeuralOperator)(x::AbstractArray, ps, st::NamedTuple)
lift, st_lift = fno.lifting(x, ps.lifting, st.lifting)
mapping, st_mapping = fno.mapping(lift, ps.mapping, st.mapping)
project, st_project = fno.project(mapping, ps.project, st.project)
return project, (lifting=st_lift, mapping=st_mapping, project=st_project)
end
28 changes: 20 additions & 8 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,30 @@

```
"""
@concrete struct OperatorKernel <: AbstractExplicitContainerLayer{(:lin, :conv)}
lin
conv
activation <: Function
end

OperatorKernel(lin, conv) = OperatorKernel(lin, conv, identity)

Check warning on line 125 in src/layers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers.jl#L125

Added line #L125 was not covered by tests

function OperatorKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims{N}, transform::Type{TR},
act::A=identity; allow_fast_activation::Bool=false, permuted::Val{perm}=Val(false),
kwargs...) where {N, TR <: AbstractTransform{<:Number}, perm, A}
act = allow_fast_activation ? NNlib.fast_act(act) : act
l₁ = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch)
l₂ = OperatorConv(ch, modes, transform; permuted, kwargs...)

return @compact(; l₁, l₂, activation=act, dispatch=:OperatorKernel) do x::AbstractArray
l₁x = l₁(x)
l₂x = l₂(x)
@return @. activation(l₁x + l₂x)
end
lin = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch)
conv = OperatorConv(ch, modes, transform; permuted, kwargs...)

return OperatorKernel(lin, conv, act)
end

function (op::OperatorKernel)(x::AbstractArray, ps, st::NamedTuple)
x_conv, st_conv = op.conv(x, ps.conv, st.conv)
x_lin, st_lin = op.lin(x, ps.lin, st.lin)

out = fast_activation!!(op.activation, x_conv .+ x_lin)
return out, (lin=st_lin, conv=st_conv)
end

"""
Expand Down
38 changes: 19 additions & 19 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2}
@inline function __project(
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return dropdims(sum(b_ .* t; dims=1); dims=1) # N x nb
return dropdims(sum(b_ .* t; dims=1); dims=1), () # N x nb
end

@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2}
@inline function __project(
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2}
# b : p x u x nb
# t : p x N x nb
if size(b, 2) == 1 || size(t, 2) == 1
return sum(b .* t; dims=1) # 1 x N x nb
return sum(b .* t; dims=1), () # 1 x N x nb
else
return batched_matmul(batched_adjoint(b), t) # u x N x b
return batched_matmul(batched_adjoint(b), t), () # u x N x b
end
end

@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
additional::Nothing) where {T1, T2, N}
@inline function __project(
b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2, N}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]
Expand All @@ -29,34 +29,34 @@
t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
# p x (1,1,1...) x N x nb

return dropdims(sum(b_ .* t_; dims=1); dims=1) # u_size x N x nb
return dropdims(sum(b_ .* t_; dims=1); dims=1), () # u_size x N x nb
end

@inline function __project(
b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3},
additional::T, params) where {T1, T2, T}
# b : p x nb
# t : p x N x nb
b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb
return additional(b_ .* t) # p x N x nb => out_dims x N x nb
return additional(b_ .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb
end

@inline function __project(
b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T}
@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3},
additional::T, params) where {T1, T2, T}
# b : p x u x nb
# t : p x N x nb

if size(b, 2) == 1 || size(t, 2) == 1
return additional(b .* t) # p x N x nb => out_dims x N x nb
return additional(b .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb
else
b_ = reshape(b, size(b)[1:2]..., 1, size(b, 3)) # p x u x 1 x nb
t_ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # p x 1 x N x nb

return additional(b_ .* t_) # p x u x N x nb => out_size x N x nb
return additional(b_ .* t_, params.ps, params.st) # p x u x N x nb => out_size x N x nb
end
end

@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3},
additional::T) where {T1, T2, N, T}
additional::T, params) where {T1, T2, N, T}
# b : p x u_size x nb
# t : p x N x nb
u_size = size(b)[2:(end - 1)]
Expand All @@ -67,5 +67,5 @@
t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...)
# p x (1,1,1...) x N x nb

return additional(b_ .* t_) # p x u_size x N x nb => out_size x N x nb
return additional(b_ .* t_, params.ps, params.st) # p x u_size x N x nb => out_size x N x nb

Check warning on line 70 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L70

Added line #L70 was not covered by tests
end
6 changes: 3 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ Documenter = "1.5.0"
ExplicitImports = "1.9.0"
Hwloc = "3.2.0"
InteractiveUtils = "<0.0.1, 1"
Lux = "0.5.62"
LuxCore = "0.1.21"
LuxLib = "0.3.40"
Lux = "0.5.64"
LuxCore = "0.1.24"
LuxLib = "0.3.42"
LuxTestUtils = "1.1.2"
MLDataDevices = "1.0.0"
Optimisers = "0.3.3"
Expand Down
3 changes: 1 addition & 2 deletions test/fno_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
@test size(first(fno(x, ps, st))) == setup.y_size

data = [(x, y)]
broken = mode == "AMDGPU"
@test begin
l2, l1 = train!(fno, ps, st, data; epochs=10)
l2 < l1
end broken=broken
end

__f = (x, ps) -> sum(abs2, first(fno(x, ps, st)))
test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,
Expand Down
Loading
Loading