Skip to content

Commit

Permalink
chore: minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 8, 2024
1 parent 2ef10d5 commit 55fc1fe
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ LuxCUDA = "0.3.3"
MAT = "0.10.7"
MLUtils = "0.4.4"
NeuralOperators = "0.5"
Optimisers = "0.3.3"
Optimisers = "0.3.3, 0.4"
Printf = "1.10"
PythonCall = "0.9.23"
Zygote = "0.6.71"
18 changes: 9 additions & 9 deletions docs/src/tutorials/xla_compilation.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
```@example xla_compilation
using NeuralOperators, Lux, Random, Enzyme, Reactant
function sumabs2first(model, ps, st, (u, y))
z, _ = model((u, y), ps, st)
function sumabs2first(model, ps, st, x)
z, _ = model(x, ps, st)
return sum(abs2, z)
end
Expand All @@ -15,14 +15,14 @@ dev = reactant_device()

```@example xla_compilation
deeponet = DeepONet()
ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev
ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev;
u = rand(Float32, 64, 1024) |> dev
y = rand(Float32, 1, 128, 1024) |> dev
u = rand(Float32, 64, 1024) |> dev;
y = rand(Float32, 1, 128, 1024) |> dev;
nothing # hide
deeponet_compiled = @compile deeponet((u, y), ps, st)
deeponet_compiled((u, y), ps, st)
deeponet_compiled((u, y), ps, st)[1]
```

Computing the gradient of the DeepONet model.
Expand All @@ -49,12 +49,12 @@ end

```@example xla_compilation
fno = FourierNeuralOperator()
ps, st = Lux.setup(Random.default_rng(), fno) |> dev
ps, st = Lux.setup(Random.default_rng(), fno) |> dev;
x = rand(Float32, 2, 1024, 5) |> dev
x = rand(Float32, 2, 1024, 5) |> dev;
fno_compiled = @compile fno(x, ps, st)
fno_compiled(x, ps, st)
fno_compiled(x, ps, st)[1]
```

Computing the gradient of the FourierNeuralOperator model.
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ LuxCore = "1"
LuxLib = "1.2"
LuxTestUtils = "1.1.2"
MLDataDevices = "1"
Optimisers = "0.3.3"
Optimisers = "0.3.3, 0.4"
Pkg = "1.10"
Preferences = "1"
Random = "1.10"
Expand Down

0 comments on commit 55fc1fe

Please sign in to comment.