diff --git a/docs/Project.toml b/docs/Project.toml index 7aaae9b..204864c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -25,7 +25,7 @@ LuxCUDA = "0.3.3" MAT = "0.10.7" MLUtils = "0.4.4" NeuralOperators = "0.5" -Optimisers = "0.3.3" +Optimisers = "0.3.3, 0.4" Printf = "1.10" PythonCall = "0.9.23" Zygote = "0.6.71" diff --git a/docs/src/tutorials/xla_compilation.md b/docs/src/tutorials/xla_compilation.md index ed3b7b8..57b3620 100644 --- a/docs/src/tutorials/xla_compilation.md +++ b/docs/src/tutorials/xla_compilation.md @@ -3,8 +3,8 @@ ```@example xla_compilation using NeuralOperators, Lux, Random, Enzyme, Reactant -function sumabs2first(model, ps, st, (u, y)) - z, _ = model((u, y), ps, st) +function sumabs2first(model, ps, st, x) + z, _ = model(x, ps, st) return sum(abs2, z) end @@ -15,14 +15,14 @@ dev = reactant_device() ```@example xla_compilation deeponet = DeepONet() -ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev +ps, st = Lux.setup(Random.default_rng(), deeponet) |> dev; -u = rand(Float32, 64, 1024) |> dev -y = rand(Float32, 1, 128, 1024) |> dev +u = rand(Float32, 64, 1024) |> dev; +y = rand(Float32, 1, 128, 1024) |> dev; nothing # hide deeponet_compiled = @compile deeponet((u, y), ps, st) -deeponet_compiled((u, y), ps, st) +deeponet_compiled((u, y), ps, st)[1] ``` Computing the gradient of the DeepONet model. @@ -49,12 +49,12 @@ end ```@example xla_compilation fno = FourierNeuralOperator() -ps, st = Lux.setup(Random.default_rng(), fno) |> dev +ps, st = Lux.setup(Random.default_rng(), fno) |> dev; -x = rand(Float32, 2, 1024, 5) |> dev +x = rand(Float32, 2, 1024, 5) |> dev; fno_compiled = @compile fno(x, ps, st) -fno_compiled(x, ps, st) +fno_compiled(x, ps, st)[1] ``` Computing the gradient of the FourierNeuralOperator model. diff --git a/test/Project.toml b/test/Project.toml index 8890976..b863b98 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,7 +30,7 @@ LuxCore = "1" LuxLib = "1.2" LuxTestUtils = "1.1.2" MLDataDevices = "1" -Optimisers = "0.3.3" +Optimisers = "0.3.3, 0.4" Pkg = "1.10" Preferences = "1" Random = "1.10"