Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

structs for Neural Operators #19

Closed
wants to merge 16 commits into from
Closed
4 changes: 0 additions & 4 deletions .buildkite/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ steps:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
commands: |
printf "[LuxTestUtils]\ntarget_modules = [\"NeuralOperators\", \"Lux\", \"LuxLib\"]\n[LuxLib]\ninstability_check = \"error\"\n[LuxCore]\ninstability_check = \"error\"\n" > LocalPreferences.toml
agents:
queue: "juliagpu"
cuda: "*"
Expand All @@ -29,8 +27,6 @@ steps:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
commands: |
printf "[LuxTestUtils]\ntarget_modules = [\"NeuralOperators\", \"Lux\", \"LuxLib\"]\n[LuxLib]\ninstability_check = \"error\"\n[LuxCore]\ninstability_check = \"error\"\n" > LocalPreferences.toml
env:
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
Expand Down
20 changes: 0 additions & 20 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,6 @@ jobs:
- windows-latest
steps:
- uses: actions/checkout@v4
- uses: DamianReeves/write-file-action@master
with:
path: "LocalPreferences.toml"
contents: |
[LuxTestUtils]
target_modules = ["NeuralOperators", "Lux", "LuxLib"]
[LuxLib]
instability_check = "error"
[LuxCore]
instability_check = "error"
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
Expand Down Expand Up @@ -78,16 +68,6 @@ jobs:
version: ["1"]
steps:
- uses: actions/checkout@v4
- uses: DamianReeves/write-file-action@master
with:
path: "LocalPreferences.toml"
contents: |
[LuxTestUtils]
target_modules = ["NeuralOperators", "Lux", "LuxLib"]
[LuxLib]
instability_check = "error"
[LuxCore]
instability_check = "error"
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/QualityCheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ jobs:
- name: Checkout Actions Repository
uses: actions/checkout@v4
- name: Check spelling
uses: crate-ci/[email protected].2
uses: crate-ci/[email protected].5
37 changes: 6 additions & 31 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,22 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[compat]
Aqua = "0.8.7"
ArgCheck = "2.3.0"
ChainRulesCore = "1.24.0"
ConcreteStructs = "0.2.3"
Documenter = "1.4.1"
ExplicitImports = "1.9.0"
FFTW = "1.8.0"
Lux = "0.5.56"
LuxCore = "0.1.15"
LuxDeviceUtils = "0.1.24"
LuxTestUtils = "0.1.15"
NNlib = "0.9.17"
Optimisers = "0.3.3"
Pkg = "1.10"
Lux = "0.5.62"
LuxCore = "0.1.21"
LuxLib = "0.3.40"
NNlib = "0.9.21"
Random = "1.10"
ReTestItems = "1.24.0"
Reexport = "1.2.2"
StableRNGs = "1.0.2"
Test = "1.10"
WeightInitializers = "0.1.7, 1"
Zygote = "0.6.70"
WeightInitializers = "1"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Documenter", "ExplicitImports", "LuxTestUtils", "Optimisers", "Pkg", "ReTestItems", "StableRNGs", "Test", "Zygote"]
4 changes: 2 additions & 2 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ using ConcreteStructs: @concrete
using FFTW: FFTW, irfft, rfft
using Lux
using LuxCore: LuxCore, AbstractExplicitLayer
using LuxDeviceUtils: get_device, LuxAMDGPUDevice
using NNlib: NNlib, ⊠, batched_adjoint
using LuxLib: batched_matmul
using NNlib: NNlib, batched_adjoint
using Random: Random, AbstractRNG
using Reexport: @reexport

Expand Down
108 changes: 59 additions & 49 deletions src/deeponet.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
branch_activation = identity, trunk_activation = identity)
DeepONet(branch, trunk; additional)

Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
`trunk` are same.
Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
nets output should have the same first dimension.

## Keyword arguments:
## Arguments

- `branch`: `Lux` network to be used as branch net.
- `trunk`: `Lux` network to be used as trunk net.

## Keyword Arguments

- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
- `branch_activation`: activation function for branch net
- `trunk_activation`: activation function for trunk net
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`

Expand All @@ -23,7 +23,11 @@ operators", doi: https://arxiv.org/abs/1910.03193
## Example

```jldoctest
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));

julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));

julia> deeponet = DeepONet(branch_net, trunk_net);

julia> ps, st = Lux.setup(Xoshiro(), deeponet);

Expand All @@ -35,37 +39,28 @@ julia> size(first(deeponet((u, y), ps, st)))
(10, 5)
```
"""
function DeepONet(;
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
trunk_activation=identity, additional=nothing)

# checks for last dimension size
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
for i in 1:(length(branch) - 1)]...)

trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
for i in 1:(length(trunk) - 1)]...)

return DeepONet(branch_net, trunk_net; additional)
@concrete struct DeepONet{L1, L2, L3} <:
Lux.AbstractExplicitContainerLayer{(:branch, :trunk, :additional)}
branch::L1
trunk::L2
additional::L3
end

"""
DeepONet(branch, trunk)
DeepONet(branch, trunk) = DeepONet(branch, trunk, NoOpLayer())

Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
nets output should have the same first dimension.

## Arguments
"""
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
branch_activation = identity, trunk_activation = identity)

- `branch`: `Lux` network to be used as branch net.
- `trunk`: `Lux` network to be used as trunk net.
Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
`trunk` are same.

## Keyword Arguments
## Keyword arguments:

- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
- `branch_activation`: activation function for branch net
- `trunk_activation`: activation function for trunk net
- `additional`: `Lux` network to pass the output of DeepONet, to include additional operations
for embeddings, defaults to `nothing`

Expand All @@ -78,11 +73,7 @@ operators", doi: https://arxiv.org/abs/1910.03193
## Example

```jldoctest
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));

julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));

julia> deeponet = DeepONet(branch_net, trunk_net);
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16));

julia> ps, st = Lux.setup(Xoshiro(), deeponet);

Expand All @@ -94,15 +85,34 @@ julia> size(first(deeponet((u, y), ps, st)))
(10, 5)
```
"""
function DeepONet(branch::L1, trunk::L2; additional=nothing) where {L1, L2}
return @compact(; branch, trunk, additional, dispatch=:DeepONet) do (u, y)
t = trunk(y) # p x N x nb
b = branch(u) # p x u_size... x nb
function DeepONet(;
branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity,
trunk_activation=identity, additional=NoOpLayer())

# checks for last dimension size
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
for i in 1:(length(branch) - 1)]...)

trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
for i in 1:(length(trunk) - 1)]...)

return DeepONet(branch_net, trunk_net, additional)
end

function (deeponet::DeepONet)(
u::T1, y::T2, ps, st::NamedTuple) where {T1 <: AbstractArray, T2 <: AbstractArray}
b, st_b = deeponet.branch(u, ps.branch, st.branch)
t, st_t = deeponet.trunk(y, ps.trunk, st.trunk)

@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \
amount of nodes in the last layer. Otherwise \
Σᵢ bᵢⱼ tᵢₖ won't work."
@argcheck size(b, 1)==size(t, 1) "Branch and Trunk net must share the same amount of \
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
work."

@return __project(b, t, additional)
end
out, st_a = __project(
b, t, deeponet.additional, (;ps=ps.additional, st=st.additional))
return out, (branch=st_b, trunk=st_t, additional=st_a)
end
27 changes: 20 additions & 7 deletions src/fno.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,15 @@ julia> size(first(fno(u, ps, st)))
(1, 1024, 5)
```
"""
function FourierNeuralOperator(
σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
@concrete struct FourierNeuralOperator{L1, L2, L3} <:
Lux.AbstractExplicitContainerLayer{(:lifting, :mapping, :project)}
lifting::L1
mapping::L2
project::L3
end

function FourierNeuralOperator(;
σ=gelu, chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,),
permuted::Val{perm}=Val(false), kwargs...) where {C, M, perm}
@argcheck length(chs) ≥ 5

Expand All @@ -52,9 +59,15 @@ function FourierNeuralOperator(
project = perm ? Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) :
Chain(Dense(map₂, σ), Dense(map₃))

return Chain(; lifting,
mapping=Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...)
for i in 2:(C - 3)]...),
project,
name="FourierNeuralOperator")
mapping = Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...)
for i in 2:(C - 3)]...)

return FourierNeuralOperator(lifting, mapping, project)
end

function (fno::FourierNeuralOperator)(x::T, ps, st::NamedTuple) where {T}
lift, st_lift = fno.lifting(x, ps.lifting, st.lifting)
mapping, st_mapping = fno.mapping(lift, ps.mapping, st.mapping)
project, st_project = fno.project(mapping, ps.project, st.project)
return project, (lifting=st_lift, mapping=st_mapping, project=st_project)
end
4 changes: 2 additions & 2 deletions src/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ end
x_size = size(x_tr)
x_flat = reshape(x_tr, :, x_size[N - 1], x_size[N])

x_flat_t = permutedims(x_flat, (2, 3, 1)) # i x b x m
x_weighted = permutedims(__batched_mul(weights, x_flat_t), (3, 1, 2)) # m x o x b
x_flat_t = permutedims(x_flat, (2, 3, 1)) # i x b x m
x_weighted = permutedims(batched_matmul(weights, x_flat_t), (3, 1, 2)) # m x o x b

return reshape(x_weighted, x_size[1:(N - 2)]..., size(x_weighted)[2:3]...)
end
Expand Down
Loading