Skip to content

Commit

Permalink
Added perturb_edges function (#423)
Browse files Browse the repository at this point in the history
* add edge perturbation

* add to gnngraphs

* Update src/GNNGraphs/transform.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* loop fix

* Update src/GNNGraphs/transform.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* Update src/GNNGraphs/transform.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* Update src/GNNGraphs/transform.jl

Co-authored-by: Carlo Lucibello <[email protected]>

* Update transform.jl

* Update transform.jl

* gpu compat

* include package

* Update test/GNNGraphs/transform.jl

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
rbSparky and CarloLucibello authored Jul 18, 2024
1 parent 0f8e13c commit e2623eb
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import KrylovKit
using ChainRulesCore
using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, rand_like
import Functors

include("chainrules.jl") # hacks for differentiability
Expand Down Expand Up @@ -78,6 +78,7 @@ export add_nodes,
to_bidirected,
to_unidirected,
random_walk_pe,
perturb_edges,
remove_nodes,
ppr_diffusion,
drop_nodes,
Expand Down
66 changes: 66 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,72 @@ function add_edges(g::GNNHeteroGraph{<:COO_T},
ntypes, etypes)
end

"""
perturb_edges([rng], g::GNNGraph, perturb_ratio)
Perturb the graph `g` by adding random edges, based on a specified `perturb_ratio`. The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. These new edges are added without creating self-loops. Optionally, a random `seed` can be provided to ensure reproducible perturbations.
The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges. The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges.
# Parameters
- `g::GNNGraph`: The graph to be perturbed.
- `perturb_ratio`: The ratio of the number of new edges to add relative to the current number of edges in the graph. For example, a `perturb_ratio` of 0.1 means that 10% of the current number of edges will be added as new random edges.
- `seed=123`: An optional seed for the random number generator to ensure reproducible results.
# Examples
```julia
julia> g = GNNGraph((s, t, w))
GNNGraph:
num_nodes: 4
num_edges: 5
julia> perturbed_g = perturb_edges(g, 0.2)
GNNGraph:
num_nodes: 4
num_edges: 6 # One new edge added if the original graph had 5 edges, as 0.2 of 5 is 1.
julia> perturbed_g = perturb_edges(g, 0.5, seed=42)
GNNGraph:
num_nodes: 4
num_edges: 7 # Two new edges added if the original graph had 5 edges, as 0.5 of 5 rounds to 2.
```
"""
function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; rng::AbstractRNG = Random.default_rng())
@assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1"

Random.seed!(rng)

num_current_edges = g.num_edges
num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio)

if num_edges_to_add == 0
return g
end

num_nodes = g.num_nodes
@assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges"

snew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes)
tnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes)

mask_loops = snew .!= tnew
snew = snew[mask_loops]
tnew = tnew[mask_loops]

while length(snew) < num_edges_to_add
n = num_edges_to_add - length(snew)
snewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes)
tnewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes)
mask_new_loops = snewnew .!= tnewnew
snewnew = snewnew[mask_new_loops]
tnewnew = tnewnew[mask_new_loops]
snew = [snew; snewnew]
tnew = [tnew; tnewnew]
end

return add_edges(g, (snew, tnew, nothing))
end


### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable
Expand Down
8 changes: 8 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ end
end
end

@testset "perturb_edges" begin if GRAPH_T == :coo
s, t = [1, 2, 3, 4, 5], [2, 3, 4, 5, 1]
g = GNNGraph((s, t))
rng = MersenneTwister(42)
g_per = perturb_edges(g, 0.5, rng=rng)
@test g_per.num_edges == 8
end end

@testset "remove_nodes" begin if GRAPH_T == :coo
#single node
s = [1, 1, 2, 3]
Expand Down

0 comments on commit e2623eb

Please sign in to comment.