Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Aug 6, 2024
1 parent 3daf120 commit 1f80cb6
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion GNNGraphs/src/abstracttypes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: AbstractVector}
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: Union{Nothing, AbstractVector}}
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}}
const ADJMAT_T = AbstractMatrix
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/src/generate.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
rand_graph(n, m; bidirected=true, seed=-1, edge_weight = nothing, kws...)
rand_graph([rng,] n, m; bidirected=true, edge_weight = nothing, kws...)
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes and `m` edges.
Expand All @@ -10,7 +10,7 @@ In any case, the output graph will contain no self-loops or multi-edges.
A vector can be passed as `edge_weight`. Its length has to be equal to `m`
in the directed case, and `m÷2` in the bidirected one.
Use a `seed > 0` for reproducibility.
Pass a random number generator as the first argument to make the generation reproducible.
Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.
Expand Down
1 change: 0 additions & 1 deletion GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,6 @@ end
Return a new graph obtained from `g` by adding random edges, based on a specified `perturb_ratio`.
The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph.
These new edges are added without creating self-loops.
Optionally, a random `seed` can be provided to ensure reproducible perturbations.
The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges.
The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges.
Expand Down
12 changes: 8 additions & 4 deletions GNNGraphs/test/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,23 @@
@test g.edata.e[:, (m2 + 1):end] == e
end

g = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T)
rng = MersenneTwister(17)
g = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T)
@test g.num_nodes == n
@test g.num_edges == m

g2 = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T)
rng = MersenneTwister(17)
g2 = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T)
@test edge_index(g2) == edge_index(g)

ew = rand(m2)
g = rand_graph(n, m, bidirected = true, seed = 17, graph_type = GRAPH_T, edge_weight = ew)
rng = MersenneTwister(17)
g = rand_graph(rng, n, m, bidirected = true, graph_type = GRAPH_T, edge_weight = ew)
@test get_edge_weight(g) == [ew; ew] broken=(GRAPH_T != :coo)

ew = rand(m)
g = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T, edge_weight = ew)
rng = MersenneTwister(17)
g = rand_graph(n, m, bidirected = false, graph_type = GRAPH_T, edge_weight = ew)
@test get_edge_weight(g) == ew broken=(GRAPH_T != :coo)
end

Expand Down
3 changes: 2 additions & 1 deletion GNNGraphs/test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@
end

@testset "color_refinment" begin
g = rand_graph(10, 20, seed=17, graph_type = GRAPH_T)
rng = MersenneTwister(17)
g = rand_graph(rng, 10, 20, graph_type = GRAPH_T)
x0 = ones(Int, 10)
x, ncolors, niters = color_refinement(g, x0)
@test ncolors == 8
Expand Down

0 comments on commit 1f80cb6

Please sign in to comment.