Skip to content

Commit

Permalink
docs: rename to reactant compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 10, 2024
1 parent 55fc1fe commit 528fca0
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 30 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ concurrency:

jobs:
ci:
name: Julia ${{ matrix.version }} - ${{ matrix.os }}
if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }}
runs-on: ${{ matrix.os }}
strategy:
Expand Down Expand Up @@ -60,7 +59,6 @@ jobs:

downgrade:
if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }}
name: Downgrade Julia ${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Lux = "1.2.1"
LuxCore = "1.1"
LuxLib = "1.3.7"
MLDataDevices = "1.5"
NNlib = "0.9.21"
NNlib = "0.9.24"
Random = "1.10"
Static = "1.1.1"
WeightInitializers = "1"
Expand Down
2 changes: 1 addition & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pages = [
"NOMAD" => "models/nomad.md"
],
"Tutorials" => [
"XLA Compilation" => "tutorials/xla_compilation.md",
"XLA Compilation" => "tutorials/reactant.md",
"Burgers Equation" => "tutorials/burgers.md"
],
"API Reference" => "api.md"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Compiling NeuralOperators.jl using Reactant.jl

```@example xla_compilation
```@example reactant
using NeuralOperators, Lux, Random, Enzyme, Reactant
function sumabs2first(model, ps, st, x)
Expand All @@ -13,12 +13,12 @@ dev = reactant_device()

## Compiling DeepONet

```@example xla_compilation
```@example reactant
deeponet = DeepONet()
ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev;
u = rand(Float32, 64, 1024) |> dev;
y = rand(Float32, 1, 128, 1024) |> dev;
u = rand(Float32, 64, 32) |> dev;
y = rand(Float32, 1, 128, 32) |> dev;
nothing # hide
deeponet_compiled = @compile deeponet((u, y), ps, st)
Expand All @@ -27,18 +27,11 @@ deeponet_compiled((u, y), ps, st)[1]

Computing the gradient of the DeepONet model.

```@example xla_compilation
```@example reactant
function ∇deeponet(model, ps, st, (u, y))
dps = Enzyme.make_zero(ps)
Enzyme.autodiff(
Enzyme.Reverse,
sumabs2first,
Const(model),
Duplicated(ps, dps),
Const(st),
Const((u, y))
return Enzyme.gradient(
Enzyme.Reverse, Const(sumabs2first), Const(model), ps, Const(st), Const((u, y))
)
return dps
end
∇deeponet_compiled = @compile ∇deeponet(deeponet, ps, st, (u, y))
Expand All @@ -47,30 +40,23 @@ end

## Compiling FourierNeuralOperator

```@example xla_compilation
```@example reactant
fno = FourierNeuralOperator()
ps, st = Lux.setup(Random.default_rng(), fno) |> dev;
x = rand(Float32, 2, 1024, 5) |> dev;
x = rand(Float32, 2, 32, 5) |> dev;
fno_compiled = @compile fno(x, ps, st)
fno_compiled(x, ps, st)[1]
```

Computing the gradient of the FourierNeuralOperator model.

```@example xla_compilation
```@example reactant
function ∇fno(model, ps, st, x)
dps = Enzyme.make_zero(ps)
Enzyme.autodiff(
Enzyme.Reverse,
sumabs2first,
Const(model),
Duplicated(ps, dps),
Const(st),
Const(x)
return Enzyme.gradient(
Enzyme.Reverse, Const(sumabs2first), Const(model), ps, Const(st), Const(x)
)
return dps
end
∇fno_compiled = @compile ∇fno(fno, ps, st, x)
Expand Down

0 comments on commit 528fca0

Please sign in to comment.