From 2f9f0d1a61ce20c0e9b27deca8f0e35331f24a21 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Tue, 24 Dec 2024 23:19:38 +0100 Subject: [PATCH] Fix pooling layer --- GNNLux/src/layers/pool.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl index c5519f47c..7fcc044f6 100644 --- a/GNNLux/src/layers/pool.jl +++ b/GNNLux/src/layers/pool.jl @@ -25,12 +25,12 @@ pool = GlobalPool(mean) g = GNNGraph(erdos_renyi(10, 4)) X = rand(32, 10) -pool(g, X) # => 32x1 matrix +pool(g, X, ps, st) # => 32x1 matrix g = MLUtils.batch([GNNGraph(erdos_renyi(10, 4)) for _ in 1:5]) X = rand(32, 50) -pool(g, X) # => 32x5 matrix +pool(g, X, ps, st) # => 32x5 matrix ``` """ struct GlobalPool{F} <: GNNLayer @@ -39,4 +39,4 @@ end (l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st -(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) \ No newline at end of file +(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) \ No newline at end of file