Skip to content

Commit

Permalink
Merge pull request #8 from LuxDL/ap/fixes
Browse files Browse the repository at this point in the history
Inefficient matmul version of DeepONet
  • Loading branch information
avik-pal authored Jun 26, 2024
2 parents 67d9007 + 9e4a069 commit 44c39bb
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ steps:
env:
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
SECRET_CODECOV_TOKEN: "Tg/DGJmBhzxJQBcGajfE2McAOuNVa6zpMZGw0rYTTTGpE7dsBg8cDuj5D9tmLZYdNXJxlkSrjQkjQiPelqECIlMieRveDJ/S3bnA1meJk5p8/PIzwJzQiMCrXpX+xbhcHPn9aQoMmloqP/u6eJ7ToYineDiGbtvQnofVvH0cTgEj/xD15Dflo3K9m/w5/vdvaRbSrxIMc1Z7md/m2XSJJHyLD2Zkir2YWk2cZpyq/S7mA0zL2Yeur27tkzsjSPN/Y+vS5+LLdr5yxo9OVTCAJAZDVsBJGf1Ynd8y4T7usfK+fa41Se48ZpKA/VZtSSZQKdTHM0JcVpqe+Z5L9zbGGg==;U2FsdGVkX18VAT6PhLvJvEVkHs4vFg/vBLTECZAdWznsrPEISjpgl00GTYqrxMw30trS4RDWRSdY1TRYAC85QQ=="
SECRET_CODECOV_TOKEN: "vn5M+4wSwUFje6fl6UB/Q/rTmLHu3OlCCMgoPOXPQHYpLZTLz2hOHsV44MadAnxw8MsNVxLKZlXBKqP3IydU9gUfV7QUBtnvbUmIvgUHbr+r0bVaIVVhw6cnd0s8/b+561nU483eRJd35bjYDOlO+V5eDxkbdh/0bzLefXNXy5+ALxsBYzsp75Sx/9nuREfRqWwU6S45mne2ZlwCDpZlFvBDXQ2ICKYXpA45MpxhW9RuqfpQdi6sSR6I/HdHkV2cuJO99dqqh8xfUy6vWPC/+HUVrn9ETsrXtayX1MX3McKj869htGICpR8vqd311HTONYVprH2AN1bJqr5MOIZ8Xg==;U2FsdGVkX1+W55pTI7zq+NwYrbK6Cgqe+Gp8wMCmXY+W10aXTB0bS6zshiDYSQ1Y3piT91xFyNhS+9AsajY0yQ=="
8 changes: 2 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ ConcreteStructs = "0.2.3"
Documenter = "1.4.1"
ExplicitImports = "1.6.0"
FFTW = "1.8.0"
Lux = "0.5.53"
LuxAMDGPU = "0.2.3"
Lux = "0.5.56"
LuxCUDA = "0.3.2"
LuxCore = "0.1.15"
LuxTestUtils = "0.1.15"
Expand All @@ -43,7 +42,6 @@ Random = "1.10"
ReTestItems = "1.24.0"
Reexport = "1.2.2"
StableRNGs = "1.0.2"
Statistics = "1.10"
Test = "1.10"
WeightInitializers = "0.1.7"
Zygote = "0.6.70"
Expand All @@ -53,15 +51,13 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Documenter", "ExplicitImports", "LuxAMDGPU", "LuxCUDA", "LuxTestUtils", "Optimisers", "ReTestItems", "StableRNGs", "Statistics", "Test", "Zygote"]
test = ["Aqua", "Documenter", "ExplicitImports", "AMDGPU", "LuxCUDA", "LuxTestUtils", "Optimisers", "ReTestItems", "StableRNGs", "Test", "Zygote"]
94 changes: 46 additions & 48 deletions src/deeponet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,31 @@ operators", doi: https://arxiv.org/abs/1910.03193
## Example
```jldoctest
deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
# output
Branch net :
(
Chain(
julia> deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
@compact(
branch = Chain(
layer_1 = Dense(64 => 32), # 2_080 parameters
layer_2 = Dense(32 => 32), # 1_056 parameters
layer_3 = Dense(32 => 16), # 528 parameters
),
)
Trunk net :
(
Chain(
trunk = Chain(
layer_1 = Dense(1 => 8), # 16 parameters
layer_2 = Dense(8 => 8), # 72 parameters
layer_3 = Dense(8 => 16), # 144 parameters
),
)
) do (u, y)
t = trunk(y)
b = branch(u)
@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
@argcheck size(t, 1) == size(b, 1) "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work."
b_ = if ndims(t) == ndims(b)
b
else
reshape(b, size(b, 1), 1, (size(b))[2:end]...)
end
return dropdims(sum(t .* b_; dims = 1); dims = 1)
end # Total: 3_896 parameters,
# plus 0 states.
```
"""
function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
Expand Down Expand Up @@ -81,54 +85,48 @@ operators", doi: https://arxiv.org/abs/1910.03193
## Example
```jldoctest
branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
don_ = DeepONet(branch_net, trunk_net)
julia> branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
# output
julia> trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
Branch net :
(
Chain(
julia> deeponet = DeepONet(branch_net, trunk_net)
@compact(
branch = Chain(
layer_1 = Dense(64 => 32), # 2_080 parameters
layer_2 = Dense(32 => 32), # 1_056 parameters
layer_3 = Dense(32 => 16), # 528 parameters
),
)
Trunk net :
(
Chain(
trunk = Chain(
layer_1 = Dense(1 => 8), # 16 parameters
layer_2 = Dense(8 => 8), # 72 parameters
layer_3 = Dense(8 => 16), # 144 parameters
),
)
) do (u, y)
t = trunk(y)
b = branch(u)
@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
@argcheck size(t, 1) == size(b, 1) "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work."
b_ = if ndims(t) == ndims(b)
b
else
reshape(b, size(b, 1), 1, (size(b))[2:end]...)
end
return dropdims(sum(t .* b_; dims = 1); dims = 1)
end # Total: 3_896 parameters,
# plus 0 states.
```
"""
function DeepONet(branch::L1, trunk::L2) where {L1, L2}
return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y) # ::AbstractArray{<:Real, M} where {M}
t = trunk(y) # p x N x nb
b = branch(u) # p x nb

# checks for last dimension size
@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same amount \
of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \
won't work."

tᵀ = permutedims(t, (2, 1, 3)) # N x p x nb
b_ = permutedims(reshape(b, size(b)..., 1), (1, 3, 2)) # p x 1 x nb
G = batched_mul(tᵀ, b_) # N x 1 X nb
@return dropdims(G; dims=2)
end
end
return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y)
t = trunk(y) # p x N x nb...
b = branch(u) # p x nb...

function Base.show(io::IO, model::Lux.CompactLuxLayer{:DeepONet})
Lux._print_wrapper_model(io, "Branch net :\n", model.layers.branch)
print(io, "\n \n")
Lux._print_wrapper_model(io, "Trunk net :\n", model.layers.trunk)
end
@argcheck ndims(t) == ndims(b) + 1 || ndims(t) == ndims(b)
@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same \
amount of nodes in the last layer. Otherwise \
Σᵢ bᵢⱼ tᵢₖ won't work."

function Base.show(io::IO, ::MIME"text/plain", x::CompactLuxLayer{:DeepONet})
show(io, x)
b_ = ndims(t) == ndims(b) ? b : reshape(b, size(b, 1), 1, size(b)[2:end]...)
@return dropdims(sum(t .* b_; dims=1); dims=1)
end
end
45 changes: 29 additions & 16 deletions test/deeponet_tests.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,48 @@
@testitem "DeepONet" setup=[SharedTestSetup] begin
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
rng_ = get_stable_rng()
rng = StableRNG(12345)

u = rand(64, 5) |> aType # sensor_points x nb
y = rand(1, 10, 5) |> aType # ndims x N x nb
u = rand(Float32, 64, 5) |> aType # sensor_points x nb
y = rand(Float32, 1, 10, 5) |> aType # ndims x N x nb
out_size = (10, 5)

don_ = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))

ps, st = Lux.setup(rng_, don_) |> dev
ps, st = Lux.setup(rng, deeponet) |> dev

@inferred don_((u, y), ps, st)
@jet don_((u, y), ps, st)
@inferred deeponet((u, y), ps, st)
@jet deeponet((u, y), ps, st)

pred = first(don_((u, y), ps, st))
pred = first(deeponet((u, y), ps, st))
@test size(pred) == out_size

don_ = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))

ps, st = Lux.setup(rng_, don_) |> dev
ps, st = Lux.setup(rng, deeponet) |> dev

@inferred don_((u, y), ps, st)
@jet don_((u, y), ps, st)
@inferred deeponet((u, y), ps, st)
@jet deeponet((u, y), ps, st)

pred = first(don_((u, y), ps, st))
pred = first(deeponet((u, y), ps, st))
@test size(pred) == out_size

don_ = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
ps, st = Lux.setup(rng_, don_) |> dev
@test_throws ArgumentError don_((u, y), ps, st)
ps, st = Lux.setup(rng, deeponet) |> dev
@test_throws ArgumentError deeponet((u, y), ps, st)

@testset "higher-dim input #7" begin
u = ones(Float32, 10, 10, 10) |> aType
v = ones(Float32, 1, 10, 10) |> aType
deeponet = DeepONet(; branch=(10, 10, 10), trunk=(1, 10, 10))
ps, st = Lux.setup(rng, deeponet) |> dev

y, st_ = deeponet((u, v), ps, st)
@test size(y) == (10, 10)

@inferred deeponet((u, v), ps, st)
@jet deeponet((u, v), ps, st)
end
end
end
2 changes: 1 addition & 1 deletion test/fno_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testitem "Fourier Neural Operator" setup=[SharedTestSetup] begin
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
rng = get_stable_rng()
rng = StableRNG(12345)

setups = [
(modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1),
Expand Down
2 changes: 1 addition & 1 deletion test/layers_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testitem "SpectralConv & SpectralKernel" setup=[SharedTestSetup] begin
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
rng = get_stable_rng()
rng = StableRNG(12345)

opconv = [SpectralConv, SpectralKernel]
setups = [
Expand Down
43 changes: 18 additions & 25 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
@testsetup module SharedTestSetup
import Reexport: @reexport

@reexport using Lux, LuxCUDA, LuxAMDGPU, Zygote, Optimisers, Random, StableRNGs, Statistics
@reexport using Lux, LuxCUDA, AMDGPU, Zygote, Optimisers, Random, StableRNGs
using LuxTestUtils: @jet, @test_gradients

CUDA.allowscalar(false)

const BACKEND_GROUP = get(ENV, "BACKEND_GROUP", "All")

cpu_testing() = BACKEND_GROUP == "All" || BACKEND_GROUP == "CPU"
cuda_testing() = (BACKEND_GROUP == "All" || BACKEND_GROUP == "CUDA") && LuxCUDA.functional()
function cuda_testing()
return (BACKEND_GROUP == "All" || BACKEND_GROUP == "CUDA") &&
LuxDeviceUtils.functional(LuxCUDADevice)
end
function amdgpu_testing()
return (BACKEND_GROUP == "All" || BACKEND_GROUP == "AMDGPU") && LuxAMDGPU.functional()
return (BACKEND_GROUP == "All" || BACKEND_GROUP == "AMDGPU") &&
LuxDeviceUtils.functional(LuxAMDGPUDevice)
end

const MODES = begin
# Mode, Array Type, Device Function, GPU?
cpu_mode = ("CPU", Array, LuxCPUDevice(), false)
cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true)
amdgpu_mode = ("AMDGPU", ROCArray, LuxAMDGPUDevice(), true)

modes = []
cpu_testing() && push!(modes, cpu_mode)
cuda_testing() && push!(modes, cuda_mode)
amdgpu_testing() && push!(modes, amdgpu_mode)

cpu_testing() && push!(modes, ("CPU", Array, LuxCPUDevice(), false))
cuda_testing() && push!(modes, ("CUDA", CuArray, LuxCUDADevice(), true))
amdgpu_testing() && push!(modes, ("AMDGPU", ROCArray, LuxAMDGPUDevice(), true))
modes
end

Expand All @@ -36,28 +34,23 @@ function get_default_rng(mode::String)
return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng)
end

get_stable_rng(seed=12345) = StableRNG(seed)
train!(args...; kwargs...) = train!(MSELoss(), AutoZygote(), args...; kwargs...)

default_loss_function(model, ps, x, y) = mean(abs2, y .- model(x, ps))
function train!(loss, backend, model, ps, st, data; epochs=10)
l1 = loss(model, ps, st, first(data))

train!(args...; kwargs...) = train!(default_loss_function, args...; kwargs...)

function train!(loss, model, ps, st, data; epochs=10)
m = StatefulLuxLayer{true}(model, ps, st)

l1 = loss(m, ps, first(data)...)
st_opt = Optimisers.setup(Adam(0.01f0), ps)
tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.01f0))
for _ in 1:epochs, (x, y) in data
_, gs, _, _ = Zygote.gradient(loss, m, ps, x, y)
Optimisers.update!(st_opt, ps, gs)
_, _, _, tstate = Lux.Experimental.single_train_step!(backend, loss, (x, y), tstate)
end
l2 = loss(m, ps, first(data)...)

l2 = loss(model, ps, st, first(data))

return l2, l1
end

export @jet, @test_gradients, check_approx
export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, get_default_rng,
get_stable_rng, train!
train!

end

0 comments on commit 44c39bb

Please sign in to comment.