Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Aug 19, 2024
1 parent 2585cb6 commit 70674a2
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 1 deletion.
14 changes: 14 additions & 0 deletions GNNLux/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,29 @@ authors = ["Carlo Lucibello and contributors"]
version = "0.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ConcreteStructs = "0.2.3"
Expand Down
1 change: 0 additions & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,6 @@ end
function (l::NNConv)(g, x, edge_weight, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)

# what would be the order of args here?
m = (; nn, l.aggr, ps.weight, bias = _getbias(ps),
l.add_self_loops, l.use_edge_weight, l.σ)
y = GNNlib.nn_conv(m, g, x, edge_weight)
Expand Down
7 changes: 7 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,11 @@
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end

@testset "NNConv" begin
edim = 10
nn = Dense(edim, out_dims * in_dims)
l = NNConv(in_dims => out_dims, nn, tanh, bias = true, aggr = +)
test_lux_layer(rng, l, g, x, sizey=(out_dims, g.num_nodes), container=true)
end
end
94 changes: 94 additions & 0 deletions GNNLux/test/layers/temp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@


@testset "GINConv" begin
nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end

@testset "SGConv" begin
l = SGConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end



function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
outputsize=nothing, sizey=nothing, container=false,
atol=1.0f-2, rtol=1.0f-2)


ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
@test LuxCore.statelength(l) == LuxCore.statelength(st)

y, st′ = l(g, x, ps, st)
@test eltype(y) == eltype(x)
if outputsize !== nothing
@test LuxCore.outputsize(l) == outputsize
end
if sizey !== nothing
@test size(y) == sizey
elseif outputsize !== nothing
@test size(y) == (outputsize..., g.num_nodes)
end

loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme
using StableRNGs

"""
MEGNetConv{Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, typeof(mean)}(Chain(Dense(9 => 5, relu), Dense(5 => 5)), Chain(Dense(8 => 5, relu), Dense(5 => 5)), Statistics.mean)
"""

g = rand_graph(10, 40, seed=1234)
in_dims = 3
out_dims = 5
x = randn(Float32, in_dims, 10)
rng = StableRNG(1234)
l = MEGNetConv(in_dims => out_dims)
l
l isa GNNContainerLayer
test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true)


ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
edata = rand(T, in_channel, g.num_edges)

(x_new, e_new), st_new = l(g, x, ps, st)

@test size(x_new) == (out_dims, g.num_nodes)
@test size(e_new) == (out_dims, g.num_edges)


nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)



hin = 6
hout = 7
hidden = 8
l = EGNNConv(hin => hout, hidden)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
h = randn(rng, Float32, hin, g.num_nodes)
(hnew, xnew), stnew = l(g, h, x, ps, st)
@test size(hnew) == (hout, g.num_nodes)
@test size(xnew) == (in_dims, g.num_nodes)


l = MEGNetConv(in_dims => out_dims)
l
l isa GNNContainerLayer
test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true)


ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)

0 comments on commit 70674a2

Please sign in to comment.