Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Aug 4, 2024
1 parent 6b1af1b commit 9b76cf5
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,11 +629,12 @@ function Base.show(io::IO, l::GINConv)
print(io, ")")
end

@concrete struct MEGNetConv{TE, TV, A} <: GNNLayer
@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)}
in_dims::Int
out_dims::Int
ϕe::TE
ϕv::TV
aggr::A
num_features::NamedTuple
end

MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr)
Expand All @@ -646,15 +647,13 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean)
ϕv = Chain(Dense(nin + nout, nout, relu),
Dense(nout, nout))

num_features = (in = nin, out = nout)

return MEGNetConv(ϕe, ϕv; aggr, num_features)
return MEGNetConv(nin, nout, ϕe, ϕv; aggr)
end


LuxCore.outputsize(l::MegNetConv) = (l.num_features.out,)

(l::MegNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st)
(l::MegNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st) # check

function (l::MegNetConv)(g, x, e, ps, st)
ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe))
Expand Down

0 comments on commit 9b76cf5

Please sign in to comment.