From 87c062b2bccf0ebf4348c3480bb2c88b5eddc348 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 7 Aug 2024 00:27:43 +0200 Subject: [PATCH] fix tests --- GNNGraphs/src/convert.jl | 37 +++++++++++++++++-------------------- GNNGraphs/src/generate.jl | 2 +- GNNGraphs/src/transform.jl | 13 +++++++++---- GNNGraphs/test/utils.jl | 4 ++-- 4 files changed, 29 insertions(+), 27 deletions(-) diff --git a/GNNGraphs/src/convert.jl b/GNNGraphs/src/convert.jl index 1e103db8b..3789309cb 100644 --- a/GNNGraphs/src/convert.jl +++ b/GNNGraphs/src/convert.jl @@ -4,27 +4,24 @@ function to_coo(data::EDict; num_nodes = nothing, kws...) graph = EDict{COO_T}() _num_nodes = NDict{Int}() num_edges = EDict{Int}() - if !isempty(data) - for k in keys(data) - d = data[k] - @assert d isa Tuple - if length(d) == 2 - d = (d..., nothing) - end - if num_nodes !== nothing - n1 = get(num_nodes, k[1], nothing) - n2 = get(num_nodes, k[3], nothing) - else - n1 = nothing - n2 = nothing - end - g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) - graph[k] = g - num_edges[k] = nedges - _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) - _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) + for k in keys(data) + d = data[k] + @assert d isa Tuple + if length(d) == 2 + d = (d..., nothing) end - graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types + if num_nodes !== nothing + n1 = get(num_nodes, k[1], nothing) + n2 = get(num_nodes, k[3], nothing) + else + n1 = nothing + n2 = nothing + end + g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) + graph[k] = g + num_edges[k] = nedges + _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) + _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) end return graph, _num_nodes, num_edges end diff --git a/GNNGraphs/src/generate.jl b/GNNGraphs/src/generate.jl index 872150103..8ee24a2cd 100644 --- a/GNNGraphs/src/generate.jl +++ b/GNNGraphs/src/generate.jl @@ -59,7 +59,7 @@ function rand_graph(rng::AbstractRNG, n::Integer, m::Integer; bidirected = true, else s, t, _ = _rand_edges(rng, n, m; directed=true, self_loops=false) end - return GNNGraph((s, t, edge_weight); kws...) + return GNNGraph((s, t, edge_weight); num_nodes=n, kws...) end """ diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 3710c545f..325a20f5c 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -57,7 +57,8 @@ then all new self loops will have no weight. If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same. This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type. """ -function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V} +function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) get(g.graph, edge_t, (nothing, nothing, nothing))[3] end @@ -69,13 +70,17 @@ function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where n = get(g.num_nodes, src_t, 0) if haskey(g.graph, edge_t) - x = g.graph[edge_t] - s, t = x[1:2] + s, t = g.graph[edge_t][1:2] nodes = convert(typeof(s), [1:n;]) s = [s; nodes] t = [t; nodes] else - nodes = convert(T, [1:n;]) + if !isempty(g.graph) + T = typeof(first(values(g.graph))[1]) + nodes = convert(T, [1:n;]) + else + nodes = [1:n;] + end s = nodes t = nodes end diff --git a/GNNGraphs/test/utils.jl b/GNNGraphs/test/utils.jl index 0bc27d5ad..31a1c7373 100644 --- a/GNNGraphs/test/utils.jl +++ b/GNNGraphs/test/utils.jl @@ -93,7 +93,7 @@ end end -@testset "color_refinment" begin +@testset "color_refinement" begin rng = MersenneTwister(17) g = rand_graph(rng, 10, 20, graph_type = GRAPH_T) x0 = ones(Int, 10) @@ -104,4 +104,4 @@ end x2, _, _ = color_refinement(g) @test x2 == x -end \ No newline at end of file +end \ No newline at end of file