Skip to content

Commit

Permalink
feat: implement for return sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 19, 2025
1 parent 4ae54ef commit 0bbc7d9
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions ext/LuxReactantExt/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@ function (r::Lux.Recurrence{False})(x::AnyTracedRArray, ps, st::NamedTuple)
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2)
idxs = ntuple(Returns(Colon()), ndims(x) - 1)
(out, carry), st = r.cell(x[idxs..., 1], ps, st)
T = size(x, ndims(x))
@trace for i in 2:T
@trace for i in 2:size(x, ndims(x))
(out, carry), st = r.cell((x[idxs..., i], carry), ps, st)
end
return out, st
elseif r.ordering isa Lux.BatchLastIndex
idxs = ntuple(Returns(Colon()), ndims(x) - 2)
(out, carry), st = r.cell(x[idxs..., 1, :], ps, st)
T = size(x, ndims(x) - 1)
@trace for i in 2:T
@trace for i in 2:size(x, ndims(x) - 1)
(out, carry), st = r.cell((x[idxs..., i, :], carry), ps, st)
end
return out, st
Expand All @@ -27,4 +25,28 @@ function (r::Lux.Recurrence{False})(x::AnyTracedRArray, ps, st::NamedTuple)
end
end

# TODO: We need to implement the return sequence version as well
function (r::Lux.Recurrence{True})(x::AnyTracedRArray, ps, st::NamedTuple)
if r.ordering isa Lux.TimeLastIndex ||
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2)
idxs = ntuple(Returns(Colon()), ndims(x) - 1)
(out, carry), st = r.cell(x[idxs..., 1], ps, st)
sequence = similar(out, size(out)..., size(x, ndims(x)))
sequence[idxs..., 1] .= out
@trace for i in 2:size(x, ndims(x))
(out, carry), st = r.cell((x[idxs..., i], carry), ps, st)
sequence[idxs..., i] = out
end
elseif r.ordering isa Lux.BatchLastIndex
idxs = ntuple(Returns(Colon()), ndims(x) - 2)
(out, carry), st = r.cell(x[idxs..., 1, :], ps, st)
sequence = similar(out, size(out)..., size(x, ndims(x) - 1))
sequence[idxs..., :, 1] .= out
@trace for i in 2:size(x, ndims(x) - 1)
(out, carry), st = r.cell((x[idxs..., i, :], carry), ps, st)
sequence[idxs..., :, i] = out
end
else
error("Unknown ordering: $(r.ordering)")
end
return (out, eachslice(sequence; dims=ndims(sequence))), st
end

0 comments on commit 0bbc7d9

Please sign in to comment.