Skip to content

Commit

Permalink
Fix pooling layer
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Dec 24, 2024
1 parent a8ec4b5 commit 2f9f0d1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions GNNLux/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ pool = GlobalPool(mean)
g = GNNGraph(erdos_renyi(10, 4))
X = rand(32, 10)
pool(g, X) # => 32x1 matrix
pool(g, X, ps, st) # => 32x1 matrix
g = MLUtils.batch([GNNGraph(erdos_renyi(10, 4)) for _ in 1:5])
X = rand(32, 50)
pool(g, X) # => 32x5 matrix
pool(g, X, ps, st) # => 32x5 matrix
```
"""
struct GlobalPool{F} <: GNNLayer
Expand All @@ -39,4 +39,4 @@ end

(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st

(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))

0 comments on commit 2f9f0d1

Please sign in to comment.