Skip to content

Commit

Permalink
test: more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 11, 2025
1 parent 4b0fa0c commit 6413484
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 17 deletions.
4 changes: 3 additions & 1 deletion ext/LuxReactantExt/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ function (e::Lux.Embedding)(x::TracedRNumber{<:Reactant.ReactantInt}, ps, st::Na
end

# Recurrent Layers
# TODO: Once we can elimiate dead-args in while loop we should remove this case and only

Check warning on line 7 in ext/LuxReactantExt/layers.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"elimiate" should be "eliminate".
# use the later function for maintainence purposes.

Check warning on line 8 in ext/LuxReactantExt/layers.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"maintainence" should be "maintenance".
function (r::Lux.Recurrence{False})(x::AnyTracedRArray, ps, st::NamedTuple)
if r.ordering isa Lux.TimeLastIndex ||
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2)
Expand All @@ -27,7 +29,7 @@ end

function (r::Lux.Recurrence{True})(x::AnyTracedRArray, ps, st::NamedTuple)
if r.ordering isa Lux.TimeLastIndex ||
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2)
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2)
idxs = ntuple(Returns(Colon()), ndims(x) - 1)
(out, carry), st = r.cell(x[idxs..., 1], ps, st)
sequence = similar(out, size(out)..., size(x, ndims(x)))
Expand Down
38 changes: 22 additions & 16 deletions test/reactant/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,28 @@ end
dev = reactant_device(; force=true)

@testset for cell in (RNNCell, LSTMCell, GRUCell)
model = Recurrence(cell(4 => 4))
ps, st = Lux.setup(rng, model)
ps_ra, st_ra = (ps, st) |> dev
x = rand(Float32, 4, 16, 12)
x_ra = x |> dev

y_ra, _ = @jit model(x_ra, ps_ra, st_ra)
y, _ = model(x, ps, st)

@test y_ray atol=1e-2 rtol=1e-2

@testset "gradient" begin
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
@test ∂x_ra∂x atol=1e-2 rtol=1e-2
@test check_approx(∂ps_ra, ∂ps; atol=1e-2, rtol=1e-2)
@testset for ordering in (BatchLastIndex(), TimeLastIndex())
model = Recurrence(cell(4 => 4); ordering)
ps, st = Lux.setup(rng, model)
ps_ra, st_ra = (ps, st) |> dev
if ordering isa TimeLastIndex
x = rand(Float32, 4, 12, 16)
else
x = rand(Float32, 4, 16, 12)
end
x_ra = x |> dev

y_ra, _ = @jit model(x_ra, ps_ra, st_ra)
y, _ = model(x, ps, st)

@test y_ray atol=1e-2 rtol=1e-2

@testset "gradient" begin
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
@test ∂x_ra∂x atol=1e-2 rtol=1e-2
@test check_approx(∂ps_ra, ∂ps; atol=1e-2, rtol=1e-2)
end
end
end
end
Expand Down

0 comments on commit 6413484

Please sign in to comment.