From 528fca0ed22c677b53ff889ad6cdb74ccec521c7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 21:16:41 -0500 Subject: [PATCH] docs: rename to reactant compilation --- .github/workflows/CI.yml | 2 - Project.toml | 2 +- docs/pages.jl | 2 +- .../{xla_compilation.md => reactant.md} | 38 ++++++------------- 4 files changed, 14 insertions(+), 30 deletions(-) rename docs/src/tutorials/{xla_compilation.md => reactant.md} (61%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d14f91d..b8f2f0d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,6 @@ concurrency: jobs: ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: @@ -60,7 +59,6 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia ${{ matrix.version }} runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/Project.toml b/Project.toml index 15d269a..195fd47 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ Lux = "1.2.1" LuxCore = "1.1" LuxLib = "1.3.7" MLDataDevices = "1.5" -NNlib = "0.9.21" +NNlib = "0.9.24" Random = "1.10" Static = "1.1.1" WeightInitializers = "1" diff --git a/docs/pages.jl b/docs/pages.jl index 87a4d3c..e0fbea1 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -6,7 +6,7 @@ pages = [ "NOMAD" => "models/nomad.md" ], "Tutorials" => [ - "XLA Compilation" => "tutorials/xla_compilation.md", + "XLA Compilation" => "tutorials/reactant.md", "Burgers Equation" => "tutorials/burgers.md" ], "API Reference" => "api.md" diff --git a/docs/src/tutorials/xla_compilation.md b/docs/src/tutorials/reactant.md similarity index 61% rename from docs/src/tutorials/xla_compilation.md rename to docs/src/tutorials/reactant.md index 57b3620..a90836b 100644 --- a/docs/src/tutorials/xla_compilation.md +++ b/docs/src/tutorials/reactant.md @@ -1,6 +1,6 @@ # Compiling NeuralOperators.jl using Reactant.jl -```@example xla_compilation +```@example reactant using NeuralOperators, Lux, Random, Enzyme, Reactant function sumabs2first(model, ps, st, x) @@ -13,12 +13,12 @@ dev = reactant_device() ## Compiling DeepONet -```@example xla_compilation +```@example reactant deeponet = DeepONet() ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev; -u = rand(Float32, 64, 1024) |> dev; -y = rand(Float32, 1, 128, 1024) |> dev; +u = rand(Float32, 64, 32) |> dev; +y = rand(Float32, 1, 128, 32) |> dev; nothing # hide deeponet_compiled = @compile deeponet((u, y), ps, st) @@ -27,18 +27,11 @@ deeponet_compiled((u, y), ps, st)[1] Computing the gradient of the DeepONet model. -```@example xla_compilation +```@example reactant function ∇deeponet(model, ps, st, (u, y)) - dps = Enzyme.make_zero(ps) - Enzyme.autodiff( - Enzyme.Reverse, - sumabs2first, - Const(model), - Duplicated(ps, dps), - Const(st), - Const((u, y)) + return Enzyme.gradient( + Enzyme.Reverse, Const(sumabs2first), Const(model), ps, Const(st), Const((u, y)) ) - return dps end ∇deeponet_compiled = @compile ∇deeponet(deeponet, ps, st, (u, y)) @@ -47,11 +40,11 @@ end ## Compiling FourierNeuralOperator -```@example xla_compilation +```@example reactant fno = FourierNeuralOperator() ps, st = Lux.setup(Random.default_rng(), fno) |> dev; -x = rand(Float32, 2, 1024, 5) |> dev; +x = rand(Float32, 2, 32, 5) |> dev; fno_compiled = @compile fno(x, ps, st) fno_compiled(x, ps, st)[1] @@ -59,18 +52,11 @@ fno_compiled(x, ps, st)[1] Computing the gradient of the FourierNeuralOperator model. -```@example xla_compilation +```@example reactant function ∇fno(model, ps, st, x) - dps = Enzyme.make_zero(ps) - Enzyme.autodiff( - Enzyme.Reverse, - sumabs2first, - Const(model), - Duplicated(ps, dps), - Const(st), - Const(x) + return Enzyme.gradient( + Enzyme.Reverse, Const(sumabs2first), Const(model), ps, Const(st), Const(x) ) - return dps end ∇fno_compiled = @compile ∇fno(fno, ps, st, x)