diff --git a/LocalPreferences.toml b/LocalPreferences.toml new file mode 100644 index 0000000..22c70b2 --- /dev/null +++ b/LocalPreferences.toml @@ -0,0 +1,2 @@ +[LuxTestUtils] +target_modules = ["LuxNeuralOperators", "Lux", "LuxLib"] diff --git a/examples/Burgers/main.jl b/examples/Burgers/main.jl deleted file mode 100644 index b8a1884..0000000 --- a/examples/Burgers/main.jl +++ /dev/null @@ -1,51 +0,0 @@ -using CairoMakie - -# Load the common utilities via Revise.jl if available -if @isdefined(includet) - includet("../common.jl") -else - include("../common.jl") -end - -function train_burgers(; seed=1234, dataset="Burgers_R10", model_type=:fno) - if model_type == :fno - model = FourierNeuralOperator(; chs=(2, 64, 64, 64, 64, 64, 128, 1), modes=(16,), - σ=gelu) - else - error("Unknown model type: $model_type") - end - - trainloader, testloader = get_dataset(dataset; batchsize=512) - - ps, st = Lux.setup(Xoshiro(seed), model) - - opt = OptimiserChain(WeightDecay(1.0f-4), Adam(0.001f0)) - - model, ps, st = train!(model, ps, st, trainloader, testloader, opt; epochs=100) - - return model, ps, st -end - -model, ps, st = train_burgers() -x_data, y_data = get_dataset("Burgers_R10"; no_dataloader=Val(true)); -st_ = Lux.testmode(st) -pred = first(model(x_data, ps, st_)) - -fig = with_theme(theme_latexfonts()) do - fig = Figure(; size=(800, 800)) - - for i in 1:2, j in 1:2 - idx = (i - 1) * 2 + j - ax = Axis(fig[i, j]; xlabel=L"x", ylabel=L"u(x, t_{end})") - - l1 = lines!(ax, x_data[:, 1, idx], y_data[:, 1, idx]; linewidth=3) - l2 = lines!(ax, x_data[:, 1, idx], pred[:, 1, idx]; linewidth=3, linsestyle=:dot, - color=:red) - - if i == 1 && j == 1 - axislegend(ax, [l1, l2], ["Ground Truth", "Prediction"]) - end - end - - return fig -end diff --git a/examples/Project.toml b/examples/Project.toml deleted file mode 100644 index 96253ac..0000000 --- a/examples/Project.toml +++ /dev/null @@ -1,14 +0,0 @@ -[deps] -BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" -Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -LuxNeuralOperators = "c0ba2cc5-a80b-46ec-84b3-091eb317b01d" -MAT = "23992714-dd62-5051-b70f-ba57cb901cac" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/common.jl b/examples/common.jl deleted file mode 100644 index 434571f..0000000 --- a/examples/common.jl +++ /dev/null @@ -1,117 +0,0 @@ -using Fetch, DataDeps, MAT, MLUtils -using Lux, LuxNeuralOperators, LuxCUDA -using Optimisers, Zygote -using TimerOutputs, ProgressLogging, Random -import BSON: @save, @load - -const gdev = gpu_device() -const cdev = cpu_device() - -# Make DataLoader work with ProgressLogging -Base.size(d::DataLoader) = (length(d),) - -function register_if_notpresent(regname::String, datadep) - name = datadep.name - name != regname && return - haskey(DataDeps.registry, name) && return - return register(datadep) -end - -function register_dataset(dataset::String) - register_if_notpresent(dataset, - DataDep("Burgers_R10", """Burgers R10""", - "https://drive.google.com/file/d/16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe/view?usp=sharing"; - fetch_method=gdownload, post_fetch_method=unpack)) - - register_if_notpresent(dataset, - DataDep("Burgers_V1000", """Burgers V1000""", - "https://drive.google.com/file/d/1G9IW_2shmfgprPYISYt_YS8xa87p4atu/view?usp=sharing"; - fetch_method=gdownload, post_fetch_method=unpack)) - - register_if_notpresent(dataset, - DataDep("Burgers_V100", """Burgers V100""", - "https://drive.google.com/file/d/1nzT0-Tu-LS2SoMUCcmO1qyjQd6WC9OdJ/view?usp=sharing"; - fetch_method=gdownload, post_fetch_method=unpack)) - - return -end - -function get_dataset(dataset::String; return_eltype::Type{T}=Float32, batchsize::Int = 128, - ratio::AbstractFloat=0.9, no_dataloader::Val{DT} = Val(false)) where {T, DT} - register_dataset(dataset) - root = @datadep_str dataset - - if dataset == "Burgers_R10" - n = 2048 - Δsamples = 2^3 - grid_size = div(2^13, Δsamples) - - file = matopen(joinpath(root, "burgers_data_R10.mat")) - x_data = Matrix{T}(read(file, "a")[1:n, 1:Δsamples:end]') - y_data = Matrix{T}(read(file, "u")[1:n, 1:Δsamples:end]') - close(file) - - x_loc_data = Array{T, 3}(undef, 2, grid_size, n) - x_loc_data[1, :, :] = reshape(repeat(LinRange(0, 1, grid_size), n), (grid_size, n)) - x_loc_data[2, :, :] .= x_data - - x, y = x_loc_data, reshape(y_data, 1, :, n) - else - error("Not Implemented Dataset: $(dataset)") - end - - DT && return x, y - - data_train, data_test = splitobs((x, y); at=ratio) - - trainloader = DataLoader(data_train; batchsize, shuffle=true) - testloader = DataLoader(data_test; batchsize, shuffle=true) - - return trainloader, testloader -end - -@inline function l₂_loss(x, y) - feature_dims = 2:(ndims(y) - 1) - - loss = sum(sqrt, sum(abs2, x .- y; dims = feature_dims)) - y_norm = sum(sqrt, sum(abs2, y; dims = feature_dims)) - - return loss / y_norm -end - -@inline l₂_loss(m, ps, x, y) = l₂_loss(m(x, ps), y) - -function train!(model, ps, st, trainloader, testloader, opt; epochs = 500) - ps = ps |> gdev - st = st |> gdev - st_opt = Optimisers.setup(opt, ps) - - model2 = Lux.Experimental.StatefulLuxLayer(model, ps, st) - - @progress "Epochs" for epoch in 1:epochs - @progress name="Training" for (i, (x, y)) in enumerate(trainloader) - x = x |> gdev - y = y |> gdev - l, gs = Zygote.withgradient(l₂_loss, model2, ps, x, y) - ∂ps = gs[2] - mod1(i, 10) == 1 && @info "Epoch: $epoch, Iter: $i, Loss: $l" - Optimisers.update!(st_opt, ps, ∂ps) - end - - st_ = Lux.testmode(model2.st) - model_inf = Lux.Experimental.StatefulLuxLayer(model, ps, st_) - - total_loss = 0.0 - total_data = 0 - @progress name="Inference" for (x, y) in testloader - x = x |> gdev - y = y |> gdev - total_loss += l₂_loss(model_inf, ps, x, y) - total_data += 1 - end - - @info "Epoch: $epoch, Loss: $(total_loss / total_data)" - end - - return model, ps |> cdev, st |> cdev -end \ No newline at end of file diff --git a/src/transform.jl b/src/transform.jl index fbedda8..23d9cbb 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -28,7 +28,7 @@ end @inline truncate_modes(ft::FourierTransform, x_fft::AbstractArray) = low_pass(ft, x_fft) -function inverse(ft::FourierTransform, x_fft::AbstractArray{T, N}, - M::NTuple{N, Int64}) where {T, N} +function inverse( + ft::FourierTransform, x_fft::AbstractArray{T, N}, M::NTuple{N, Int64}) where {T, N} return real(irfft(x_fft, first(M), 1:ndims(ft))) end diff --git a/test/fno_tests.jl b/test/fno_tests.jl index f1e8a2f..a122fa6 100644 --- a/test/fno_tests.jl +++ b/test/fno_tests.jl @@ -1,20 +1,23 @@ @testitem "Fourier Neural Operator" setup=[SharedTestSetup] begin @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = get_default_rng(mode) + rng = get_stable_rng() setups = [ - (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), x_size=(2, 1024, 5), - y_size=(1, 1024, 5), permuted=Val(false)), - (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), x_size=(1024, 2, 5), - y_size=(1024, 1, 5), permuted=Val(true))] + (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), + x_size=(2, 1024, 5), y_size=(1, 1024, 5), permuted=Val(false)), + (modes=(16,), chs=(2, 64, 64, 64, 64, 64, 128, 1), + x_size=(1024, 2, 5), y_size=(1024, 1, 5), permuted=Val(true))] @testset "$(length(setup.modes))D: permuted = $(setup.permuted)" for setup in setups fno = FourierNeuralOperator(; setup.chs, setup.modes, setup.permuted) - x = rand(rng, Float32, setup.x_size...) - y = rand(rng, Float32, setup.y_size...) + x = rand(rng, Float32, setup.x_size...) |> aType + y = rand(rng, Float32, setup.y_size...) |> aType - ps, st = Lux.setup(rng, fno) + ps, st = Lux.setup(rng, fno) |> dev + + @inferred fno(x, ps, st) + @jet fno(x, ps, st) @test size(first(fno(x, ps, st))) == setup.y_size diff --git a/test/layers_tests.jl b/test/layers_tests.jl index 39d04c1..db0db7a 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -1,6 +1,6 @@ @testitem "SpectralConv & SpectralKernel" setup=[SharedTestSetup] begin @testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES - rng = get_default_rng(mode) + rng = get_stable_rng() opconv = [SpectralConv, SpectralKernel] setups = [ @@ -11,7 +11,7 @@ (; m=(10, 10), permuted=Val(true), x_size=(22, 22, 1, 5), y_size=(22, 22, 64, 5))] - @testset "$(op) $(length(setup.modes))D: permuted = $(setup.permuted)" for setup in setups, + @testset "$(op) $(length(setup.m))D: permuted = $(setup.permuted)" for setup in setups, op in opconv p = Lux.__unwrap_val(setup.permuted) @@ -22,11 +22,12 @@ l1 = p ? Conv(ntuple(_ -> 1, length(setup.m)), in_chs => first(ch)) : Dense(in_chs => first(ch)) m = Chain(l1, op(ch, setup.m; setup.permuted)) - ps, st = Lux.setup(rng, m) + ps, st = Lux.setup(rng, m) |> dev - x = rand(rng, Float32, setup.x_size...) + x = rand(rng, Float32, setup.x_size...) |> aType @test size(first(m(x, ps, st))) == setup.y_size @inferred m(x, ps, st) + @jet m(x, ps, st) data = [(x, rand(rng, Float32, setup.y_size...))] l2, l1 = train!(m, ps, st, data; epochs=10)