Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Aug 6, 2024
1 parent 1f80cb6 commit 87c062b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 27 deletions.
37 changes: 17 additions & 20 deletions GNNGraphs/src/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,24 @@ function to_coo(data::EDict; num_nodes = nothing, kws...)
graph = EDict{COO_T}()
_num_nodes = NDict{Int}()
num_edges = EDict{Int}()
if !isempty(data)
for k in keys(data)
d = data[k]
@assert d isa Tuple
if length(d) == 2
d = (d..., nothing)
end
if num_nodes !== nothing
n1 = get(num_nodes, k[1], nothing)
n2 = get(num_nodes, k[3], nothing)
else
n1 = nothing
n2 = nothing
end
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
graph[k] = g
num_edges[k] = nedges
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
for k in keys(data)
d = data[k]
@assert d isa Tuple
if length(d) == 2
d = (d..., nothing)
end
graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types
if num_nodes !== nothing
n1 = get(num_nodes, k[1], nothing)
n2 = get(num_nodes, k[3], nothing)
else
n1 = nothing
n2 = nothing
end
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
graph[k] = g
num_edges[k] = nedges
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
end
return graph, _num_nodes, num_edges
end
Expand Down
2 changes: 1 addition & 1 deletion GNNGraphs/src/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function rand_graph(rng::AbstractRNG, n::Integer, m::Integer; bidirected = true,
else
s, t, _ = _rand_edges(rng, n, m; directed=true, self_loops=false)
end
return GNNGraph((s, t, edge_weight); kws...)
return GNNGraph((s, t, edge_weight); num_nodes=n, kws...)
end

"""
Expand Down
13 changes: 9 additions & 4 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ then all new self loops will have no weight.
If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same.
This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type.
"""
function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V}
function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)

function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
get(g.graph, edge_t, (nothing, nothing, nothing))[3]
end
Expand All @@ -69,13 +70,17 @@ function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where
n = get(g.num_nodes, src_t, 0)

if haskey(g.graph, edge_t)
x = g.graph[edge_t]
s, t = x[1:2]
s, t = g.graph[edge_t][1:2]
nodes = convert(typeof(s), [1:n;])
s = [s; nodes]
t = [t; nodes]
else
nodes = convert(T, [1:n;])
if !isempty(g.graph)
T = typeof(first(values(g.graph))[1])
nodes = convert(T, [1:n;])
else
nodes = [1:n;]
end
s = nodes
t = nodes
end
Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
end
end

@testset "color_refinment" begin
@testset "color_refinement" begin
rng = MersenneTwister(17)
g = rand_graph(rng, 10, 20, graph_type = GRAPH_T)
x0 = ones(Int, 10)
Expand All @@ -104,4 +104,4 @@ end

x2, _, _ = color_refinement(g)
@test x2 == x
end
end

0 comments on commit 87c062b

Please sign in to comment.