diff --git a/.gitignore b/.gitignore index 7181205b6..91820619c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ Manifest.toml .vscode LocalPreferences.toml .DS_Store -docs/src/democards/gridtheme.css \ No newline at end of file +docs/src/democards/gridtheme.css +test.jl \ No newline at end of file diff --git a/GNNGraphs/src/abstracttypes.jl b/GNNGraphs/src/abstracttypes.jl index b8959b807..73146160f 100644 --- a/GNNGraphs/src/abstracttypes.jl +++ b/GNNGraphs/src/abstracttypes.jl @@ -1,5 +1,5 @@ -const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V} +const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: Union{Nothing, AbstractVector}} const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}} const ADJMAT_T = AbstractMatrix const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T diff --git a/GNNGraphs/src/convert.jl b/GNNGraphs/src/convert.jl index 1e103db8b..3789309cb 100644 --- a/GNNGraphs/src/convert.jl +++ b/GNNGraphs/src/convert.jl @@ -4,27 +4,24 @@ function to_coo(data::EDict; num_nodes = nothing, kws...) graph = EDict{COO_T}() _num_nodes = NDict{Int}() num_edges = EDict{Int}() - if !isempty(data) - for k in keys(data) - d = data[k] - @assert d isa Tuple - if length(d) == 2 - d = (d..., nothing) - end - if num_nodes !== nothing - n1 = get(num_nodes, k[1], nothing) - n2 = get(num_nodes, k[3], nothing) - else - n1 = nothing - n2 = nothing - end - g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) - graph[k] = g - num_edges[k] = nedges - _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) - _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) + for k in keys(data) + d = data[k] + @assert d isa Tuple + if length(d) == 2 + d = (d..., nothing) end - graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types + if num_nodes !== nothing + n1 = get(num_nodes, k[1], nothing) + n2 = get(num_nodes, k[3], nothing) + else + n1 = nothing + n2 = nothing + end + g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) + graph[k] = g + num_edges[k] = nedges + _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) + _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) end return graph, _num_nodes, num_edges end diff --git a/GNNGraphs/src/generate.jl b/GNNGraphs/src/generate.jl index 4e6738279..6005ac023 100644 --- a/GNNGraphs/src/generate.jl +++ b/GNNGraphs/src/generate.jl @@ -1,5 +1,5 @@ """ - rand_graph(n, m; bidirected=true, seed=-1, edge_weight = nothing, kws...) + rand_graph([rng,] n, m; bidirected=true, edge_weight = nothing, kws...) Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes and `m` edges. @@ -10,7 +10,7 @@ In any case, the output graph will contain no self-loops or multi-edges. A vector can be passed as `edge_weight`. Its length has to be equal to `m` in the directed case, and `m÷2` in the bidirected one. -Use a `seed > 0` for reproducibility. +Pass a random number generator as the first argument to make the generation reproducible. Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. @@ -36,25 +36,42 @@ GNNGraph: # Each edge has a reverse julia> edge_index(g) ([1, 3, 3, 4], [3, 4, 1, 3]) - ``` """ -function rand_graph(n::Integer, m::Integer; bidirected = true, seed = -1, edge_weight = nothing, kws...) +function rand_graph(n::Integer, m::Integer; seed=-1, kws...) + if seed != -1 + Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_graph) + rng = MersenneTwister(seed) + else + rng = Random.default_rng() + end + return rand_graph(rng, n, m; kws...) +end + +function rand_graph(rng::AbstractRNG, n::Integer, m::Integer; + bidirected::Bool = true, + edge_weight::Union{AbstractVector, Nothing} = nothing, kws...) if bidirected - @assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m." + @assert iseven(m) lazy"Need even number of edges for bidirected graphs, given m=$m." + s, t, _ = _rand_edges(rng, n, m ÷ 2; directed=false, self_loops=false) + s, t = vcat(s, t), vcat(t, s) + if edge_weight !== nothing + edge_weight = vcat(edge_weight, edge_weight) + end + else + s, t, _ = _rand_edges(rng, n, m; directed=true, self_loops=false) end - m2 = bidirected ? m ÷ 2 : m - return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed = !bidirected, seed); edge_weight, kws...) + return GNNGraph((s, t, edge_weight); num_nodes=n, kws...) end """ - rand_heterograph(n, m; seed=-1, bidirected=false, kws...) + rand_heterograph([rng,] n, m; bidirected=false, kws...) -Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges +Construct an [`GNNHeteroGraph`](@ref) with random edges and with number of nodes and edges specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs specifing node/edge types and their numbers. -Use a `seed > 0` for reproducibility. +Pass a random number generator as a first argument to make the generation reproducible. Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge. Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)` @@ -76,9 +93,19 @@ function rand_heterograph end # for generic iterators of pairs rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...) +rand_heterograph(rng::AbstractRNG, n, m; kws...) = rand_heterograph(rng, Dict(n), Dict(m); kws...) -function rand_heterograph(n::NDict, m::EDict; bidirected = false, seed = -1, kws...) - rng = seed > 0 ? MersenneTwister(seed) : Random.GLOBAL_RNG +function rand_heterograph(n::NDict, m::EDict; seed=-1, kws...) + if seed != -1 + Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_heterograph) + rng = MersenneTwister(seed) + else + rng = Random.default_rng() + end + return rand_heterograph(rng, n, m; kws...) +end + +function rand_heterograph(rng::AbstractRNG, n::NDict, m::EDict; bidirected::Bool = false, kws...) if bidirected return _rand_bidirected_heterograph(rng, n, m; kws...) end @@ -86,7 +113,7 @@ function rand_heterograph(n::NDict, m::EDict; bidirected = false, seed = -1, kws return GNNHeteroGraph(graphs; num_nodes = n, kws...) end -function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...) +function _rand_bidirected_heterograph(rng::AbstractRNG, n::NDict, m::EDict; kws...) for k in keys(m) if reverse(k) ∈ keys(m) @assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs." @@ -104,43 +131,60 @@ function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...) return GNNHeteroGraph(graphs; num_nodes = n, kws...) end -function _rand_edges(rng, (n1, n2), m) - idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false) - s, t = edge_decoding(idx, n1, n2) - val = nothing - return s, t, val -end """ - rand_bipartite_heterograph(n1, n2, m; [bidirected, seed, node_t, edge_t, kws...]) - rand_bipartite_heterograph((n1, n2), m; ...) - rand_bipartite_heterograph((n1, n2), (m1, m2); ...) + rand_bipartite_heterograph([rng,] + (n1, n2), (m12, m21); + bidirected = true, + node_t = (:A, :B), + edge_t = :to, + kws...) -Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges -specified by `n1`, `n2` and `m1` and `m2` respectively. +Construct an [`GNNHeteroGraph`](@ref) with random edges representing a bipartite graph. +The graph will have two types of nodes, and edges will only connect nodes of different types. -See [`rand_heterograph`](@ref) for a more general version. +The first argument is a tuple `(n1, n2)` specifying the number of nodes of each type. +The second argument is a tuple `(m12, m21)` specifying the number of edges connecting nodes of type `1` to nodes of type `2` +and vice versa. -# Keyword arguments +The type of nodes and edges can be specified with the `node_t` and `edge_t` keyword arguments, +which default to `(:A, :B)` and `:to` respectively. -- `bidirected`: whether to generate a bidirected graph. Default is `true`. -- `seed`: random seed. Default is `-1` (no seed). -- `node_t`: node types. If `bipartite=true`, this should be a tuple of two node types, otherwise it should be a single node type. -- `edge_t`: edge types. If `bipartite=true`, this should be a tuple of two edge types, otherwise it should be a single edge type. -""" -function rand_bipartite_heterograph end +If `bidirected=true` (default), the reverse edge of each edge will be present. In this case +`m12 == m21` is required. + +A random number generator can be passed as the first argument to make the generation reproducible. + +Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor. + +See [`rand_heterograph`](@ref) for a more general version. + +# Examples -rand_bipartite_heterograph(n1::Int, n2::Int, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...) +```julia-repl +julia> g = rand_bipartite_heterograph((10, 15), 20) +GNNHeteroGraph: + num_nodes: (:A => 10, :B => 15) + num_edges: ((:A, :to, :B) => 20, (:B, :to, :A) => 20) -rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...) +julia> g = rand_bipartite_heterograph((10, 15), (20, 0), node_t=(:user, :item), edge_t=:-, bidirected=false) +GNNHeteroGraph: + num_nodes: Dict(:item => 15, :user => 10) + num_edges: Dict((:item, :-, :user) => 0, (:user, :-, :item) => 20) +``` +""" +rand_bipartite_heterograph(n, m; kws...) = rand_bipartite_heterograph(Random.default_rng(), n, m; kws...) -function rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, (m1, m2)::NTuple{2,Int}; bidirected=true, - node_t = (:A, :B), edge_t = :to, kws...) - if edge_t isa Symbol - edge_t = (edge_t, edge_t) +function rand_bipartite_heterograph(rng::AbstractRNG, (n1, n2)::NTuple{2,Int}, m; bidirected=true, + node_t = (:A, :B), edge_t::Symbol = :to, kws...) + if m isa Integer + m12 = m21 = m + else + m12, m21 = m end - return rand_heterograph(Dict(node_t[1] => n1, node_t[2] => n2), - Dict((node_t[1], edge_t[1], node_t[2]) => m1, (node_t[2], edge_t[2], node_t[1]) => m2); + + return rand_heterograph(rng, Dict(node_t[1] => n1, node_t[2] => n2), + Dict((node_t[1], edge_t, node_t[2]) => m12, (node_t[2], edge_t, node_t[1]) => m21); bidirected, kws...) end diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 8e8c98d13..325a20f5c 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -57,7 +57,8 @@ then all new self loops will have no weight. If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same. This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type. """ -function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V} +function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) get(g.graph, edge_t, (nothing, nothing, nothing))[3] end @@ -69,13 +70,17 @@ function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where n = get(g.num_nodes, src_t, 0) if haskey(g.graph, edge_t) - x = g.graph[edge_t] - s, t = x[1:2] + s, t = g.graph[edge_t][1:2] nodes = convert(typeof(s), [1:n;]) s = [s; nodes] t = [t; nodes] else - nodes = convert(T, [1:n;]) + if !isempty(g.graph) + T = typeof(first(values(g.graph))[1]) + nodes = convert(T, [1:n;]) + else + nodes = [1:n;] + end s = nodes t = nodes end @@ -518,7 +523,6 @@ end Return a new graph obtained from `g` by adding random edges, based on a specified `perturb_ratio`. The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. These new edges are added without creating self-loops. -Optionally, a random `seed` can be provided to ensure reproducible perturbations. The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges. The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges. diff --git a/GNNGraphs/src/utils.jl b/GNNGraphs/src/utils.jl index 4bba304ef..7cdc3e543 100644 --- a/GNNGraphs/src/utils.jl +++ b/GNNGraphs/src/utils.jl @@ -205,17 +205,13 @@ end numnonzeros(a::AbstractSparseMatrix) = nnz(a) numnonzeros(a::AbstractMatrix) = count(!=(0), a) -# each edge is represented by a number in -# 1:N^2 -function edge_encoding(s, t, n; directed = true) - if directed - # directed edges and self-loops allowed - idx = (s .- 1) .* n .+ t +## Map edges into a contiguous range of integers +function edge_encoding(s, t, n; directed = true, self_loops = true) + if directed && self_loops maxid = n^2 - else - # Undirected edges and self-loops allowed + idx = (s .- 1) .* n .+ t + elseif !directed && self_loops maxid = n * (n + 1) ÷ 2 - mask = s .> t snew = copy(s) tnew = copy(t) @@ -228,18 +224,34 @@ function edge_encoding(s, t, n; directed = true) # = ∑_{i',i' s) + elseif !directed && !self_loops + @assert all(s .!= t) + maxid = n * (n - 1) ÷ 2 + mask = s .> t + snew = copy(s) + tnew = copy(t) + snew[mask] .= t[mask] + tnew[mask] .= s[mask] + s, t = snew, tnew + + # idx(s,t) = ∑_{s',1<= s'= s) + elseif !directed && !self_loops + # Considering t = s + 1 in + # idx = @. (s - 1) * n - s * (s - 1) ÷ 2 + (t - s) + # and inverting for s we have + s = @. floor(Int, 1/2 + n - 1/2 * sqrt(9 - 4n + 4n^2 - 8*idx)) + # now we can find t + t = @. idx - (s - 1) * n + s * (s - 1) ÷ 2 + s end return s, t end -# each edge is represented by a number in -# 1:n1*n2 +# for bipartite graphs function edge_decoding(idx, n1, n2) @assert all(1 .<= idx .<= n1 * n2) s = (idx .- 1) .÷ n2 .+ 1 @@ -265,6 +287,29 @@ function edge_decoding(idx, n1, n2) return s, t end +function _rand_edges(rng, n::Int, m::Int; directed = true, self_loops = true) + idmax = if directed && self_loops + n^2 + elseif !directed && self_loops + n * (n + 1) ÷ 2 + elseif directed && !self_loops + n * (n - 1) + elseif !directed && !self_loops + n * (n - 1) ÷ 2 + end + idx = StatsBase.sample(rng, 1:idmax, m, replace = false) + s, t = edge_decoding(idx, n; directed, self_loops) + val = nothing + return s, t, val +end + +function _rand_edges(rng, (n1, n2), m) + idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false) + s, t = edge_decoding(idx, n1, n2) + val = nothing + return s, t, val +end + binarize(x) = map(>(0), x) @non_differentiable binarize(x...) diff --git a/GNNGraphs/test/generate.jl b/GNNGraphs/test/generate.jl index 867fec399..c26b651c3 100644 --- a/GNNGraphs/test/generate.jl +++ b/GNNGraphs/test/generate.jl @@ -16,19 +16,23 @@ @test g.edata.e[:, (m2 + 1):end] == e end - g = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T) + rng = MersenneTwister(17) + g = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T) @test g.num_nodes == n @test g.num_edges == m - g2 = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T) + rng = MersenneTwister(17) + g2 = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T) @test edge_index(g2) == edge_index(g) ew = rand(m2) - g = rand_graph(n, m, bidirected = true, seed = 17, graph_type = GRAPH_T, edge_weight = ew) + rng = MersenneTwister(17) + g = rand_graph(rng, n, m, bidirected = true, graph_type = GRAPH_T, edge_weight = ew) @test get_edge_weight(g) == [ew; ew] broken=(GRAPH_T != :coo) ew = rand(m) - g = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T, edge_weight = ew) + rng = MersenneTwister(17) + g = rand_graph(n, m, bidirected = false, graph_type = GRAPH_T, edge_weight = ew) @test get_edge_weight(g) == ew broken=(GRAPH_T != :coo) end @@ -77,7 +81,7 @@ end end @testset "rand_bipartite_heterograph" begin - g = rand_bipartite_heterograph(10, 15, 20) + g = rand_bipartite_heterograph((10, 15), (20, 20)) @test g.num_nodes == Dict(:A => 10, :B => 15) @test g.num_edges == Dict((:A, :to, :B) => 20, (:B, :to, :A) => 20) sA, tB = edge_index(g, (:A, :to, :B)) diff --git a/GNNGraphs/test/gnnheterograph.jl b/GNNGraphs/test/gnnheterograph.jl index 6764b7814..f3c29b80f 100644 --- a/GNNGraphs/test/gnnheterograph.jl +++ b/GNNGraphs/test/gnnheterograph.jl @@ -123,7 +123,7 @@ end @testset "get/set node features" begin d, n = 3, 5 - g = rand_bipartite_heterograph(n, 2*n, 15) + g = rand_bipartite_heterograph((n, 2*n), 15) g[:A].x = rand(Float32, d, n) g[:B].y = rand(Float32, d, 2*n) @@ -133,7 +133,7 @@ end @testset "add_edges" begin d, n = 3, 5 - g = rand_bipartite_heterograph(n, 2 * n, 15) + g = rand_bipartite_heterograph((n, 2 * n), 15) s, t = [1, 2, 3], [3, 2, 1] ## Keep the same ntypes - construct with args g1 = add_edges(g, (:A, :rel1, :B), s, t) diff --git a/GNNGraphs/test/utils.jl b/GNNGraphs/test/utils.jl index db65b6357..31a1c7373 100644 --- a/GNNGraphs/test/utils.jl +++ b/GNNGraphs/test/utils.jl @@ -47,10 +47,55 @@ tnew[mask] .= s1[mask] @test sdec == snew @test tdec == tnew + + @testset "directed=false, self_loops=false" begin + n = 5 + edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] + s = [e[1] for e in edges] + t = [e[2] for e in edges] + g = GNNGraph(s, t) + idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=false, self_loops=false) + @test idxmax == n * (n - 1) ÷ 2 + @test idx == 1:idxmax + + snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=false, self_loops=false) + @test snew == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4] + @test tnew == [2, 3, 4, 5, 3, 4, 5, 4, 5, 5] + end + + @testset "directed=false, self_loops=false" begin + n = 5 + edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] + s = [e[1] for e in edges] + t = [e[2] for e in edges] + + idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=false, self_loops=false) + @test idxmax == n * (n - 1) ÷ 2 + @test idx == 1:idxmax + + snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=false, self_loops=false) + @test snew == [1, 1, 1, 1, 2, 2, 2, 3, 3, 4] + @test tnew == [2, 3, 4, 5, 3, 4, 5, 4, 5, 5] + end + + @testset "directed=true, self_loops=false" begin + n = 5 + edges = [(1,2), (3,1), (1,4), (1,5), (2,3), (2,4), (2,5), (3,4), (3,5), (4,5)] + s = [e[1] for e in edges] + t = [e[2] for e in edges] + + idx, idxmax = GNNGraphs.edge_encoding(s, t, n, directed=true, self_loops=false) + @test idxmax == n^2 - n + @test idx == [1, 9, 3, 4, 6, 7, 8, 11, 12, 16] + snew, tnew = GNNGraphs.edge_decoding(idx, n, directed=true, self_loops=false) + @test snew == s + @test tnew == t + end end -@testset "color_refinment" begin - g = rand_graph(10, 20, seed=17, graph_type = GRAPH_T) +@testset "color_refinement" begin + rng = MersenneTwister(17) + g = rand_graph(rng, 10, 20, graph_type = GRAPH_T) x0 = ones(Int, 10) x, ncolors, niters = color_refinement(g, x0) @test ncolors == 8 @@ -59,4 +104,4 @@ end x2, _, _ = color_refinement(g) @test x2 == x -end \ No newline at end of file +end \ No newline at end of file diff --git a/GNNLux/test/layers/basic_tests.jl b/GNNLux/test/layers/basic_tests.jl index 9f59f3b10..ac937d128 100644 --- a/GNNLux/test/layers/basic_tests.jl +++ b/GNNLux/test/layers/basic_tests.jl @@ -1,6 +1,6 @@ @testitem "layers/basic" setup=[SharedTestSetup] begin rng = StableRNG(17) - g = rand_graph(10, 40, seed=17) + g = rand_graph(rng, 10, 40) x = randn(rng, Float32, 3, 10) @testset "GNNLayer" begin diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 9f010f39e..ab06c9445 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,12 +1,12 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) - g = rand_graph(10, 40, seed=1234) + g = rand_graph(rng, 10, 40) in_dims = 3 out_dims = 5 x = randn(rng, Float32, in_dims, 10) @testset "GCNConv" begin - l = GCNConv(in_dims => out_dims, relu) + l = GCNConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @@ -16,7 +16,7 @@ end @testset "GraphConv" begin - l = GraphConv(in_dims => out_dims, relu) + l = GraphConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end @@ -26,7 +26,7 @@ end @testset "EdgeConv" begin - nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims)) + nn = Chain(Dense(2*in_dims => 2, tanh), Dense(2 => out_dims)) l = EdgeConv(nn, aggr = +) test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true) end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 9a3b6ee9f..2428865ae 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -21,7 +21,7 @@ @testset "constructor with names" begin m = GNNChain(GCNConv(din => d), LayerNorm(d), - x -> relu.(x), + x -> tanh.(x), Dense(d, dout)) m2 = GNNChain(enc = m, @@ -34,7 +34,7 @@ @testset "constructor with vector" begin m = GNNChain(GCNConv(din => d), LayerNorm(d), - x -> relu.(x), + x -> tanh.(x), Dense(d, dout)) m2 = GNNChain([m.layers...]) @test m2(g, x) == m(g, x) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 4c4827d2e..b96baa880 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -100,7 +100,7 @@ end test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) end - l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean) + l = GraphConv(in_channel => out_channel, tanh, bias = false, aggr = mean) for g in test_graphs test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) end diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index 29f36ba63..d9eaf0c7f 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -1,6 +1,6 @@ @testset "HeteroGraphConv" begin d, n = 3, 5 - g = rand_bipartite_heterograph(n, 2*n, 15) + g = rand_bipartite_heterograph((n, 2*n), 15) hg = rand_bipartite_heterograph((2,3), 6) model = HeteroGraphConv([(:A,:to,:B) => GraphConv(d => d), @@ -30,8 +30,8 @@ end @testset "Constructor from pairs" begin - layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, relu), - (:B, :to, :A) => GraphConv(64 => 32, relu)); + layer = HeteroGraphConv((:A, :to, :B) => GraphConv(64 => 32, tanh), + (:B, :to, :A) => GraphConv(64 => 32, tanh)); @test length(layer.etypes) == 2 end @@ -95,8 +95,8 @@ @testset "CGConv" begin x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, relu), - (:B, :to, :A) => CGConv(4 => 2, relu)); + layers = HeteroGraphConv( (:A, :to, :B) => CGConv(4 => 2, tanh), + (:B, :to, :A) => CGConv(4 => 2, tanh)); y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end @@ -111,8 +111,8 @@ @testset "SAGEConv" begin x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +), - (:B, :to, :A) => SAGEConv(4 => 2, relu, bias = false, aggr = +)); + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, tanh, bias = false, aggr = +), + (:B, :to, :A) => SAGEConv(4 => 2, tanh, bias = false, aggr = +)); y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end @@ -152,8 +152,8 @@ @testset "GCNConv" begin g = rand_bipartite_heterograph((2,3), 6) x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, relu), - (:B, :to, :A) => GCNConv(4 => 2, relu)); + layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), + (:B, :to, :A) => GCNConv(4 => 2, tanh)); y = layers(g, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end diff --git a/test/layers/temporalconv.jl b/test/layers/temporalconv.jl index b55aff808..45c8acf04 100644 --- a/test/layers/temporalconv.jl +++ b/test/layers/temporalconv.jl @@ -133,7 +133,7 @@ end end @testset "ResGatedGraphConv" begin - resgatedconv = ResGatedGraphConv(in_channel => out_channel, relu) + resgatedconv = ResGatedGraphConv(in_channel => out_channel, tanh) @test length(resgatedconv(tg, tg.ndata.x)) == S @test size(resgatedconv(tg, tg.ndata.x)[1]) == (out_channel, N) @test length(Flux.gradient(x ->sum(sum(resgatedconv(tg, x))), tg.ndata.x)[1]) == S @@ -147,7 +147,7 @@ end end @testset "GraphConv" begin - graphconv = GraphConv(in_channel => out_channel,relu) + graphconv = GraphConv(in_channel => out_channel, tanh) @test length(graphconv(tg, tg.ndata.x)) == S @test size(graphconv(tg, tg.ndata.x)[1]) == (out_channel, N) @test length(Flux.gradient(x ->sum(sum(graphconv(tg, x))), tg.ndata.x)[1]) == S diff --git a/test/runtests.jl b/test/runtests.jl index e41c7c1ae..b32f8541c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,7 +35,9 @@ tests = [ !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") # @testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse) -for graph_type in (:coo, :dense, :sparse) +# for graph_type in (:coo, :dense, :sparse) +for graph_type in (:dense,) + @info "Testing graph format :$graph_type" global GRAPH_T = graph_type global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse)