diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index b8266b3b2..279d2beb3 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -629,11 +629,12 @@ function Base.show(io::IO, l::GINConv) print(io, ")") end -@concrete struct MEGNetConv{TE, TV, A} <: GNNLayer +@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)} + in_dims::Int + out_dims::Int ϕe::TE ϕv::TV aggr::A - num_features::NamedTuple end MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr) @@ -646,15 +647,13 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) ϕv = Chain(Dense(nin + nout, nout, relu), Dense(nout, nout)) - num_features = (in = nin, out = nout) - - return MEGNetConv(ϕe, ϕv; aggr, num_features) + return MEGNetConv(nin, nout, ϕe, ϕv; aggr) end LuxCore.outputsize(l::MegNetConv) = (l.num_features.out,) -(l::MegNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st) +(l::MegNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st) # check function (l::MegNetConv)(g, x, e, ps, st) ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe))