Skip to content

Commit

Permalink
temporary changes to run tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Aug 7, 2024
1 parent d5cfb7b commit 3296f2e
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 11 deletions.
4 changes: 2 additions & 2 deletions GNNGraphs/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const ACUMatrix{T} = Union{CuMatrix{T}, CUDA.CUSPARSE.CuSparseMatrix{T}}
ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets

include("test_utils.jl")

"""
tests = [
"chainrules",
"datastore",
Expand All @@ -39,7 +39,7 @@ tests = [
"mldatasets",
"ext/SimpleWeightedGraphs"
]

"""
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")

for graph_type in (:coo, :dense, :sparse)
Expand Down
2 changes: 1 addition & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export AGNNConv,
GINConv,
# GMMConv,
GraphConv,
#MEGNetConv,
MEGNetConv,
# NNConv,
# ResGatedGraphConv,
# SAGEConv,
Expand Down
13 changes: 8 additions & 5 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,18 +650,21 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean)
return MEGNetConv(nin, nout, ϕe, ϕv; aggr)
end

LuxCore.outputsize(l::MegNetConv) = (l.num_features.out,)

function (l::MegNetConv)(g, x, e, ps, st)
function (l::MEGNetConv)(g, x, e, ps, st)
ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe))
ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv))
m = (; ϕe, ϕv, l.residual, l.num_features)
return GNNlib.megnet_conv(m, g, x, e), st
end

function Base.show(io::IO, l::MegNetConv)

LuxCore.outputsize(l::MEGNetConv) = (l.out_dims,)

(l::MEGNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function Base.show(io::IO, l::MEGNetConv)
nin = l.in_dims
nout = l.out_dims
print(io, "MegNetConv(", nin, " => ", nout)
print(io, "MEGNetConv(", nin, " => ", nout)
print(io, ")")
end
20 changes: 18 additions & 2 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
in_dims = 3
out_dims = 5
x = randn(rng, Float32, in_dims, 10)

"""
@testset "GCNConv" begin
l = GCNConv(in_dims => out_dims, relu)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
Expand Down Expand Up @@ -53,7 +53,22 @@
@test size(hnew) == (hout, g.num_nodes)
@test size(xnew) == (in_dims, g.num_nodes)
end

"""
@testset "MEGNetConv" begin
in_dims = 6
out_dims = 8

l = MEGNetConv(in_dims => out_dims)

ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)

(x_new, e_new), st_new = l(g, x, ps, st)

@test size(x_new) == (out_dims, g.num_nodes)
@test size(e_new) == (out_dims, g.num_edges)
end
"""
@testset "GATConv" begin
x = randn(rng, Float32, 6, 10)
Expand Down Expand Up @@ -93,4 +108,5 @@
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end
"""
end
2 changes: 1 addition & 1 deletion GNNlib/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ using Test
using ReTestItems
using Random, Statistics

runtests(GNNlib)
#runtests(GNNlib)

0 comments on commit 3296f2e

Please sign in to comment.