diff --git a/GNNGraphs/src/abstracttypes.jl b/GNNGraphs/src/abstracttypes.jl index 1149157f2..73146160f 100644 --- a/GNNGraphs/src/abstracttypes.jl +++ b/GNNGraphs/src/abstracttypes.jl @@ -1,5 +1,5 @@ -const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: AbstractVector} +const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: Union{Nothing, AbstractVector}} const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}} const ADJMAT_T = AbstractMatrix const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T diff --git a/GNNGraphs/src/generate.jl b/GNNGraphs/src/generate.jl index 0b595ce33..872150103 100644 --- a/GNNGraphs/src/generate.jl +++ b/GNNGraphs/src/generate.jl @@ -1,5 +1,5 @@ """ - rand_graph(n, m; bidirected=true, seed=-1, edge_weight = nothing, kws...) + rand_graph([rng,] n, m; bidirected=true, edge_weight = nothing, kws...) Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes and `m` edges. @@ -10,7 +10,7 @@ In any case, the output graph will contain no self-loops or multi-edges. A vector can be passed as `edge_weight`. Its length has to be equal to `m` in the directed case, and `m÷2` in the bidirected one. -Use a `seed > 0` for reproducibility. +Pass a random number generator as the first argument to make the generation reproducible. Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index 8e8c98d13..3710c545f 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -518,7 +518,6 @@ end Return a new graph obtained from `g` by adding random edges, based on a specified `perturb_ratio`. The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. These new edges are added without creating self-loops. -Optionally, a random `seed` can be provided to ensure reproducible perturbations. The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges. The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges. diff --git a/GNNGraphs/test/generate.jl b/GNNGraphs/test/generate.jl index 867fec399..263afb2e9 100644 --- a/GNNGraphs/test/generate.jl +++ b/GNNGraphs/test/generate.jl @@ -16,19 +16,23 @@ @test g.edata.e[:, (m2 + 1):end] == e end - g = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T) + rng = MersenneTwister(17) + g = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T) @test g.num_nodes == n @test g.num_edges == m - g2 = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T) + rng = MersenneTwister(17) + g2 = rand_graph(rng, n, m, bidirected = false, graph_type = GRAPH_T) @test edge_index(g2) == edge_index(g) ew = rand(m2) - g = rand_graph(n, m, bidirected = true, seed = 17, graph_type = GRAPH_T, edge_weight = ew) + rng = MersenneTwister(17) + g = rand_graph(rng, n, m, bidirected = true, graph_type = GRAPH_T, edge_weight = ew) @test get_edge_weight(g) == [ew; ew] broken=(GRAPH_T != :coo) ew = rand(m) - g = rand_graph(n, m, bidirected = false, seed = 17, graph_type = GRAPH_T, edge_weight = ew) + rng = MersenneTwister(17) + g = rand_graph(n, m, bidirected = false, graph_type = GRAPH_T, edge_weight = ew) @test get_edge_weight(g) == ew broken=(GRAPH_T != :coo) end diff --git a/GNNGraphs/test/utils.jl b/GNNGraphs/test/utils.jl index ca0c25b17..0bc27d5ad 100644 --- a/GNNGraphs/test/utils.jl +++ b/GNNGraphs/test/utils.jl @@ -94,7 +94,8 @@ end @testset "color_refinment" begin - g = rand_graph(10, 20, seed=17, graph_type = GRAPH_T) + rng = MersenneTwister(17) + g = rand_graph(rng, 10, 20, graph_type = GRAPH_T) x0 = ones(Int, 10) x, ncolors, niters = color_refinement(g, x0) @test ncolors == 8