diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index 605332b..7979b22 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -7,8 +7,6 @@ steps: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true - commands: | - printf "[LuxTestUtils]\ntarget_modules = [\"NeuralOperators\", \"Lux\", \"LuxLib\"]\n[LuxLib]\ninstability_check = \"error\"\n[LuxCore]\ninstability_check = \"error\"\n" > LocalPreferences.toml agents: queue: "juliagpu" cuda: "*" @@ -29,8 +27,6 @@ steps: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true - commands: | - printf "[LuxTestUtils]\ntarget_modules = [\"NeuralOperators\", \"Lux\", \"LuxLib\"]\n[LuxLib]\ninstability_check = \"error\"\n[LuxCore]\ninstability_check = \"error\"\n" > LocalPreferences.toml env: JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a68b3d7..037da9c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -35,16 +35,6 @@ jobs: - windows-latest steps: - uses: actions/checkout@v4 - - uses: DamianReeves/write-file-action@master - with: - path: "LocalPreferences.toml" - contents: | - [LuxTestUtils] - target_modules = ["NeuralOperators", "Lux", "LuxLib"] - [LuxLib] - instability_check = "error" - [LuxCore] - instability_check = "error" - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} @@ -78,16 +68,6 @@ jobs: version: ["1"] steps: - uses: actions/checkout@v4 - - uses: DamianReeves/write-file-action@master - with: - path: "LocalPreferences.toml" - contents: | - [LuxTestUtils] - target_modules = ["NeuralOperators", "Lux", "LuxLib"] - [LuxLib] - instability_check = "error" - [LuxCore] - instability_check = "error" - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} diff --git a/.github/workflows/QualityCheck.yml b/.github/workflows/QualityCheck.yml index 0dac8cb..1f204df 100644 --- a/.github/workflows/QualityCheck.yml +++ b/.github/workflows/QualityCheck.yml @@ -16,4 +16,4 @@ jobs: - name: Checkout Actions Repository uses: actions/checkout@v4 - name: Check spelling - uses: crate-ci/typos@v1.23.2 + uses: crate-ci/typos@v1.23.5 diff --git a/Project.toml b/Project.toml index 013a3ea..9fc4c05 100644 --- a/Project.toml +++ b/Project.toml @@ -10,47 +10,22 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [compat] -Aqua = "0.8.7" ArgCheck = "2.3.0" ChainRulesCore = "1.24.0" ConcreteStructs = "0.2.3" -Documenter = "1.4.1" -ExplicitImports = "1.9.0" FFTW = "1.8.0" -Lux = "0.5.56" -LuxCore = "0.1.15" -LuxDeviceUtils = "0.1.24" -LuxTestUtils = "0.1.15" -NNlib = "0.9.17" -Optimisers = "0.3.3" -Pkg = "1.10" +Lux = "0.5.62" +LuxCore = "0.1.21" +LuxLib = "0.3.40" +NNlib = "0.9.21" Random = "1.10" -ReTestItems = "1.24.0" Reexport = "1.2.2" -StableRNGs = "1.0.2" -Test = "1.10" -WeightInitializers = "0.1.7, 1" -Zygote = "0.6.70" +WeightInitializers = "1" julia = "1.10" - -[extras] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" -LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[targets] -test = ["Aqua", "Documenter", "ExplicitImports", "LuxTestUtils", "Optimisers", "Pkg", "ReTestItems", "StableRNGs", "Test", "Zygote"] diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 6afd764..89ad5d0 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -6,8 +6,8 @@ using ConcreteStructs: @concrete using FFTW: FFTW, irfft, rfft using Lux using LuxCore: LuxCore, AbstractExplicitLayer -using LuxDeviceUtils: get_device, LuxAMDGPUDevice -using NNlib: NNlib, ⊠, batched_adjoint +using LuxLib: batched_matmul +using NNlib: NNlib, batched_adjoint using Random: Random, AbstractRNG using Reexport: @reexport diff --git a/src/deeponet.jl b/src/deeponet.jl index 05bf97f..948ce33 100644 --- a/src/deeponet.jl +++ b/src/deeponet.jl @@ -1,16 +1,16 @@ """ - DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), - branch_activation = identity, trunk_activation = identity) + DeepONet(branch, trunk; additional) -Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and -`trunk` are same. +Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the +nets output should have the same first dimension. -## Keyword arguments: +## Arguments + + - `branch`: `Lux` network to be used as branch net. + - `trunk`: `Lux` network to be used as trunk net. + +## Keyword Arguments - - `branch`: Tuple of integers containing the number of nodes in each layer for branch net - - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net - - `branch_activation`: activation function for branch net - - `trunk_activation`: activation function for trunk net - `additional`: `Lux` network to pass the output of DeepONet, to include additional operations for embeddings, defaults to `nothing` @@ -23,7 +23,11 @@ operators", doi: https://arxiv.org/abs/1910.03193 ## Example ```jldoctest -julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16)); +julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)); + +julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)); + +julia> deeponet = DeepONet(branch_net, trunk_net); julia> ps, st = Lux.setup(Xoshiro(), deeponet); @@ -35,37 +39,28 @@ julia> size(first(deeponet((u, y), ps, st))) (10, 5) ``` """ -function DeepONet(; - branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity, - trunk_activation=identity, additional=nothing) - - # checks for last dimension size - @argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \ - nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ - work." - - branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation) - for i in 1:(length(branch) - 1)]...) - - trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation) - for i in 1:(length(trunk) - 1)]...) - - return DeepONet(branch_net, trunk_net; additional) +@concrete struct DeepONet{L1, L2, L3} <: + Lux.AbstractExplicitContainerLayer{(:branch, :trunk, :additional)} + branch::L1 + trunk::L2 + additional::L3 end -""" - DeepONet(branch, trunk) +DeepONet(branch, trunk) = DeepONet(branch, trunk, NoOpLayer()) -Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the -nets output should have the same first dimension. - -## Arguments +""" + DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), + branch_activation = identity, trunk_activation = identity) - - `branch`: `Lux` network to be used as branch net. - - `trunk`: `Lux` network to be used as trunk net. +Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and +`trunk` are same. -## Keyword Arguments +## Keyword arguments: + - `branch`: Tuple of integers containing the number of nodes in each layer for branch net + - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net + - `branch_activation`: activation function for branch net + - `trunk_activation`: activation function for trunk net - `additional`: `Lux` network to pass the output of DeepONet, to include additional operations for embeddings, defaults to `nothing` @@ -78,11 +73,7 @@ operators", doi: https://arxiv.org/abs/1910.03193 ## Example ```jldoctest -julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)); - -julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)); - -julia> deeponet = DeepONet(branch_net, trunk_net); +julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16)); julia> ps, st = Lux.setup(Xoshiro(), deeponet); @@ -94,15 +85,34 @@ julia> size(first(deeponet((u, y), ps, st))) (10, 5) ``` """ -function DeepONet(branch::L1, trunk::L2; additional=nothing) where {L1, L2} - return @compact(; branch, trunk, additional, dispatch=:DeepONet) do (u, y) - t = trunk(y) # p x N x nb - b = branch(u) # p x u_size... x nb +function DeepONet(; + branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), branch_activation=identity, + trunk_activation=identity, additional=NoOpLayer()) + + # checks for last dimension size + @argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \ + nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ + work." + + branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation) + for i in 1:(length(branch) - 1)]...) + + trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation) + for i in 1:(length(trunk) - 1)]...) + + return DeepONet(branch_net, trunk_net, additional) +end + +function (deeponet::DeepONet)( + u::T1, y::T2, ps, st::NamedTuple) where {T1 <: AbstractArray, T2 <: AbstractArray} + b, st_b = deeponet.branch(u, ps.branch, st.branch) + t, st_t = deeponet.trunk(y, ps.trunk, st.trunk) - @argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \ - amount of nodes in the last layer. Otherwise \ - Σᵢ bᵢⱼ tᵢₖ won't work." + @argcheck size(b, 1)==size(t, 1) "Branch and Trunk net must share the same amount of \ + nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ + work." - @return __project(b, t, additional) - end + out, st_a = __project( + b, t, deeponet.additional, (;ps=ps.additional, st=st.additional)) + return out, (branch=st_b, trunk=st_t, additional=st_a) end diff --git a/src/fno.jl b/src/fno.jl index 26f67d0..a25c152 100644 --- a/src/fno.jl +++ b/src/fno.jl @@ -37,8 +37,15 @@ julia> size(first(fno(u, ps, st))) (1, 1024, 5) ``` """ -function FourierNeuralOperator( - σ=gelu; chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,), +@concrete struct FourierNeuralOperator{L1, L2, L3} <: + Lux.AbstractExplicitContainerLayer{(:lifting, :mapping, :project)} + lifting::L1 + mapping::L2 + project::L3 +end + +function FourierNeuralOperator(; + σ=gelu, chs::Dims{C}=(2, 64, 64, 64, 64, 64, 128, 1), modes::Dims{M}=(16,), permuted::Val{perm}=Val(false), kwargs...) where {C, M, perm} @argcheck length(chs) ≥ 5 @@ -52,9 +59,15 @@ function FourierNeuralOperator( project = perm ? Chain(Conv(kernel_size, map₂, σ), Conv(kernel_size, map₃)) : Chain(Dense(map₂, σ), Dense(map₃)) - return Chain(; lifting, - mapping=Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...) - for i in 2:(C - 3)]...), - project, - name="FourierNeuralOperator") + mapping = Chain([SpectralKernel(chs[i] => chs[i + 1], modes, σ; permuted, kwargs...) + for i in 2:(C - 3)]...) + + return FourierNeuralOperator(lifting, mapping, project) +end + +function (fno::FourierNeuralOperator)(x::T, ps, st::NamedTuple) where {T} + lift, st_lift = fno.lifting(x, ps.lifting, st.lifting) + mapping, st_mapping = fno.mapping(lift, ps.mapping, st.mapping) + project, st_project = fno.project(mapping, ps.project, st.project) + return project, (lifting=st_lift, mapping=st_mapping, project=st_project) end diff --git a/src/functional.jl b/src/functional.jl index 6cdf001..d4942ea 100644 --- a/src/functional.jl +++ b/src/functional.jl @@ -12,8 +12,8 @@ end x_size = size(x_tr) x_flat = reshape(x_tr, :, x_size[N - 1], x_size[N]) - x_flat_t = permutedims(x_flat, (2, 3, 1)) # i x b x m - x_weighted = permutedims(__batched_mul(weights, x_flat_t), (3, 1, 2)) # m x o x b + x_flat_t = permutedims(x_flat, (2, 3, 1)) # i x b x m + x_weighted = permutedims(batched_matmul(weights, x_flat_t), (3, 1, 2)) # m x o x b return reshape(x_weighted, x_size[1:(N - 2)]..., size(x_weighted)[2:3]...) end diff --git a/src/utils.jl b/src/utils.jl index 5f37e01..1f8409c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,36 +1,24 @@ -# Temporarily capture certain calls like AMDGPU for ComplexFloats -@inline __batched_mul(x, y) = __batched_mul(x, y, get_device((x, y))) -@inline function __batched_mul( - x::AbstractArray{<:Number, 3}, y::AbstractArray{<:Number, 3}, _) - return x ⊠ y -end -@inline function __batched_mul( - x::AbstractArray{<:Complex, 3}, y::AbstractArray{<:Complex, 3}, ::LuxAMDGPUDevice) - # FIXME: This is not good for performance but that is okay for now - return stack(*, eachslice(x; dims=3), eachslice(y; dims=3)) -end - @inline function __project(b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, - additional::Nothing) where {T1, T2} + additional::NoOpLayer, ::NamedTuple) where {T1, T2} # b : p x nb # t : p x N x nb b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb - return dropdims(sum(b_ .* t; dims=1); dims=1) # N x nb + return dropdims(sum(b_ .* t; dims=1); dims=1), () # N x nb end @inline function __project(b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, - additional::Nothing) where {T1, T2} + additional::NoOpLayer, ::NamedTuple) where {T1, T2} # b : p x u x nb # t : p x N x nb if size(b, 2) == 1 || size(t, 2) == 1 - return sum(b .* t; dims=1) # 1 x N x nb + return sum(b .* t; dims=1), () # 1 x N x nb else - return __batched_mul(batched_adjoint(b), t) # u x N x b + return __batched_mul(batched_adjoint(b), t), () # u x N x b end end @inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, - additional::Nothing) where {T1, T2, N} + additional::NoOpLayer, ::NamedTuple) where {T1, T2, N} # b : p x u_size x nb # t : p x N x nb u_size = size(b)[2:(end - 1)] @@ -41,34 +29,34 @@ end t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...) # p x (1,1,1...) x N x nb - return dropdims(sum(b_ .* t_; dims=1); dims=1) # u_size x N x nb + return dropdims(sum(b_ .* t_; dims=1); dims=1), () # u_size x N x nb end @inline function __project( - b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T} + b::AbstractArray{T1, 2}, t::AbstractArray{T2, 3}, additional::T, params::NamedTuple) where {T1, T2, T} # b : p x nb # t : p x N x nb b_ = reshape(b, size(b, 1), 1, size(b, 2)) # p x 1 x nb - return additional(b_ .* t) # p x N x nb => out_dims x N x nb + return additional(b_ .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb end @inline function __project( - b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T) where {T1, T2, T} + b::AbstractArray{T1, 3}, t::AbstractArray{T2, 3}, additional::T, params::NamedTuple) where {T1, T2, T} # b : p x u x nb # t : p x N x nb if size(b, 2) == 1 || size(t, 2) == 1 - return additional(b .* t) # p x N x nb => out_dims x N x nb + return additional(b .* t, params.ps, params.st) # p x N x nb => out_dims x N x nb else b_ = reshape(b, size(b)[1:2]..., 1, size(b, 3)) # p x u x 1 x nb t_ = reshape(t, size(t, 1), 1, size(t)[2:end]...) # p x 1 x N x nb - return additional(b_ .* t_) # p x u x N x nb => out_size x N x nb + return additional(b_ .* t_, params.ps, params.st) # p x u x N x nb => out_size x N x nb end end @inline function __project(b::AbstractArray{T1, N}, t::AbstractArray{T2, 3}, - additional::T) where {T1, T2, N, T} + additional::T, params::NamedTuple) where {T1, T2, N, T} # b : p x u_size x nb # t : p x N x nb u_size = size(b)[2:(end - 1)] @@ -79,5 +67,5 @@ end t_ = reshape(t, size(t, 1), ones(eltype(u_size), length(u_size))..., size(t)[2:end]...) # p x (1,1,1...) x N x nb - return additional(b_ .* t_) # p x u_size x N x nb => out_size x N x nb + return additional(b_ .* t_, params.ps, params.st) # p x u_size x N x nb => out_size x N x nb end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..4aa9c6a --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,41 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Aqua = "0.8.7" +Documenter = "1.5.0" +ExplicitImports = "1.9.0" +Hwloc = "3.2.0" +InteractiveUtils = "<0.0.1, 1" +Lux = "0.5.62" +LuxCore = "0.1.21" +LuxLib = "0.3.40" +LuxTestUtils = "1.1.2" +MLDataDevices = "1.0.0" +Optimisers = "0.3.3" +Pkg = "1.10" +Preferences = "1" +Random = "1.10" +ReTestItems = "1.24.0" +Reexport = "1.2.2" +StableRNGs = "1.0.2" +Test = "1.10" +Zygote = "0.6.70" diff --git a/test/deeponet_tests.jl b/test/deeponet_tests.jl index 80a3985..83944d2 100644 --- a/test/deeponet_tests.jl +++ b/test/deeponet_tests.jl @@ -18,10 +18,10 @@ deeponet = DeepONet(; branch=setup.branch, trunk=setup.trunk) ps, st = Lux.setup(rng, deeponet) |> dev - @inferred first(deeponet((u, y), ps, st)) - @jet first(deeponet((u, y), ps, st)) + @inferred first(deeponet(u, y, ps, st)) + @jet first(deeponet(u, y, ps, st)) - pred = first(deeponet((u, y), ps, st)) + pred = first(deeponet(u, y, ps, st)) @test setup.out_size == size(pred) end @@ -43,11 +43,15 @@ branch=setup.branch, trunk=setup.trunk, additional=setup.additional) ps, st = Lux.setup(rng, deeponet) |> dev - @inferred first(deeponet((u, y), ps, st)) - @jet first(deeponet((u, y), ps, st)) + @inferred first(deeponet(u, y, ps, st)) + @jet first(deeponet(u, y, ps, st)) - pred = first(deeponet((u, y), ps, st)) + pred = first(deeponet(u, y, ps, st)) @test setup.out_size == size(pred) + + __f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st))) + test_gradients( + __f, u, y, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()]) end @testset "Embedding layer mismatch" begin @@ -58,7 +62,7 @@ Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16))) ps, st = Lux.setup(rng, deeponet) |> dev - @test_throws ArgumentError deeponet((u, y), ps, st) + @test_throws ArgumentError deeponet(u, y, ps, st) end end end diff --git a/test/fno_tests.jl b/test/fno_tests.jl index b9b8dac..e6d4cd8 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -10,7 +10,6 @@ @testset "$(length(setup.modes))D: permuted = $(setup.permuted)" for setup in setups fno = FourierNeuralOperator(; setup.chs, setup.modes, setup.permuted) - display(fno) ps, st = Lux.setup(rng, fno) |> dev x = rand(rng, Float32, setup.x_size...) |> aType @@ -27,6 +26,10 @@ l2, l1 = train!(fno, ps, st, data; epochs=10) l2 < l1 end broken=broken + + __f = (x, ps) -> sum(abs2, first(fno(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, + skip_backends=[AutoEnzyme(), AutoTracker(), AutoReverseDiff()]) end end end diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 293182b..b4a1016 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -36,6 +36,10 @@ l2, l1 = train!(m, ps, st, data; epochs=10) l2 < l1 end broken=broken + + __f = (x, ps) -> sum(abs2, first(m(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, + skip_backends=[AutoEnzyme(), AutoTracker(), AutoReverseDiff()]) end end end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 0e80dbd..f79e1ec 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -9,7 +9,8 @@ end @testitem "Aqua: Quality Assurance" tags=[:qa] begin using Aqua - Aqua.test_all(NeuralOperators) + Aqua.test_all(NeuralOperators; ambiguities=false) + Aqua.test_ambiguities(NeuralOperators; recursive=false) end @testitem "Explicit Imports: Quality Assurance" tags=[:qa] begin diff --git a/test/runtests.jl b/test/runtests.jl index 765fe75..1987473 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,9 @@ -using ReTestItems, Pkg, ReTestItems, Test +using Preferences + +Preferences.set_preferences!("LuxLib", "instability_check" => "error") +Preferences.set_preferences!("LuxCore", "instability_check" => "error") + +using ReTestItems, Pkg, Test, InteractiveUtils, Hwloc, NeuralOperators const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) @@ -14,6 +19,13 @@ if !isempty(EXTRA_PKGS) Pkg.instantiate() end +const RETESTITEMS_NWORKERS = parse( + Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16)))) +const RETESTITEMS_NWORKER_THREADS = parse(Int, + get(ENV, "RETESTITEMS_NWORKER_THREADS", + string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1)))) + @testset "NeuralOperators.jl Tests" begin - ReTestItems.runtests(@__DIR__) + ReTestItems.runtests(NeuralOperators; nworkers=RETESTITEMS_NWORKERS, + nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index b5d00bc..6ce77fe 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -1,8 +1,10 @@ @testsetup module SharedTestSetup import Reexport: @reexport -@reexport using Lux, Zygote, Optimisers, Random, StableRNGs -using LuxTestUtils: @jet, @test_gradients +@reexport using Lux, Zygote, Optimisers, Random, StableRNGs, LuxTestUtils +using MLDataDevices + +LuxTestUtils.jet_target_modules!(["NeuralOperators", "Lux", "LuxLib"]) const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) @@ -17,18 +19,18 @@ end cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" function cuda_testing() return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && - LuxDeviceUtils.functional(LuxCUDADevice) + MLDataDevices.functional(CUDADevice) end function amdgpu_testing() return (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && - LuxDeviceUtils.functional(LuxAMDGPUDevice) + MLDataDevices.functional(AMDGPUDevice) end const MODES = begin modes = [] - cpu_testing() && push!(modes, ("CPU", Array, LuxCPUDevice(), false)) - cuda_testing() && push!(modes, ("CUDA", CuArray, LuxCUDADevice(), true)) - amdgpu_testing() && push!(modes, ("AMDGPU", ROCArray, LuxAMDGPUDevice(), true)) + cpu_testing() && push!(modes, ("CPU", Array, CPUDevice(), false)) + cuda_testing() && push!(modes, ("CUDA", CuArray, CUDADevice(), true)) + amdgpu_testing() && push!(modes, ("AMDGPU", ROCArray, AMDGPUDevice(), true)) modes end @@ -37,9 +39,9 @@ train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...) function train!(loss, backend, model, ps, st, data; epochs=10) l1 = loss(model, ps, st, first(data)) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.01f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.01f0)) for _ in 1:epochs, (x, y) in data - _, _, _, tstate = Lux.Experimental.single_train_step!(backend, loss, (x, y), tstate) + _, _, _, tstate = Training.single_train_step!(backend, loss, (x, y), tstate) end l2 = loss(model, ps, st, first(data)) @@ -47,7 +49,7 @@ function train!(loss, backend, model, ps, st, data; epochs=10) return l2, l1 end -export @jet, @test_gradients, check_approx +export check_approx export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, train! end