Skip to content

Commit

Permalink
rng instead of seed for rand_graph (#482)
Browse files Browse the repository at this point in the history
* rng instead of seed for rand_graph

* add tests

* fix tests

* rand_bipartite

* more

* relu -> tanh in tests
  • Loading branch information
CarloLucibello authored Aug 7, 2024
1 parent 3ce025b commit 6b58b75
Show file tree
Hide file tree
Showing 16 changed files with 256 additions and 114 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ Manifest.toml
.vscode
LocalPreferences.toml
.DS_Store
docs/src/democards/gridtheme.css
docs/src/democards/gridtheme.css
test.jl
2 changes: 1 addition & 1 deletion GNNGraphs/src/abstracttypes.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V <: Union{Nothing, AbstractVector}}
const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}}
const ADJMAT_T = AbstractMatrix
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
Expand Down
37 changes: 17 additions & 20 deletions GNNGraphs/src/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,24 @@ function to_coo(data::EDict; num_nodes = nothing, kws...)
graph = EDict{COO_T}()
_num_nodes = NDict{Int}()
num_edges = EDict{Int}()
if !isempty(data)
for k in keys(data)
d = data[k]
@assert d isa Tuple
if length(d) == 2
d = (d..., nothing)
end
if num_nodes !== nothing
n1 = get(num_nodes, k[1], nothing)
n2 = get(num_nodes, k[3], nothing)
else
n1 = nothing
n2 = nothing
end
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
graph[k] = g
num_edges[k] = nedges
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
for k in keys(data)
d = data[k]
@assert d isa Tuple
if length(d) == 2
d = (d..., nothing)
end
graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types
if num_nodes !== nothing
n1 = get(num_nodes, k[1], nothing)
n2 = get(num_nodes, k[3], nothing)
else
n1 = nothing
n2 = nothing
end
g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...)
graph[k] = g
num_edges[k] = nedges
_num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1])
_num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2])
end
return graph, _num_nodes, num_edges
end
Expand Down
124 changes: 84 additions & 40 deletions GNNGraphs/src/generate.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
rand_graph(n, m; bidirected=true, seed=-1, edge_weight = nothing, kws...)
rand_graph([rng,] n, m; bidirected=true, edge_weight = nothing, kws...)
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes and `m` edges.
Expand All @@ -10,7 +10,7 @@ In any case, the output graph will contain no self-loops or multi-edges.
A vector can be passed as `edge_weight`. Its length has to be equal to `m`
in the directed case, and `m÷2` in the bidirected one.
Use a `seed > 0` for reproducibility.
Pass a random number generator as the first argument to make the generation reproducible.
Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.
Expand All @@ -36,25 +36,42 @@ GNNGraph:
# Each edge has a reverse
julia> edge_index(g)
([1, 3, 3, 4], [3, 4, 1, 3])
```
"""
function rand_graph(n::Integer, m::Integer; bidirected = true, seed = -1, edge_weight = nothing, kws...)
function rand_graph(n::Integer, m::Integer; seed=-1, kws...)
if seed != -1
Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_graph)
rng = MersenneTwister(seed)
else
rng = Random.default_rng()
end
return rand_graph(rng, n, m; kws...)
end

function rand_graph(rng::AbstractRNG, n::Integer, m::Integer;
bidirected::Bool = true,
edge_weight::Union{AbstractVector, Nothing} = nothing, kws...)
if bidirected
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
@assert iseven(m) lazy"Need even number of edges for bidirected graphs, given m=$m."
s, t, _ = _rand_edges(rng, n, m ÷ 2; directed=false, self_loops=false)
s, t = vcat(s, t), vcat(t, s)
if edge_weight !== nothing
edge_weight = vcat(edge_weight, edge_weight)
end
else
s, t, _ = _rand_edges(rng, n, m; directed=true, self_loops=false)
end
m2 = bidirected ? m ÷ 2 : m
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed = !bidirected, seed); edge_weight, kws...)
return GNNGraph((s, t, edge_weight); num_nodes=n, kws...)
end

"""
rand_heterograph(n, m; seed=-1, bidirected=false, kws...)
rand_heterograph([rng,] n, m; bidirected=false, kws...)
Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges
Construct an [`GNNHeteroGraph`](@ref) with random edges and with number of nodes and edges
specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs
specifing node/edge types and their numbers.
Use a `seed > 0` for reproducibility.
Pass a random number generator as a first argument to make the generation reproducible.
Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge.
Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)`
Expand All @@ -76,17 +93,27 @@ function rand_heterograph end

# for generic iterators of pairs
rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...)
rand_heterograph(rng::AbstractRNG, n, m; kws...) = rand_heterograph(rng, Dict(n), Dict(m); kws...)

function rand_heterograph(n::NDict, m::EDict; bidirected = false, seed = -1, kws...)
rng = seed > 0 ? MersenneTwister(seed) : Random.GLOBAL_RNG
function rand_heterograph(n::NDict, m::EDict; seed=-1, kws...)
if seed != -1
Base.depwarn("Keyword argument `seed` is deprecated, pass an rng as first argument instead.", :rand_heterograph)
rng = MersenneTwister(seed)
else
rng = Random.default_rng()
end
return rand_heterograph(rng, n, m; kws...)
end

function rand_heterograph(rng::AbstractRNG, n::NDict, m::EDict; bidirected::Bool = false, kws...)
if bidirected
return _rand_bidirected_heterograph(rng, n, m; kws...)
end
graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m))
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
end

function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...)
function _rand_bidirected_heterograph(rng::AbstractRNG, n::NDict, m::EDict; kws...)
for k in keys(m)
if reverse(k) keys(m)
@assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs."
Expand All @@ -104,43 +131,60 @@ function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...)
return GNNHeteroGraph(graphs; num_nodes = n, kws...)
end

function _rand_edges(rng, (n1, n2), m)
idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false)
s, t = edge_decoding(idx, n1, n2)
val = nothing
return s, t, val
end

"""
rand_bipartite_heterograph(n1, n2, m; [bidirected, seed, node_t, edge_t, kws...])
rand_bipartite_heterograph((n1, n2), m; ...)
rand_bipartite_heterograph((n1, n2), (m1, m2); ...)
rand_bipartite_heterograph([rng,]
(n1, n2), (m12, m21);
bidirected = true,
node_t = (:A, :B),
edge_t = :to,
kws...)
Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges
specified by `n1`, `n2` and `m1` and `m2` respectively.
Construct an [`GNNHeteroGraph`](@ref) with random edges representing a bipartite graph.
The graph will have two types of nodes, and edges will only connect nodes of different types.
See [`rand_heterograph`](@ref) for a more general version.
The first argument is a tuple `(n1, n2)` specifying the number of nodes of each type.
The second argument is a tuple `(m12, m21)` specifying the number of edges connecting nodes of type `1` to nodes of type `2`
and vice versa.
# Keyword arguments
The type of nodes and edges can be specified with the `node_t` and `edge_t` keyword arguments,
which default to `(:A, :B)` and `:to` respectively.
- `bidirected`: whether to generate a bidirected graph. Default is `true`.
- `seed`: random seed. Default is `-1` (no seed).
- `node_t`: node types. If `bipartite=true`, this should be a tuple of two node types, otherwise it should be a single node type.
- `edge_t`: edge types. If `bipartite=true`, this should be a tuple of two edge types, otherwise it should be a single edge type.
"""
function rand_bipartite_heterograph end
If `bidirected=true` (default), the reverse edge of each edge will be present. In this case
`m12 == m21` is required.
A random number generator can be passed as the first argument to make the generation reproducible.
Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor.
See [`rand_heterograph`](@ref) for a more general version.
# Examples
rand_bipartite_heterograph(n1::Int, n2::Int, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...)
```julia-repl
julia> g = rand_bipartite_heterograph((10, 15), 20)
GNNHeteroGraph:
num_nodes: (:A => 10, :B => 15)
num_edges: ((:A, :to, :B) => 20, (:B, :to, :A) => 20)
rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...)
julia> g = rand_bipartite_heterograph((10, 15), (20, 0), node_t=(:user, :item), edge_t=:-, bidirected=false)
GNNHeteroGraph:
num_nodes: Dict(:item => 15, :user => 10)
num_edges: Dict((:item, :-, :user) => 0, (:user, :-, :item) => 20)
```
"""
rand_bipartite_heterograph(n, m; kws...) = rand_bipartite_heterograph(Random.default_rng(), n, m; kws...)

function rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, (m1, m2)::NTuple{2,Int}; bidirected=true,
node_t = (:A, :B), edge_t = :to, kws...)
if edge_t isa Symbol
edge_t = (edge_t, edge_t)
function rand_bipartite_heterograph(rng::AbstractRNG, (n1, n2)::NTuple{2,Int}, m; bidirected=true,
node_t = (:A, :B), edge_t::Symbol = :to, kws...)
if m isa Integer
m12 = m21 = m
else
m12, m21 = m
end
return rand_heterograph(Dict(node_t[1] => n1, node_t[2] => n2),
Dict((node_t[1], edge_t[1], node_t[2]) => m1, (node_t[2], edge_t[2], node_t[1]) => m2);

return rand_heterograph(rng, Dict(node_t[1] => n1, node_t[2] => n2),
Dict((node_t[1], edge_t, node_t[2]) => m12, (node_t[2], edge_t, node_t[1]) => m21);
bidirected, kws...)
end

Expand Down
14 changes: 9 additions & 5 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ then all new self loops will have no weight.
If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same.
This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type.
"""
function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V}
function add_self_loops(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)

function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
get(g.graph, edge_t, (nothing, nothing, nothing))[3]
end
Expand All @@ -69,13 +70,17 @@ function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where
n = get(g.num_nodes, src_t, 0)

if haskey(g.graph, edge_t)
x = g.graph[edge_t]
s, t = x[1:2]
s, t = g.graph[edge_t][1:2]
nodes = convert(typeof(s), [1:n;])
s = [s; nodes]
t = [t; nodes]
else
nodes = convert(T, [1:n;])
if !isempty(g.graph)
T = typeof(first(values(g.graph))[1])
nodes = convert(T, [1:n;])
else
nodes = [1:n;]
end
s = nodes
t = nodes
end
Expand Down Expand Up @@ -518,7 +523,6 @@ end
Return a new graph obtained from `g` by adding random edges, based on a specified `perturb_ratio`.
The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph.
These new edges are added without creating self-loops.
Optionally, a random `seed` can be provided to ensure reproducible perturbations.
The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges.
The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges.
Expand Down
Loading

0 comments on commit 6b58b75

Please sign in to comment.