Skip to content

Commit

Permalink
tests for GNNlib (#466)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jul 29, 2024
1 parent 9338ed7 commit 11515eb
Show file tree
Hide file tree
Showing 15 changed files with 293 additions and 254 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/test_GNNlib.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: GNNlib
on:
pull_request:
branches:
- master
push:
branches:
- master
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- '1.10' # Replace this with the minimum Julia version that your package supports.
# - '1' # '1' will automatically expand to the latest stable 1.x release of Julia.
# - 'pre'
os:
- ubuntu-latest
arch:
- x64

steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- name: Install Julia dependencies and run tests
shell: julia --project=monorepo {0}
run: |
using Pkg
# dev mono repo versions
pkg"registry up"
Pkg.update()
pkg"dev ./GNNGraphs ./GNNlib"
Pkg.test("GNNlib"; coverage=true)
- uses: julia-actions/julia-processcoverage@v1
with:
# directories: ./GNNlib/src, ./GNNlib/ext
directories: ./GNNlib/src
- uses: codecov/codecov-action@v4
with:
files: lcov.info
7 changes: 5 additions & 2 deletions GNNlib/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GNNlibCUDAExt = "CUDA"

[compat]
ChainRulesCore = "1.24"
CUDA = "4, 5"
ChainRulesCore = "1.24"
DataStructures = "0.18"
GNNGraphs = "1.0"
LinearAlgebra = "1"
Expand All @@ -32,7 +32,10 @@ Statistics = "1"
julia = "1.10"

[extras]
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "ReTestItems", "Reexport", "SparseArrays"]
16 changes: 8 additions & 8 deletions GNNlib/src/layers/pool.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@


function global_pool(aggr, g::GNNGraph, x::AbstractArray)
return reduce_nodes(aggr, g, x)
function global_pool(l, g::GNNGraph, x::AbstractArray)
return reduce_nodes(l.aggr, g, x)
end

function global_attention_pool(fgate, ffeat, g::GNNGraph, x::AbstractArray)
α = softmax_nodes(g, fgate(x))
feats = α .* ffeat(x)
function global_attention_pool(l, g::GNNGraph, x::AbstractArray)
α = softmax_nodes(g, l.fgate(x))
feats = α .* l.ffeat(x)
u = reduce_nodes(+, g, feats)
return u
end
Expand All @@ -26,11 +26,11 @@ end

topk_index(y::Adjoint, k::Int) = topk_index(y', k)

function set2set_pool(lstm, num_iters, g::GNNGraph, x::AbstractMatrix)
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
n_in = size(x, 1)
qstar = zeros_like(x, (2*n_in, g.num_graphs))
for t in 1:num_iters
q = lstm(qstar) # [n_in, n_graphs]
for t in 1:l.num_iters
q = l.lstm(qstar) # [n_in, n_graphs]
qn = broadcast_nodes(g, q) # [n_in, n_nodes]
α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes]
r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs]
Expand Down
140 changes: 140 additions & 0 deletions GNNlib/test/msgpass_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
@testitem "msgpass" setup=[SharedTestSetup] begin
#TODO test all graph types
GRAPH_T = :coo
in_channel = 10
out_channel = 5
num_V = 6
num_E = 14
T = Float32

adj = [0 1 0 0 0 0
1 0 0 1 1 1
0 0 0 0 0 1
0 1 0 0 1 0
0 1 0 1 0 1
0 1 1 0 1 0]

X = rand(T, in_channel, num_V)
E = rand(T, in_channel, num_E)

g = GNNGraph(adj, graph_type = GRAPH_T)

@testset "propagate" begin
function message(xi, xj, e)
@test xi === nothing
@test e === nothing
ones(T, out_channel, size(xj, 2))
end

m = propagate(message, g, +, xj = X)

@test size(m) == (out_channel, num_V)

@testset "isolated nodes" begin
x1 = rand(1, 6)
g1 = GNNGraph(collect(1:5), collect(1:5), num_nodes = 6)
y1 = propagate((xi, xj, e) -> xj, g, +, xj = x1)
@test size(y1) == (1, 6)
end
end

@testset "apply_edges" begin
m = apply_edges(g, e = E) do xi, xj, e
@test xi === nothing
@test xj === nothing
ones(out_channel, size(e, 2))
end

@test m == ones(out_channel, num_E)

# With NamedTuple input
m = apply_edges(g, xj = (; a = X, b = 2X), e = E) do xi, xj, e
@test xi === nothing
@test xj.b == 2 * xj.a
@test size(xj.a, 2) == size(xj.b, 2) == size(e, 2)
ones(out_channel, size(e, 2))
end

# NamedTuple output
m = apply_edges(g, e = E) do xi, xj, e
@test xi === nothing
@test xj === nothing
(; a = ones(out_channel, size(e, 2)))
end

@test m.a == ones(out_channel, num_E)

@testset "sizecheck" begin
x = rand(3, g.num_nodes - 1)
@test_throws AssertionError apply_edges(copy_xj, g, xj = x)
@test_throws AssertionError apply_edges(copy_xj, g, xi = x)

x = (a = rand(3, g.num_nodes), b = rand(3, g.num_nodes + 1))
@test_throws AssertionError apply_edges(copy_xj, g, xj = x)
@test_throws AssertionError apply_edges(copy_xj, g, xi = x)

e = rand(3, g.num_edges - 1)
@test_throws AssertionError apply_edges(copy_xj, g, e = e)
end
end

@testset "copy_xj" begin
n = 128
A = sprand(n, n, 0.1)
Adj = map(x -> x > 0 ? 1 : 0, A)
X = rand(10, n)

g = GNNGraph(A, ndata = X, graph_type = GRAPH_T)

function spmm_copyxj_fused(g)
propagate(copy_xj,
g, +; xj = g.ndata.x)
end

function spmm_copyxj_unfused(g)
propagate((xi, xj, e) -> xj,
g, +; xj = g.ndata.x)
end

@test spmm_copyxj_unfused(g) X * Adj
@test spmm_copyxj_fused(g) X * Adj
end

@testset "e_mul_xj and w_mul_xj for weighted conv" begin
n = 128
A = sprand(n, n, 0.1)
Adj = map(x -> x > 0 ? 1 : 0, A)
X = rand(10, n)

g = GNNGraph(A, ndata = X, edata = A.nzval, graph_type = GRAPH_T)

function spmm_unfused(g)
propagate((xi, xj, e) -> reshape(e, 1, :) .* xj,
g, +; xj = g.ndata.x, e = g.edata.e)
end
function spmm_fused(g)
propagate(e_mul_xj,
g, +; xj = g.ndata.x, e = g.edata.e)
end

function spmm_fused2(g)
propagate(w_mul_xj,
g, +; xj = g.ndata.x)
end

@test spmm_unfused(g) X * A
@test spmm_fused(g) X * A
@test spmm_fused2(g) X * A
end

@testset "aggregate_neighbors" begin
@testset "sizecheck" begin
m = rand(2, g.num_edges - 1)
@test_throws AssertionError aggregate_neighbors(g, +, m)

m = (a = rand(2, g.num_edges + 1), b = nothing)
@test_throws AssertionError aggregate_neighbors(g, +, m)
end
end

end
6 changes: 6 additions & 0 deletions GNNlib/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using GNNlib
using Test
using ReTestItems
using Random, Statistics

runtests(GNNlib)
12 changes: 12 additions & 0 deletions GNNlib/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
@testsetup module SharedTestSetup

import Reexport: @reexport

@reexport using GNNlib
@reexport using GNNGraphs
@reexport using NNlib
@reexport using MLUtils
@reexport using SparseArrays
@reexport using Test, Random, Statistics

end
68 changes: 68 additions & 0 deletions GNNlib/test/utils_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
@testitem "utils" setup=[SharedTestSetup] begin
# TODO test all graph types
GRAPH_T = :coo
De, Dx = 3, 2
g = MLUtils.batch([rand_graph(10, 60, bidirected=true,
ndata = rand(Dx, 10),
edata = rand(De, 30),
graph_type = GRAPH_T) for i in 1:5])
x = g.ndata.x
e = g.edata.e

@testset "reduce_nodes" begin
r = reduce_nodes(mean, g, x)
@test size(r) == (Dx, g.num_graphs)
@test r[:, 2] mean(getgraph(g, 2).ndata.x, dims = 2)

r2 = reduce_nodes(mean, graph_indicator(g), x)
@test r2 == r
end

@testset "reduce_edges" begin
r = reduce_edges(mean, g, e)
@test size(r) == (De, g.num_graphs)
@test r[:, 2] mean(getgraph(g, 2).edata.e, dims = 2)
end

@testset "softmax_nodes" begin
r = softmax_nodes(g, x)
@test size(r) == size(x)
@test r[:, 1:10] softmax(getgraph(g, 1).ndata.x, dims = 2)
end

@testset "softmax_edges" begin
r = softmax_edges(g, e)
@test size(r) == size(e)
@test r[:, 1:60] softmax(getgraph(g, 1).edata.e, dims = 2)
end

@testset "broadcast_nodes" begin
z = rand(4, g.num_graphs)
r = broadcast_nodes(g, z)
@test size(r) == (4, g.num_nodes)
@test r[:, 1] z[:, 1]
@test r[:, 10] z[:, 1]
@test r[:, 11] z[:, 2]
end

@testset "broadcast_edges" begin
z = rand(4, g.num_graphs)
r = broadcast_edges(g, z)
@test size(r) == (4, g.num_edges)
@test r[:, 1] z[:, 1]
@test r[:, 60] z[:, 1]
@test r[:, 61] z[:, 2]
end

@testset "softmax_edge_neighbors" begin
s = [1, 2, 3, 4]
t = [5, 5, 6, 6]
g2 = GNNGraph(s, t)
e2 = randn(Float32, 3, g2.num_edges)
z = softmax_edge_neighbors(g2, e2)
@test size(z) == size(e2)
@test z[:, 1:2] NNlib.softmax(e2[:, 1:2], dims = 2)
@test z[:, 3:4] NNlib.softmax(e2[:, 3:4], dims = 2)
end
end

2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.6.20"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
Expand All @@ -27,7 +26,6 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
[compat]
CUDA = "4, 5"
ChainRulesCore = "1"
DataStructures = "0.18"
Flux = "0.14"
Functors = "0.4.1"
GNNGraphs = "1.0"
Expand Down
1 change: 0 additions & 1 deletion src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ using NNlib
using NNlib: scatter, gather
using ChainRulesCore
using Reexport
using DataStructures: nlargest
using MLUtils: zeros_like

using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
Expand Down
1 change: 1 addition & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
# V1.0 deprecations
# TODO doe some reason this is not working
# @deprecate (l::GCNConv)(g, x, edge_weight, norm_fn; conv_weight=nothing) l(g, x, edge_weight; norm_fn, conv_weight)
# @deprecate (l::GNNLayer)(gs::AbstractVector{<:GNNGraph}, args...; kws...) l(MLUtils.batch(gs), args...; kws...)
6 changes: 0 additions & 6 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@ abstract type GNNLayer end
# To be specialized by layers also needing edge features as input (e.g. NNConv).
(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))

function (l::GNNLayer)(g::AbstractVector{<:GNNGraph}, args...; kws...)
@warn "Passing an array of graphs to a `GNNLayer` is discouraged.
Explicitely call `Flux.batch(graphs)` first instead." maxlog=1
return l(batch(g), args...; kws...)
end

"""
WithGraph(model, g::GNNGraph; traingraph=false)
Expand Down
Loading

0 comments on commit 11515eb

Please sign in to comment.