diff --git a/src/GNNGraphs/GNNGraphs.jl b/src/GNNGraphs/GNNGraphs.jl index 9238c5d06..8c98ff7ef 100644 --- a/src/GNNGraphs/GNNGraphs.jl +++ b/src/GNNGraphs/GNNGraphs.jl @@ -14,7 +14,7 @@ import KrylovKit using ChainRulesCore using LinearAlgebra, Random, Statistics import MLUtils -using MLUtils: getobs, numobs, ones_like, zeros_like, chunk +using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, rand_like import Functors include("chainrules.jl") # hacks for differentiability @@ -78,6 +78,7 @@ export add_nodes, to_bidirected, to_unidirected, random_walk_pe, + perturb_edges, remove_nodes, ppr_diffusion, drop_nodes, diff --git a/src/GNNGraphs/transform.jl b/src/GNNGraphs/transform.jl index f05d14a12..6f6aebd86 100644 --- a/src/GNNGraphs/transform.jl +++ b/src/GNNGraphs/transform.jl @@ -502,6 +502,72 @@ function add_edges(g::GNNHeteroGraph{<:COO_T}, ntypes, etypes) end +""" + perturb_edges([rng], g::GNNGraph, perturb_ratio) + +Perturb the graph `g` by adding random edges, based on a specified `perturb_ratio`. The `perturb_ratio` determines the fraction of new edges to add relative to the current number of edges in the graph. These new edges are added without creating self-loops. Optionally, a random `seed` can be provided to ensure reproducible perturbations. + +The function returns a new `GNNGraph` instance that shares some of the underlying data with `g` but includes the additional edges. The nodes for the new edges are selected randomly, and no edge data (`edata`) or weights (`w`) are assigned to these new edges. + +# Parameters +- `g::GNNGraph`: The graph to be perturbed. +- `perturb_ratio`: The ratio of the number of new edges to add relative to the current number of edges in the graph. For example, a `perturb_ratio` of 0.1 means that 10% of the current number of edges will be added as new random edges. +- `seed=123`: An optional seed for the random number generator to ensure reproducible results. + +# Examples + +```julia +julia> g = GNNGraph((s, t, w)) +GNNGraph: + num_nodes: 4 + num_edges: 5 + +julia> perturbed_g = perturb_edges(g, 0.2) +GNNGraph: + num_nodes: 4 + num_edges: 6 # One new edge added if the original graph had 5 edges, as 0.2 of 5 is 1. + +julia> perturbed_g = perturb_edges(g, 0.5, seed=42) +GNNGraph: + num_nodes: 4 + num_edges: 7 # Two new edges added if the original graph had 5 edges, as 0.5 of 5 rounds to 2. +``` +""" +function perturb_edges(g::GNNGraph{<:COO_T}, perturb_ratio::Float64; rng::AbstractRNG = Random.default_rng()) + @assert perturb_ratio >= 0 && perturb_ratio <= 1 "perturb_ratio must be between 0 and 1" + + Random.seed!(rng) + + num_current_edges = g.num_edges + num_edges_to_add = ceil(Int, num_current_edges * perturb_ratio) + + if num_edges_to_add == 0 + return g + end + + num_nodes = g.num_nodes + @assert num_nodes > 1 "Graph must contain at least 2 nodes to add edges" + + snew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes) + tnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, num_edges_to_add) .* num_nodes) + + mask_loops = snew .!= tnew + snew = snew[mask_loops] + tnew = tnew[mask_loops] + + while length(snew) < num_edges_to_add + n = num_edges_to_add - length(snew) + snewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes) + tnewnew = ceil.(Int, rand_like(rng, ones(num_nodes), Float32, n) .* num_nodes) + mask_new_loops = snewnew .!= tnewnew + snewnew = snewnew[mask_new_loops] + tnewnew = tnewnew[mask_new_loops] + snew = [snew; snewnew] + tnew = [tnew; tnewnew] + end + + return add_edges(g, (snew, tnew, nothing)) + end ### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable diff --git a/test/GNNGraphs/transform.jl b/test/GNNGraphs/transform.jl index af414bbd1..70570d155 100644 --- a/test/GNNGraphs/transform.jl +++ b/test/GNNGraphs/transform.jl @@ -177,6 +177,14 @@ end end end +@testset "perturb_edges" begin if GRAPH_T == :coo + s, t = [1, 2, 3, 4, 5], [2, 3, 4, 5, 1] + g = GNNGraph((s, t)) + rng = MersenneTwister(42) + g_per = perturb_edges(g, 0.5, rng=rng) + @test g_per.num_edges == 8 +end end + @testset "remove_nodes" begin if GRAPH_T == :coo #single node s = [1, 1, 2, 3]