Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: use TestExtras #1099

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/LuxLib/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down Expand Up @@ -61,5 +62,6 @@ Static = "0.8.4, 1"
StaticArrays = "1.9.7"
Statistics = "1.10"
Test = "1.10"
TestExtras = "0.3.1"
Tracker = "0.2.36"
Zygote = "0.6.70"
12 changes: 6 additions & 6 deletions lib/LuxLib/test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@
@test eltype(y2) == T
@test eltype(y3) == T

@test @inferred(apply_act(f, x)) isa Any
@test @inferred(apply_act_fast(f, x)) isa Any
@test @inferred(apply_act_fast2(f, x)) isa Any
@constinferred apply_act(f, x)
@constinferred apply_act_fast(f, x)
@constinferred apply_act_fast2(f, x)

@jet apply_act_fast(f, x)
@jet apply_act_fast2(f, x)

@test @inferred(Zygote.gradient(apply_act, f, x)) isa Any
@constinferred Zygote.gradient(apply_act, f, x)
if f !== lisht
@test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any
@constinferred Zygote.gradient(apply_act_fast, f, x)
end
@test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any
@constinferred Zygote.gradient(apply_act_fast2, f, x)

@test_gradients(apply_act, f, x; atol, rtol)
@test_gradients(apply_act_fast, f, x; atol, rtol, skip_backends=[AutoEnzyme()])
Expand Down
38 changes: 14 additions & 24 deletions lib/LuxLib/test/common_ops/bias_act_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b))
bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b))

struct __Fix1{F, A}
f::F
act::A
end
(f::__Fix1)(x, b) = f.f(f.act, x, b)

@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "$act, $T, $sz" for act in [
identity, relu, sigmoid, sigmoid_fast, softplus,
Expand All @@ -27,38 +21,34 @@
y2 = bias_act_loss2(act, x, b)
y3 = bias_act_loss3(act, x, b)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test y1≈y2 atol=atol rtol=rtol
@test y1≈y3 atol=atol rtol=rtol
@test eltype(y1) == T
@test eltype(y2) == T
@test eltype(y3) == T

@test @inferred(bias_act_loss1(act, x, b)) isa Any
@test @inferred(bias_act_loss2(act, x, b)) isa Any
@test @inferred(bias_act_loss3(act, x, b)) isa Any
@constinferred bias_act_loss1(act, x, b)
@constinferred bias_act_loss2(act, x, b)
@constinferred bias_act_loss3(act, x, b)

@jet bias_act_loss2(act, x, b)
@jet bias_act_loss3(act, x, b)

if act !== lisht && T != Float16
@test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any
@test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any
if act !== lisht
@constinferred Zygote.gradient(bias_act_loss2, act, x, b)
@constinferred Zygote.gradient(bias_act_loss3, act, x, b)
end

@test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol,
soft_fail=fp16 ? [AutoFiniteDiff()] : [])
@test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol,
soft_fail=fp16 ? [AutoFiniteDiff()] : [])
@test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol,
soft_fail=fp16 ? [AutoFiniteDiff()] : [])
@test_gradients(bias_act_loss1, act, x, b; atol, rtol)
@test_gradients(bias_act_loss2, act, x, b; atol, rtol)
@test_gradients(bias_act_loss3, act, x, b; atol, rtol)

∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b)
∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b)
∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b)
_, ∂x1, ∂b1 = Zygote.pullback(bias_act_loss1, act, x, b)
_, ∂x2, ∂b2 = Zygote.pullback(bias_act_loss2, act, x, b)
_, ∂x3, ∂b3 = Zygote.pullback(bias_act_loss3, act, x, b)

@test ∂x1≈∂x2 atol=atol rtol=rtol
@test ∂x1≈∂x3 atol=atol rtol=rtol
Expand Down
11 changes: 5 additions & 6 deletions lib/LuxLib/test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,15 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,
@jet fused_conv_bias_activation(activation, weight, x, bias, cdims)

if mode != "amdgpu" && activation !== anonact
@test @inferred(Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims
)) isa Any
@test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any
@test @inferred(Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims)) isa Any

else
try
@inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims))
@test true
@test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any
@test @inferred(Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims)) isa Any

catch e
e isa ErrorException || rethrow()
@test_broken false
@test_broken @inferred(Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims
))
end
end

Expand Down
8 changes: 4 additions & 4 deletions lib/LuxLib/test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,23 @@ end
end

@testitem "Fused Dense: StaticArrays" tags=[:dense] begin
using StaticArrays, NNlib
using StaticArrays, NNlib, TestExtras

x = @SArray rand(2, 4)
weight = @SArray rand(3, 2)
bias = @SArray rand(3)

@test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray
@constinferred fused_dense_bias_activation(relu, weight, x, bias)
end

@testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin
using JLArrays, NNlib
using JLArrays, NNlib, TestExtras

x = JLArray(rand(Float32, 2, 4))
weight = JLArray(rand(Float32, 3, 2))
bias = JLArray(rand(Float32, 3))

@test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray
@constinferred fused_dense_bias_activation(relu, weight, x, bias)
@test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp
end

Expand Down
25 changes: 11 additions & 14 deletions lib/LuxLib/test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

x = randn(rng, T, x_shape) |> aType

@test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any
@constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims)

y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims)

Expand All @@ -21,10 +21,10 @@
@test rng != rng_

@jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims)))
@test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any
@constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims)

__f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims)))
@test @inferred(Zygote.gradient(__f, x)) isa Any
@constinferred Zygote.gradient(__f, x)

@test_gradients(sumabs2first,
dropout, rng, x, T(0.5), Val(true), T(2), dims; atol=1.0f-3, rtol=1.0f-3)
Expand Down Expand Up @@ -54,8 +54,7 @@ end
mask = rand(T, x_shape) |> aType

# Update mask
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any
@constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)
Expand All @@ -69,7 +68,7 @@ end

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any
@constinferred Zygote.gradient(__f, x, mask)

@test_gradients(sumabs2first,
dropout, rng, x, LuxTestUtils.Constant(mask), T(0.5), Val(true), Val(true),
Expand All @@ -79,8 +78,7 @@ end
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)))

# Try using mask if possible (possible!!)
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any
@constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)
Expand All @@ -94,7 +92,7 @@ end

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any
@constinferred Zygote.gradient(__f, x, mask)

@test_gradients(sumabs2first,
dropout, rng, x, LuxTestUtils.Constant(mask),
Expand All @@ -107,8 +105,7 @@ end
mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

# Testing Mode
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any
@constinferred dropout(rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)
Expand All @@ -135,7 +132,7 @@ end

x = randn(rng, T, x_shape) |> aType

@test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any
@constinferred alpha_dropout(rng, x, T(0.5), Val(true))

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true))

Expand All @@ -146,13 +143,13 @@ end
@test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2

__f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true))))
@test @inferred(Zygote.gradient(__f, x)) isa Any
@constinferred Zygote.gradient(__f, x)

@test_gradients(sumabs2first,
alpha_dropout, rng, x, T(0.5), Val(true); atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
@test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any
@constinferred alpha_dropout(rng, x, T(0.5), Val(false))

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false))

Expand Down
6 changes: 2 additions & 4 deletions lib/LuxLib/test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module BatchNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, TestExtras

function setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool)
x = gen_f(T, sz) |> aType
Expand Down Expand Up @@ -89,10 +89,8 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act,
end

if anonact !== act
lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm(
x, sc, b, rm, rv, tr, act, ϵ)))
@test @inferred(Zygote.gradient(
lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any
sumabs2first, x, scale, bias, rm, rv, training, act, epsilon)) isa Any
end
end

Expand Down
4 changes: 2 additions & 2 deletions lib/LuxLib/test/normalization/groupnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu)
@jet groupnorm(x, scale, bias, groups, act, epsilon)

if anonact !== act
lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ))
@test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any
@test @inferred(Zygote.gradient(
sumabs2groupnorm, x, scale, bias, groups, act, epsilon)) isa Any
end

@test y isa aType{T, length(sz)}
Expand Down
5 changes: 2 additions & 3 deletions lib/LuxLib/test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)
@jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)

if anonact !== act && is_training(training)
lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm(
x, sc, b, rm, rv, Val(true), act, m, ϵ)))
@test @inferred(Zygote.gradient(
lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)) isa Any
sumabs2instancenorm, x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa
Any
end

@test y isa aType{T, length(sz)}
Expand Down
6 changes: 3 additions & 3 deletions lib/LuxLib/test/normalization/layernorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module LayerNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics, TestExtras
using LuxTestUtils: check_approx

function setup_layernorm(gen_f, aType, T, x_size, affine_shape, expand_dims::Bool=true)
Expand Down Expand Up @@ -60,8 +60,8 @@ function run_layernorm_testing_core(
soft_fail=[AutoFiniteDiff()])

if anonact !== act
lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ))
@test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any
@test @inferred(Zygote.gradient(
sumabs2layernorm, x, scale, bias, act, dims, epsilon)) isa Any
end
end

Expand Down
2 changes: 1 addition & 1 deletion lib/LuxLib/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import Reexport: @reexport

using LuxLib, MLDataDevices
@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib
@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib, TestExtras

LuxTestUtils.jet_target_modules!(["LuxLib"])

Expand Down
2 changes: 2 additions & 0 deletions lib/MLDataDevices/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -39,5 +40,6 @@ ReverseDiff = "1.15"
SafeTestsets = "0.1"
SparseArrays = "1.10"
Test = "1.10"
TestExtras = "0.3.1"
Tracker = "0.2.36"
Zygote = "0.6.69"
4 changes: 2 additions & 2 deletions lib/MLDataDevices/test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MLDataDevices, Random, Test
using MLDataDevices, Random, Test, TestExtras
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
Expand Down Expand Up @@ -122,7 +122,7 @@ using FillArrays, Zygote # Extensions
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))}
@constinferred Val{parameterless_type(typeof(device))} return_val(ps)
end
end

Expand Down
4 changes: 2 additions & 2 deletions lib/MLDataDevices/test/cuda_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MLDataDevices, Random, Functors, Test
using MLDataDevices, Random, Functors, Test, TestExtras
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
Expand Down Expand Up @@ -144,7 +144,7 @@ using FillArrays, Zygote # Extensions
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))}
@constinferred Val{parameterless_type(typeof(device))} return_val(ps)

return_val2(x) = Val(get_device(x))
@test_throws ErrorException @inferred(return_val2(ps))
Expand Down
6 changes: 3 additions & 3 deletions lib/MLDataDevices/test/metal_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MLDataDevices, Random, Test
using MLDataDevices, Random, Test, TestExtras
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
Expand Down Expand Up @@ -108,10 +108,10 @@ using FillArrays, Zygote # Extensions
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))}
@constinferred Val{parameterless_type(typeof(device))} return_val(ps)

return_val2(x) = Val(get_device(x))
@test @inferred(return_val2(ps)) isa Val{get_device(x)}
@constinferred Val{get_device(x)} return_val2(ps)
end
end

Expand Down
Loading
Loading