Skip to content

Commit

Permalink
fix: update to latest Reactant changes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 20, 2024
1 parent c655138 commit c8c6d4d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ CairoMakie = "0.12.11"
CondaPkg = "0.2.23"
DataDeps = "0.7.13"
Documenter = "1.7.0"
Enzyme = "0.13.24"
Lux = "1.2.1"
LuxCUDA = "0.3.3"
MAT = "0.10.7"
Expand All @@ -28,4 +29,5 @@ NeuralOperators = "0.5"
Optimisers = "0.3.3, 0.4"
Printf = "1.10"
PythonCall = "0.9.23"
Reactant = "0.2.11"
Zygote = "0.6.71"
19 changes: 13 additions & 6 deletions ext/NeuralOperatorsReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,29 @@ module NeuralOperatorsReactantExt
using FFTW: FFTW
using NeuralOperators: NeuralOperators, FourierTransform
using NNlib: NNlib
using Reactant: Reactant, TracedRArray
using Reactant: Reactant, TracedRArray, AnyTracedRArray

# XXX: Reevaluate after https://github.com/EnzymeAD/Reactant.jl/issues/246 is fixed
function NeuralOperators.transform(ft::FourierTransform, x::TracedRArray{T, N}) where {T, N}
x_c = Reactant.promote_to(TracedRArray{Complex{T}, N}, x)
function NeuralOperators.transform(
ft::FourierTransform, x::AnyTracedRArray{T, N}) where {T, N}
x_c = Reactant.TracedUtils.promote_to(
TracedRArray{Complex{T}, N},
Reactant.TracedUtils.materialize_traced_array(x)
)
return FFTW.fft(x_c, 1:ndims(ft))
end

function NeuralOperators.inverse(
ft::FourierTransform, x::TracedRArray{T, N}, ::NTuple{N, Int64}) where {T, N}
ft::FourierTransform, x::AnyTracedRArray{T, N}, ::NTuple{N, Int64}) where {T, N}
return real(FFTW.ifft(x, 1:ndims(ft)))
end

function NeuralOperators.fast_pad_zeros(x::TracedRArray, pad_dims)
function NeuralOperators.fast_pad_zeros(x::AnyTracedRArray, pad_dims)
return NNlib.pad_zeros(
x, NeuralOperators.expand_pad_dims(pad_dims); dims=ntuple(identity, ndims(x) - 2))
Reactant.TracedUtils.materialize_traced_array(x),
NeuralOperators.expand_pad_dims(pad_dims);
dims=ntuple(identity, ndims(x) - 2)
)
end

end

0 comments on commit c8c6d4d

Please sign in to comment.