diff --git a/Project.toml b/Project.toml index 9fc4c05..8522a4b 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,6 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Reexport = "189a3867-3050-52da-a836-e630ba90ab69" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [compat] @@ -21,11 +20,10 @@ ArgCheck = "2.3.0" ChainRulesCore = "1.24.0" ConcreteStructs = "0.2.3" FFTW = "1.8.0" -Lux = "0.5.62" -LuxCore = "0.1.21" -LuxLib = "0.3.40" +Lux = "0.5.64" +LuxCore = "0.1.24" +LuxLib = "0.3.42" NNlib = "0.9.21" Random = "1.10" -Reexport = "1.2.2" WeightInitializers = "1" julia = "1.10" diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 89ad5d0..6a0ca9f 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -5,16 +5,13 @@ using ChainRulesCore: ChainRulesCore, NoTangent using ConcreteStructs: @concrete using FFTW: FFTW, irfft, rfft using Lux -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer using LuxLib: batched_matmul using NNlib: NNlib, batched_adjoint using Random: Random, AbstractRNG -using Reexport: @reexport const CRC = ChainRulesCore -@reexport using Lux - include("utils.jl") include("transform.jl") diff --git a/src/deeponet.jl b/src/deeponet.jl index 05bf97f..8d2cd7f 100644 --- a/src/deeponet.jl +++ b/src/deeponet.jl @@ -1,16 +1,16 @@ """ - DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), - branch_activation = identity, trunk_activation = identity) + DeepONet(branch, trunk, additional) -Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and -`trunk` are same. +Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the +nets output should have the same first dimension. -## Keyword arguments: +## Arguments + + - `branch`: `Lux` network to be used as branch net. + - `trunk`: `Lux` network to be used as trunk net. + +## Keyword Arguments - - `branch`: Tuple of integers containing the number of nodes in each layer for branch net - - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net - - `branch_activation`: activation function for branch net - - `trunk_activation`: activation function for trunk net - `additional`: `Lux` network to pass the output of DeepONet, to include additional operations for embeddings, defaults to `nothing` @@ -23,7 +23,11 @@ operators", doi: https://arxiv.org/abs/1910.03193 ## Example ```jldoctest -julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16)); +julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)); + +julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)); + +julia> deeponet = DeepONet(branch_net, trunk_net); julia> ps, st = Lux.setup(Xoshiro(), deeponet); @@ -35,37 +39,27 @@ julia> size(first(deeponet((u, y), ps, st))) (10, 5) ``` """ -function DeepONet(; - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity, - trunk_activation=identity, additional=nothing) - - # checks for last dimension size - @argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \ - nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ - work." - - branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation) - for i in 1:(length(branch) - 1)]...) - - trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation) - for i in 1:(length(trunk) - 1)]...) - - return DeepONet(branch_net, trunk_net; additional) +@concrete struct DeepONet <: AbstractExplicitContainerLayer{(:branch, :trunk, :additional)} + branch + trunk + additional end -""" - DeepONet(branch, trunk) +DeepONet(branch, trunk) = DeepONet(branch, trunk, NoOpLayer()) -Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the -nets output should have the same first dimension. - -## Arguments +""" + DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), + branch_activation = identity, trunk_activation = identity) - - `branch`: `Lux` network to be used as branch net. - - `trunk`: `Lux` network to be used as trunk net. +Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and +`trunk` are same. -## Keyword Arguments +## Keyword arguments: + - `branch`: Tuple of integers containing the number of nodes in each layer for branch net + - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net + - `branch_activation`: activation function for branch net + - `trunk_activation`: activation function for trunk net - `additional`: `Lux` network to pass the output of DeepONet, to include additional operations for embeddings, defaults to `nothing` @@ -78,11 +72,7 @@ operators", doi: https://arxiv.org/abs/1910.03193 ## Example ```jldoctest -julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)); - -julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)); - -julia> deeponet = DeepONet(branch_net, trunk_net); +julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16)); julia> ps, st = Lux.setup(Xoshiro(), deeponet); @@ -94,15 +84,32 @@ julia> size(first(deeponet((u, y), ps, st))) (10, 5) ``` """ -function DeepONet(branch::L1, trunk::L2; additional=nothing) where {L1, L2} - return @compact(; branch, trunk, additional, dispatch=:DeepONet) do (u, y) - t = trunk(y) # p x N x nb - b = branch(u) # p x u_size... x nb +function DeepONet(; + branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity, + trunk_activation=identity, additional=NoOpLayer()) + + # checks for last dimension size + @argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \ + nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ + work." + + branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation) + for i in 1:(length(branch) - 1)]...) + + trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation) + for i in 1:(length(trunk) - 1)]...) + + return DeepONet(branch_net, trunk_net, additional) +end + +function (deeponet::DeepONet)(x, ps, st::NamedTuple) + b, st_b = deeponet.branch(x[1], ps.branch, st.branch) + t, st_t = deeponet.trunk(x[2], ps.trunk, st.trunk) - @argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \ - amount of nodes in the last layer. Otherwise \ - Σᵢ bᵢⱼ tᵢₖ won't work." + @argcheck size(b, 1)==size(t, 1) "Branch and Trunk net must share the same amount of \ + nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ + work." - @return __project(b, t, additional) - end + out, st_a = __project(b, t, deeponet.additional, (; ps=ps.additional, st=st.additional)) + return out, (branch=st_b, trunk=st_t, additional=st_a) end diff --git a/src/fno.jl b/src/fno.jl index 26f67d0..e5f7cfd 100644 --- a/src/fno.jl +++ b/src/fno.jl @@ -27,7 +27,7 @@ kernels, and two `Dense` layers to project data back to the scalar field of inte ## Example ```jldoctest -julia> fno = FourierNeuralOperator(gelu; chs=(2, 64, 64, 128, 1), modes=(16,)); +julia> fno = FourierNeuralOperator(; σ=gelu, chs=(2, 64, 64, 128, 1), modes=(16,)); julia> ps, st = Lux.setup(Xoshiro(), fno); @@ -37,8 +37,15 @@ julia> size(first(fno(u, ps, st))) (1, 1024, 5) ``` """ -function FourierNeuralOperator( - σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,), +@concrete struct FourierNeuralOperator <: + AbstractExplicitContainerLayer{(:lifting, :mapping, :project)} + lifting + mapping + project +end + +function FourierNeuralOperator(; + σ=gelu, chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,), permuted::Val{perm}=Val(false), kwargs...) where {C, M, perm} @argcheck length(chs) ≥ 5 @@ -52,9 +59,15 @@ function FourierNeuralOperator( project = perm ? Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) : Chain(Dense(map₂, σ), Dense(map₃)) - return Chain(; lifting, - mapping=Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...) - for i in 2:(C - 3)]...), - project, - name="FourierNeuralOperator") + mapping = Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...) + for i in 2:(C - 3)]...) + + return FourierNeuralOperator(lifting, mapping, project) +end + +function (fno::FourierNeuralOperator)(x::AbstractArray, ps, st::NamedTuple) + lift, st_lift = fno.lifting(x, ps.lifting, st.lifting) + mapping, st_mapping = fno.mapping(lift, ps.mapping, st.mapping) + project, st_project = fno.project(mapping, ps.project, st.project) + return project, (lifting=st_lift, mapping=st_mapping, project=st_project) end diff --git a/src/layers.jl b/src/layers.jl index 0f23635..227e448 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -116,18 +116,30 @@ julia> OperatorKernel(2 => 5, (16,), FourierTransform{ComplexF64}; permuted=Val( ``` """ +@concrete struct OperatorKernel <: AbstractExplicitContainerLayer{(:lin, :conv)} + lin + conv + activation <: Function +end + +OperatorKernel(lin, conv) = OperatorKernel(lin, conv, identity) + function OperatorKernel(ch::Pair{<:Integer, <:Integer}, modes::Dims{N}, transform::Type{TR}, act::A=identity; allow_fast_activation::Bool=false, permuted::Val{perm}=Val(false), kwargs...) where {N, TR <: AbstractTransform{<:Number}, perm, A} act = allow_fast_activation ? NNlib.fast_act(act) : act - l₁ = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch) - l₂ = OperatorConv(ch, modes, transform; permuted, kwargs...) - - return @compact(; l₁, l₂, activation=act, dispatch=:OperatorKernel) do x::AbstractArray - l₁x = l₁(x) - l₂x = l₂(x) - @return @. activation(l₁x + l₂x) - end + lin = perm ? Conv(map(_ -> 1, modes), ch) : Dense(ch) + conv = OperatorConv(ch, modes, transform; permuted, kwargs...) + + return OperatorKernel(lin, conv, act) +end + +function (op::OperatorKernel)(x::AbstractArray, ps, st::NamedTuple) + x_conv, st_conv = op.conv(x, ps.conv, st.conv) + x_lin, st_lin = op.lin(x, ps.lin, st.lin) + + out = fast_activation!!(op.activation, x_conv .+ x_lin) + return out, (lin=st_lin, conv=st_conv) end """ diff --git a/src/utils.jl b/src/utils.jl index c9f7605..1d4c278 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,24 +1,24 @@ -@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, - additional::Nothing) where {T1, T2} +@inline function __project( + b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2} # b : p x nb # t : p x N x nb b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb - return dropdims(sum(b_ .* t; dims=1); dims=1) # N x nb + return dropdims(sum(b_ .* t; dims=1); dims=1), () # N x nb end -@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, - additional::Nothing) where {T1, T2} +@inline function __project( + b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2} # b : p x u x nb # t : p x N x nb if size(b, 2) == 1 || size(t, 2) == 1 - return sum(b .* t; dims=1) # 1 x N x nb + return sum(b .* t; dims=1), () # 1 x N x nb else - return batched_matmul(batched_adjoint(b), t) # u x N x b + return batched_matmul(batched_adjoint(b), t), () # u x N x b end end -@inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, - additional::Nothing) where {T1, T2, N} +@inline function __project( + b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, ::NoOpLayer, _) where {T1, T2, N} # b : p x u_size x nb # t : p x N x nb u_size = size(b)[2:(end - 1)] @@ -29,34 +29,34 @@ end t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...) # p x (1,1,1...) x N x nb - return dropdims(sum(b_ .* t_; dims=1); dims=1) # u_size x N x nb + return dropdims(sum(b_ .* t_; dims=1); dims=1), () # u_size x N x nb end -@inline function __project( - b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T} +@inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, + additional::T, params) where {T1, T2, T} # b : p x nb # t : p x N x nb b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb - return additional(b_ .* t) # p x N x nb => out_dims x N x nb + return additional(b_ .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb end -@inline function __project( - b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T} +@inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, + additional::T, params) where {T1, T2, T} # b : p x u x nb # t : p x N x nb if size(b, 2) == 1 || size(t, 2) == 1 - return additional(b .* t) # p x N x nb => out_dims x N x nb + return additional(b .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb else b_ = reshape(b, size(b)[1:2]..., 1, size(b, 3)) # p x u x 1 x nb t_ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # p x 1 x N x nb - return additional(b_ .* t_) # p x u x N x nb => out_size x N x nb + return additional(b_ .* t_, params.ps, params.st) # p x u x N x nb => out_size x N x nb end end @inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, - additional::T) where {T1, T2, N, T} + additional::T, params) where {T1, T2, N, T} # b : p x u_size x nb # t : p x N x nb u_size = size(b)[2:(end - 1)] @@ -67,5 +67,5 @@ end t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...) # p x (1,1,1...) x N x nb - return additional(b_ .* t_) # p x u_size x N x nb => out_size x N x nb + return additional(b_ .* t_, params.ps, params.st) # p x u_size x N x nb => out_size x N x nb end diff --git a/test/Project.toml b/test/Project.toml index 4aa9c6a..62e12ab 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -25,9 +25,9 @@ Documenter = "1.5.0" ExplicitImports = "1.9.0" Hwloc = "3.2.0" InteractiveUtils = "<0.0.1, 1" -Lux = "0.5.62" -LuxCore = "0.1.21" -LuxLib = "0.3.40" +Lux = "0.5.64" +LuxCore = "0.1.24" +LuxLib = "0.3.42" LuxTestUtils = "1.1.2" MLDataDevices = "1.0.0" Optimisers = "0.3.3" diff --git a/test/fno_tests.jl b/test/fno_tests.jl index 04f72c7..7a1a523 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -22,11 +22,10 @@ @test size(first(fno(x, ps, st))) == setup.y_size data = [(x, y)] - broken = mode == "AMDGPU" @test begin l2, l1 = train!(fno, ps, st, data; epochs=10) l2 < l1 - end broken=broken + end __f = (x, ps) -> sum(abs2, first(fno(x, ps, st))) test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, diff --git a/test/layers_tests.jl b/test/layers_tests.jl index b4a1016..8ac48ff 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -14,7 +14,7 @@ @testset "$(op) $(length(setup.m))D: permuted = $(setup.permuted)" for setup in setups, op in opconv - p = Lux.__unwrap_val(setup.permuted) + p = Lux.Utils.unwrap_val(setup.permuted) in_chs = ifelse(p, setup.x_size[end - 1], first(setup.x_size)) out_chs = ifelse(p, setup.y_size[end - 1], first(setup.y_size)) ch = 64 => out_chs @@ -31,11 +31,10 @@ @jet m(x, ps, st) data = [(x, aType(rand(rng, Float32, setup.y_size...)))] - broken = mode == "AMDGPU" @test begin l2, l1 = train!(m, ps, st, data; epochs=10) l2 < l1 - end broken=broken + end __f = (x, ps) -> sum(abs2, first(m(x, ps, st))) test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, diff --git a/test/qa_tests.jl b/test/qa_tests.jl index f79e1ec..e9385d5 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -1,8 +1,8 @@ @testitem "doctests: Quality Assurance" tags=[:qa] begin using Documenter, NeuralOperators - DocMeta.setdocmeta!( - NeuralOperators, :DocTestSetup, :(using NeuralOperators, Random); recursive=true) + DocMeta.setdocmeta!(NeuralOperators, :DocTestSetup, + :(using Lux, NeuralOperators, Random); recursive=true) doctest(NeuralOperators; manual=false) end @@ -14,7 +14,7 @@ end end @testitem "Explicit Imports: Quality Assurance" tags=[:qa] begin - using ExplicitImports + using ExplicitImports, Lux # Skip our own packages @test check_no_implicit_imports(NeuralOperators; skip=(Base, Core, Lux)) === nothing