Skip to content

Commit 9aeaef5

Browse files
committed
test: enable Enzyme testing partly
1 parent 4af3a5d commit 9aeaef5

6 files changed

+19
-14
lines changed

.JuliaFormatter.toml

-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@ format_docstrings = true
77
separate_kwargs_with_semicolon = true
88
always_for_in = true
99
annotate_untyped_fields_with_any = false
10-
join_lines_based_on_source = false

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
Manifest.toml
2+
Manifest-v*.toml
23
.vscode
34
wip
45
examples

test/deeponet_tests.jl

+9-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
(u_size=(64, 3, 5), y_size=(4, 10, 5), out_size=(3, 10, 5),
1111
branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Vector"),
1212
(u_size=(64, 4, 3, 3, 5), y_size=(4, 10, 5), out_size=(4, 3, 3, 10, 5),
13-
branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Tensor")]
13+
branch=(64, 32, 32, 16), trunk=(4, 8, 8, 16), name="Tensor")
14+
]
1415

1516
@testset "$(setup.name)" for setup in setups
1617
u = rand(Float32, setup.u_size...) |> aType
@@ -34,7 +35,8 @@
3435
additional=Dense(16 => 4), name="Scalar II"),
3536
(u_size=(64, 3, 5), y_size=(8, 10, 5), out_size=(4, 3, 10, 5),
3637
branch=(64, 32, 32, 16), trunk=(8, 8, 8, 16),
37-
additional=Dense(16 => 4), name="Vector")]
38+
additional=Dense(16 => 4), name="Vector")
39+
]
3840

3941
@testset "Additional layer: $(setup.name)" for setup in setups
4042
u = rand(Float32, setup.u_size...) |> aType
@@ -50,16 +52,17 @@
5052
@test setup.out_size == size(pred)
5153

5254
__f = (u, y, ps) -> sum(abs2, first(deeponet((u, y), ps, st)))
53-
test_gradients(
54-
__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
55+
@test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3)
5556
end
5657

5758
@testset "Embedding layer mismatch" begin
5859
u = rand(Float32, 64, 5) |> aType
5960
y = rand(Float32, 1, 10, 5) |> aType
6061

61-
deeponet = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
62-
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
62+
deeponet = DeepONet(
63+
Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
64+
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16))
65+
)
6366

6467
ps, st = Lux.setup(rng, deeponet) |> dev
6568
@test_throws ArgumentError deeponet((u, y), ps, st)

test/layers_tests.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
(; m=(10, 10), permuted=Val(false),
1010
x_size=(1, 22, 22, 5), y_size=(64, 22, 22, 5)),
1111
(; m=(10, 10), permuted=Val(true),
12-
x_size=(22, 22, 1, 5), y_size=(22, 22, 64, 5))]
12+
x_size=(22, 22, 1, 5), y_size=(22, 22, 64, 5))
13+
]
1314

1415
@testset "$(op) $(length(setup.m))D: permuted = $(setup.permuted)" for setup in setups,
1516
op in opconv
@@ -37,8 +38,8 @@
3738
end
3839

3940
__f = (x, ps) -> sum(abs2, first(m(x, ps, st)))
40-
test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,
41-
skip_backends=[AutoEnzyme(), AutoTracker(), AutoReverseDiff()])
41+
@test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3,
42+
skip_backends=[AutoTracker(), AutoEnzyme(), AutoReverseDiff()])
4243
end
4344
end
4445
end

test/nomad_tests.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
(u_size=(1, 5), y_size=(1, 5), out_size=(1, 5),
77
approximator=(1, 16, 16, 15), decoder=(16, 8, 4, 1), name="Scalar"),
88
(u_size=(8, 5), y_size=(2, 5), out_size=(8, 5),
9-
approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), name="Vector")]
9+
approximator=(8, 32, 32, 16), decoder=(18, 16, 8, 8), name="Vector")
10+
]
1011

1112
@testset "$(setup.name)" for setup in setups
1213
u = rand(Float32, setup.u_size...) |> aType
@@ -21,8 +22,7 @@
2122
@test setup.out_size == size(pred)
2223

2324
__f = (u, y, ps) -> sum(abs2, first(nomad((u, y), ps, st)))
24-
test_gradients(
25-
__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()])
25+
@test_gradients(__f, u, y, ps; atol=1.0f-3, rtol=1.0f-3)
2626
end
2727
end
2828
end

test/utils_tests.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
additional=Dense(16 => 4), name="additional : Vector"),
2222
(b_size=(16, 4, 3, 3, 5), t_size=(16, 10, 5), out_size=(3, 4, 3, 4, 10, 5),
2323
additional=Chain(Dense(16 => 4), ReshapeLayer((3, 4, 3, 4, 10))),
24-
name="additional : Tensor")]
24+
name="additional : Tensor")
25+
]
2526

2627
@testset "project : $(setup.name)" for setup in setups
2728
b = rand(Float32, setup.b_size...) |> aType

0 commit comments

Comments
 (0)