From f8a0270047afa9dc814557a54fce146102b4fcd6 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 4 May 2024 16:00:02 +0200 Subject: [PATCH] create GNNlib.jl --- GNNlib/Project.toml | 69 + GNNlib/README.md | 14 + GNNlib/ext/GNNlibCUDAExt/GNNGraphs/query.jl | 2 + .../ext/GNNlibCUDAExt/GNNGraphs/transform.jl | 2 + GNNlib/ext/GNNlibCUDAExt/GNNGraphs/utils.jl | 8 + GNNlib/ext/GNNlibCUDAExt/GNNlibCUDAExt.jl | 17 + GNNlib/ext/GNNlibCUDAExt/msgpass.jl | 37 + .../GNNlibSimpleWeightedGraphsExt.jl | 12 + GNNlib/src/GNNGraphs/GNNGraphs.jl | 108 ++ GNNlib/src/GNNGraphs/abstracttypes.jl | 11 + GNNlib/src/GNNGraphs/chainrules.jl | 15 + GNNlib/src/GNNGraphs/convert.jl | 240 ++++ GNNlib/src/GNNGraphs/datastore.jl | 222 ++++ GNNlib/src/GNNGraphs/gatherscatter.jl | 18 + GNNlib/src/GNNGraphs/generate.jl | 460 +++++++ GNNlib/src/GNNGraphs/gnngraph.jl | 347 +++++ GNNlib/src/GNNGraphs/gnnheterograph.jl | 299 +++++ GNNlib/src/GNNGraphs/operators.jl | 13 + GNNlib/src/GNNGraphs/query.jl | 633 +++++++++ GNNlib/src/GNNGraphs/sampling.jl | 118 ++ .../GNNGraphs/temporalsnapshotsgnngraph.jl | 244 ++++ GNNlib/src/GNNGraphs/transform.jl | 1131 +++++++++++++++++ GNNlib/src/GNNGraphs/utils.jl | 304 +++++ GNNlib/src/GNNlib.jl | 95 ++ GNNlib/src/layers/basic.jl | 3 + GNNlib/src/layers/conv.jl | 590 +++++++++ GNNlib/src/layers/pool.jl | 40 + GNNlib/src/layers/temporalconv.jl | 12 + GNNlib/src/mldatasets.jl | 41 + GNNlib/src/msgpass.jl | 259 ++++ GNNlib/src/utils.jl | 133 ++ 31 files changed, 5497 insertions(+) create mode 100644 GNNlib/Project.toml create mode 100644 GNNlib/README.md create mode 100644 GNNlib/ext/GNNlibCUDAExt/GNNGraphs/query.jl create mode 100644 GNNlib/ext/GNNlibCUDAExt/GNNGraphs/transform.jl create mode 100644 GNNlib/ext/GNNlibCUDAExt/GNNGraphs/utils.jl create mode 100644 GNNlib/ext/GNNlibCUDAExt/GNNlibCUDAExt.jl create mode 100644 GNNlib/ext/GNNlibCUDAExt/msgpass.jl create mode 100644 GNNlib/ext/GNNlibSimpleWeightedGraphsExt/GNNlibSimpleWeightedGraphsExt.jl create mode 100644 GNNlib/src/GNNGraphs/GNNGraphs.jl create mode 100644 GNNlib/src/GNNGraphs/abstracttypes.jl create mode 100644 GNNlib/src/GNNGraphs/chainrules.jl create mode 100644 GNNlib/src/GNNGraphs/convert.jl create mode 100644 GNNlib/src/GNNGraphs/datastore.jl create mode 100644 GNNlib/src/GNNGraphs/gatherscatter.jl create mode 100644 GNNlib/src/GNNGraphs/generate.jl create mode 100644 GNNlib/src/GNNGraphs/gnngraph.jl create mode 100644 GNNlib/src/GNNGraphs/gnnheterograph.jl create mode 100644 GNNlib/src/GNNGraphs/operators.jl create mode 100644 GNNlib/src/GNNGraphs/query.jl create mode 100644 GNNlib/src/GNNGraphs/sampling.jl create mode 100644 GNNlib/src/GNNGraphs/temporalsnapshotsgnngraph.jl create mode 100644 GNNlib/src/GNNGraphs/transform.jl create mode 100644 GNNlib/src/GNNGraphs/utils.jl create mode 100644 GNNlib/src/GNNlib.jl create mode 100644 GNNlib/src/layers/basic.jl create mode 100644 GNNlib/src/layers/conv.jl create mode 100644 GNNlib/src/layers/pool.jl create mode 100644 GNNlib/src/layers/temporalconv.jl create mode 100644 GNNlib/src/mldatasets.jl create mode 100644 GNNlib/src/msgpass.jl create mode 100644 GNNlib/src/utils.jl diff --git a/GNNlib/Project.toml b/GNNlib/Project.toml new file mode 100644 index 000000000..e8195c7df --- /dev/null +++ b/GNNlib/Project.toml @@ -0,0 +1,69 @@ +name = "GNNlib" +uuid = "a6a84749-d869-43f8-aacc-be26a1996e48" +authors = ["Carlo Lucibello and contributors"] +version = "0.1.0" + +[deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + +[weakdeps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" + +[extensions] +GNNlibCUDAExt = "CUDA" +GNNlibSimpleWeightedGraphsExt = "SimpleWeightedGraphs" + +[compat] +Adapt = "3, 4" +CUDA = "4, 5" +ChainRulesCore = "1" +DataStructures = "0.18" +Functors = "0.4.1" +Graphs = "1.4" +KrylovKit = "0.6, 0.7" +LinearAlgebra = "1" +MLDatasets = "0.7" +MLUtils = "0.4" +MacroTools = "0.5" +NNlib = "0.9" +NearestNeighbors = "0.4" +Random = "1" +Reexport = "1" +SimpleWeightedGraphs = "1.4.0" +SparseArrays = "1" +Statistics = "1" +StatsBase = "0.34" +cuDNN = "1" +julia = "1.9" + +[extras] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[targets] +test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"] diff --git a/GNNlib/README.md b/GNNlib/README.md new file mode 100644 index 000000000..a293a812b --- /dev/null +++ b/GNNlib/README.md @@ -0,0 +1,14 @@ +# GNNlib.jl + +This package contains a collection deep-learning framework agnostic +building blocks for graph neural networks such as graph convolutional layers and the implementation +of GraphGNN. + +In the future it will serve as the foundation of GraphNeuralNetworks.jl (based on Flux,jl). +GNNlib.jl will be to GraphNeuralNetworks.jl what NNlib.jl is to Flux.jl and Lux.jl. + +This package is currently under development and may break frequentely. +It is not meant for final users but for GNN libraries developers. +Final user should use GraphNeuralNetworks.jl instead. + + diff --git a/GNNlib/ext/GNNlibCUDAExt/GNNGraphs/query.jl b/GNNlib/ext/GNNlibCUDAExt/GNNGraphs/query.jl new file mode 100644 index 000000000..0e74f725e --- /dev/null +++ b/GNNlib/ext/GNNlibCUDAExt/GNNGraphs/query.jl @@ -0,0 +1,2 @@ + +GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1)) diff --git a/GNNlib/ext/GNNlibCUDAExt/GNNGraphs/transform.jl b/GNNlib/ext/GNNlibCUDAExt/GNNGraphs/transform.jl new file mode 100644 index 000000000..d2ee417fc --- /dev/null +++ b/GNNlib/ext/GNNlibCUDAExt/GNNGraphs/transform.jl @@ -0,0 +1,2 @@ + +GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz) diff --git a/GNNlib/ext/GNNlibCUDAExt/GNNGraphs/utils.jl b/GNNlib/ext/GNNlibCUDAExt/GNNGraphs/utils.jl new file mode 100644 index 000000000..c3d78e9c1 --- /dev/null +++ b/GNNlib/ext/GNNlibCUDAExt/GNNGraphs/utils.jl @@ -0,0 +1,8 @@ + +GNNGraphs.iscuarray(x::AnyCuArray) = true + + +function sort_edge_index(u::AnyCuArray, v::AnyCuArray) + #TODO proper cuda friendly implementation + sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu +end \ No newline at end of file diff --git a/GNNlib/ext/GNNlibCUDAExt/GNNlibCUDAExt.jl b/GNNlib/ext/GNNlibCUDAExt/GNNlibCUDAExt.jl new file mode 100644 index 000000000..bd3c919ae --- /dev/null +++ b/GNNlib/ext/GNNlibCUDAExt/GNNlibCUDAExt.jl @@ -0,0 +1,17 @@ +module GNNlibCUDAExt + +using CUDA +using Random, Statistics, LinearAlgebra +using GNNlib +using GNNlib.GNNGraphs +using GNNlib.GNNGraphs: COO_T, ADJMAT_T, SPARSE_T +import GNNlib: propagate + +const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix} + +include("GNNGraphs/query.jl") +include("GNNGraphs/transform.jl") +include("GNNGraphs/utils.jl") +include("msgpass.jl") + +end #module diff --git a/GNNlib/ext/GNNlibCUDAExt/msgpass.jl b/GNNlib/ext/GNNlibCUDAExt/msgpass.jl new file mode 100644 index 000000000..ded99bef8 --- /dev/null +++ b/GNNlib/ext/GNNlibCUDAExt/msgpass.jl @@ -0,0 +1,37 @@ + +###### PROPAGATE SPECIALIZATIONS #################### + +## COPY_XJ + +## avoid the fast path on gpu until we have better cuda support +function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), + xi, xj::AnyCuMatrix, e) + propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e) +end + +## E_MUL_XJ + +## avoid the fast path on gpu until we have better cuda support +function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), + xi, xj::AnyCuMatrix, e::AbstractVector) + propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e) +end + +## W_MUL_XJ + +## avoid the fast path on gpu until we have better cuda support +function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), + xi, xj::AnyCuMatrix, e::Nothing) + propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e) +end + +# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) +# A = adjacency_matrix(g, weighted=false) +# D = compute_degree(A) +# return xj * A * D +# end + +# # Zygote bug. Error with sparse matrix without nograd +# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2))) + +# Flux.Zygote.@nograd compute_degree diff --git a/GNNlib/ext/GNNlibSimpleWeightedGraphsExt/GNNlibSimpleWeightedGraphsExt.jl b/GNNlib/ext/GNNlibSimpleWeightedGraphsExt/GNNlibSimpleWeightedGraphsExt.jl new file mode 100644 index 000000000..48c8f0f3b --- /dev/null +++ b/GNNlib/ext/GNNlibSimpleWeightedGraphsExt/GNNlibSimpleWeightedGraphsExt.jl @@ -0,0 +1,12 @@ +module GNNlibSimpleWeightedGraphsExt + +using GNNlib +using Graphs +using SimpleWeightedGraphs + +function GNNlib.GNNGraph(g::T; kws...) where + {T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}} + return GNNGraph(g.weights, kws...) +end + +end #module \ No newline at end of file diff --git a/GNNlib/src/GNNGraphs/GNNGraphs.jl b/GNNlib/src/GNNGraphs/GNNGraphs.jl new file mode 100644 index 000000000..2e7f05207 --- /dev/null +++ b/GNNlib/src/GNNGraphs/GNNGraphs.jl @@ -0,0 +1,108 @@ +module GNNGraphs + +using SparseArrays +using Functors: @functor +import Graphs +using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, + has_self_loops, is_directed +import MLUtils +using MLUtils: getobs, numobs, ones_like, zeros_like, batch +import NearestNeighbors +import NNlib +import StatsBase +import KrylovKit +using ChainRulesCore +using LinearAlgebra, Random, Statistics +import MLUtils +import Functors + +include("chainrules.jl") # hacks for differentiability + +include("datastore.jl") +export DataStore + +include("abstracttypes.jl") +export AbstractGNNGraph + +include("gnngraph.jl") +export GNNGraph, + node_features, + edge_features, + graph_features + +include("gnnheterograph.jl") +export GNNHeteroGraph, + num_edge_types, + num_node_types, + edge_type_subgraph + +include("temporalsnapshotsgnngraph.jl") +export TemporalSnapshotsGNNGraph, + add_snapshot, + # add_snapshot!, + remove_snapshot + # remove_snapshot! + +include("query.jl") +export adjacency_list, + edge_index, + get_edge_weight, + graph_indicator, + has_multi_edges, + is_directed, + is_bidirected, + normalized_laplacian, + scaled_laplacian, + laplacian_lambda_max, +# from Graphs + adjacency_matrix, + degree, + has_self_loops, + has_isolated_nodes, + inneighbors, + outneighbors, + khop_adj + +include("transform.jl") +export add_nodes, + add_edges, + add_self_loops, + getgraph, + negative_sample, + rand_edge_split, + remove_self_loops, + remove_edges, + remove_multi_edges, + set_edge_weight, + to_bidirected, + to_unidirected, + random_walk_pe, + remove_nodes, +# from Flux + batch, + unbatch, +# from SparseArrays + blockdiag + +include("generate.jl") +export rand_graph, + rand_heterograph, + rand_bipartite_heterograph, + knn_graph, + radius_graph, + rand_temporal_radius_graph, + rand_temporal_hyperbolic_graph + +include("sampling.jl") +export sample_neighbors + +include("operators.jl") +# Base.intersect + +include("convert.jl") +include("utils.jl") + +include("gatherscatter.jl") +# _gather, _scatter + +end #module diff --git a/GNNlib/src/GNNGraphs/abstracttypes.jl b/GNNlib/src/GNNGraphs/abstracttypes.jl new file mode 100644 index 000000000..b8959b807 --- /dev/null +++ b/GNNlib/src/GNNGraphs/abstracttypes.jl @@ -0,0 +1,11 @@ + +const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V} +const ADJLIST_T = AbstractVector{T} where {T <: AbstractVector{<:Integer}} +const ADJMAT_T = AbstractMatrix +const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T + +const AVecI = AbstractVector{<:Integer} + +# All concrete graph types should be subtypes of AbstractGNNGraph{T}. +# GNNGraph and GNNHeteroGraph are the two concrete types. +abstract type AbstractGNNGraph{T} <: AbstractGraph{Int} end diff --git a/GNNlib/src/GNNGraphs/chainrules.jl b/GNNlib/src/GNNGraphs/chainrules.jl new file mode 100644 index 000000000..6ef0b65aa --- /dev/null +++ b/GNNlib/src/GNNGraphs/chainrules.jl @@ -0,0 +1,15 @@ +# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648 +# Remove when merged + +function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict} + ks = map(first, ps) + project_ks, project_vs = map(ProjectTo, ks), map(ProjectTo∘last, ps) + function Dict_pullback(ȳ) + dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v + dk, dv = proj_k(getkey(ȳ, k, NoTangent())), proj_v(get(ȳ, k, NoTangent())) + Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv) + end + return (NoTangent(), dps...) + end + return T(ps...), Dict_pullback +end diff --git a/GNNlib/src/GNNGraphs/convert.jl b/GNNlib/src/GNNGraphs/convert.jl new file mode 100644 index 000000000..1e103db8b --- /dev/null +++ b/GNNlib/src/GNNGraphs/convert.jl @@ -0,0 +1,240 @@ +### CONVERT_TO_COO REPRESENTATION ######## + +function to_coo(data::EDict; num_nodes = nothing, kws...) + graph = EDict{COO_T}() + _num_nodes = NDict{Int}() + num_edges = EDict{Int}() + if !isempty(data) + for k in keys(data) + d = data[k] + @assert d isa Tuple + if length(d) == 2 + d = (d..., nothing) + end + if num_nodes !== nothing + n1 = get(num_nodes, k[1], nothing) + n2 = get(num_nodes, k[3], nothing) + else + n1 = nothing + n2 = nothing + end + g, nnodes, nedges = to_coo(d; hetero = true, num_nodes = (n1, n2), kws...) + graph[k] = g + num_edges[k] = nedges + _num_nodes[k[1]] = max(get(_num_nodes, k[1], 0), nnodes[1]) + _num_nodes[k[3]] = max(get(_num_nodes, k[3], 0), nnodes[2]) + end + graph = Dict([k => v for (k, v) in pairs(graph)]...) # try to restrict the key/value types + end + return graph, _num_nodes, num_edges +end + +function to_coo(coo::COO_T; dir = :out, num_nodes = nothing, weighted = true, + hetero = false) + s, t, val = coo + + if isnothing(num_nodes) + ns = maximum(s) + nt = maximum(t) + num_nodes = hetero ? (ns, nt) : max(ns, nt) + elseif num_nodes isa Integer + ns = num_nodes + nt = num_nodes + elseif num_nodes isa Tuple + ns = isnothing(num_nodes[1]) ? maximum(s) : num_nodes[1] + nt = isnothing(num_nodes[2]) ? maximum(t) : num_nodes[2] + num_nodes = (ns, nt) + else + error("Invalid num_nodes $num_nodes") + end + @assert isnothing(val) || length(val) == length(s) + @assert length(s) == length(t) + if !isempty(s) + @assert minimum(s) >= 1 + @assert minimum(t) >= 1 + @assert maximum(s) <= ns + @assert maximum(t) <= nt + end + num_edges = length(s) + if !weighted + coo = (s, t, nothing) + end + return coo, num_nodes, num_edges +end + +function to_coo(A::SPARSE_T; dir = :out, num_nodes = nothing, weighted = true) + s, t, v = findnz(A) + if dir == :in + s, t = t, s + end + num_nodes = isnothing(num_nodes) ? size(A, 1) : num_nodes + num_edges = length(s) + if !weighted + v = nothing + end + return (s, t, v), num_nodes, num_edges +end + +function _findnz_idx(A) + nz = findall(!=(0), A) # vec of cartesian indexes + s, t = ntuple(i -> map(t -> t[i], nz), 2) + return s, t, nz +end + +@non_differentiable _findnz_idx(A) + +function to_coo(A::ADJMAT_T; dir = :out, num_nodes = nothing, weighted = true) + s, t, nz = _findnz_idx(A) + v = A[nz] + if dir == :in + s, t = t, s + end + num_nodes = isnothing(num_nodes) ? size(A, 1) : num_nodes + num_edges = length(s) + if !weighted + v = nothing + end + return (s, t, v), num_nodes, num_edges +end + +function to_coo(adj_list::ADJLIST_T; dir = :out, num_nodes = nothing, weighted = true) + @assert dir ∈ [:out, :in] + num_nodes = length(adj_list) + num_edges = sum(length.(adj_list)) + @assert num_nodes > 0 + s = similar(adj_list[1], eltype(adj_list[1]), num_edges) + t = similar(adj_list[1], eltype(adj_list[1]), num_edges) + e = 0 + for i in 1:num_nodes + for j in adj_list[i] + e += 1 + s[e] = i + t[e] = j + end + end + @assert e == num_edges + if dir == :in + s, t = t, s + end + (s, t, nothing), num_nodes, num_edges +end + +### CONVERT TO ADJACENCY MATRIX ################ + +### DENSE #################### + +to_dense(A::AbstractSparseMatrix, x...; kws...) = to_dense(collect(A), x...; kws...) + +function to_dense(A::ADJMAT_T, T = nothing; dir = :out, num_nodes = nothing, + weighted = true) + @assert dir ∈ [:out, :in] + T = T === nothing ? eltype(A) : T + num_nodes = size(A, 1) + @assert num_nodes == size(A, 2) + # @assert all(x -> (x == 1) || (x == 0), A) + num_edges = numnonzeros(A) + if dir == :in + A = A' + end + if T != eltype(A) + A = T.(A) + end + if !weighted + A = map(x -> ifelse(x > 0, T(1), T(0)), A) + end + return A, num_nodes, num_edges +end + +function to_dense(adj_list::ADJLIST_T, T = nothing; dir = :out, num_nodes = nothing, + weighted = true) + @assert dir ∈ [:out, :in] + num_nodes = length(adj_list) + num_edges = sum(length.(adj_list)) + @assert num_nodes > 0 + T = T === nothing ? eltype(adj_list[1]) : T + A = fill!(similar(adj_list[1], T, (num_nodes, num_nodes)), 0) + if dir == :out + for (i, neigs) in enumerate(adj_list) + A[i, neigs] .= 1 + end + else + for (i, neigs) in enumerate(adj_list) + A[neigs, i] .= 1 + end + end + A, num_nodes, num_edges +end + +function to_dense(coo::COO_T, T = nothing; dir = :out, num_nodes = nothing, weighted = true) + # `dir` will be ignored since the input `coo` is always in source -> target format. + # The output will always be a adjmat in :out format (e.g. A[i,j] denotes from i to j) + s, t, val = coo + n::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes + if T === nothing + T = isnothing(val) ? eltype(s) : eltype(val) + end + if val === nothing || !weighted + val = ones_like(s, T) + end + if eltype(val) != T + val = T.(val) + end + + idxs = s .+ n .* (t .- 1) + + ## using scatter instead of indexing since there could be multiple edges + # A = fill!(similar(s, T, (n, n)), 0) + # v = vec(A) # vec view of A + # A[idxs] .= val # exploiting linear indexing + v = NNlib.scatter(+, val, idxs, dstsize = n^2) + A = reshape(v, (n, n)) + return A, n, length(s) +end + +### SPARSE ############# + +function to_sparse(A::ADJMAT_T, T = nothing; dir = :out, num_nodes = nothing, + weighted = true) + @assert dir ∈ [:out, :in] + num_nodes = size(A, 1) + @assert num_nodes == size(A, 2) + T = T === nothing ? eltype(A) : T + num_edges = A isa AbstractSparseMatrix ? nnz(A) : count(!=(0), A) + if dir == :in + A = A' + end + if T != eltype(A) + A = T.(A) + end + if !(A isa AbstractSparseMatrix) + A = sparse(A) + end + if !weighted + A = map(x -> ifelse(x > 0, T(1), T(0)), A) + end + return A, num_nodes, num_edges +end + +function to_sparse(adj_list::ADJLIST_T, T = nothing; dir = :out, num_nodes = nothing, + weighted = true) + coo, num_nodes, num_edges = to_coo(adj_list; dir, num_nodes) + return to_sparse(coo; num_nodes) +end + +function to_sparse(coo::COO_T, T = nothing; dir = :out, num_nodes = nothing, + weighted = true) + s, t, eweight = coo + T = T === nothing ? (eweight === nothing ? eltype(s) : eltype(eweight)) : T + + if eweight === nothing || !weighted + eweight = fill!(similar(s, T), 1) + end + + num_nodes::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes + A = sparse(s, t, eweight, num_nodes, num_nodes) + num_edges::Int = nnz(A) + if eltype(A) != T + A = T.(A) + end + return A, num_nodes, num_edges +end diff --git a/GNNlib/src/GNNGraphs/datastore.jl b/GNNlib/src/GNNGraphs/datastore.jl new file mode 100644 index 000000000..5fe80b5f0 --- /dev/null +++ b/GNNlib/src/GNNGraphs/datastore.jl @@ -0,0 +1,222 @@ +""" + DataStore([n, data]) + DataStore([n,] k1 = x1, k2 = x2, ...) + +A container for feature arrays. The optional argument `n` enforces that +`numobs(x) == n` for each array contained in the datastore. + +At construction time, the `data` can be provided as any iterables of pairs +of symbols and arrays or as keyword arguments: + +```jldoctest +julia> ds = DataStore(3, x = rand(Float32, 2, 3), y = rand(Float32, 3)) +DataStore(3) with 2 elements: + y = 3-element Vector{Float32} + x = 2×3 Matrix{Float32} + +julia> ds = DataStore(3, Dict(:x => rand(Float32, 2, 3), :y => rand(Float32, 3))); # equivalent to above + +julia> ds = DataStore(3, (x = rand(Float32, 2, 3), y = rand(Float32, 30))) +ERROR: AssertionError: DataStore: data[y] has 30 observations, but n = 3 +Stacktrace: + [1] DataStore(n::Int64, data::Dict{Symbol, Any}) + @ GNNlib.GNNGraphs ~/.julia/dev/GNNlib/src/GNNGraphs/datastore.jl:54 + [2] DataStore(n::Int64, data::NamedTuple{(:x, :y), Tuple{Matrix{Float32}, Vector{Float32}}}) + @ GNNlib.GNNGraphs ~/.julia/dev/GNNlib/src/GNNGraphs/datastore.jl:73 + [3] top-level scope + @ REPL[13]:1 + +julia> ds = DataStore(x = randFloat32, 2, 3), y = rand(Float32, 30)) # no checks +DataStore() with 2 elements: + y = 30-element Vector{Float32} + x = 2×3 Matrix{Float32} +``` + +The `DataStore` has an interface similar to both dictionaries and named tuples. +Arrays can be accessed and added using either the indexing or the property syntax: + +```jldoctest +julia> ds = DataStore(x = ones(Float32, 2, 3), y = zeros(Float32, 3)) +DataStore() with 2 elements: + y = 3-element Vector{Float32} + x = 2×3 Matrix{Float32} + +julia> ds.x # same as `ds[:x]` +2×3 Matrix{Float32}: + 1.0 1.0 1.0 + 1.0 1.0 1.0 + +julia> ds.z = zeros(Float32, 3) # Add new feature array `z`. Same as `ds[:z] = rand(Float32, 3)` +3-element Vector{Float64}: +0.0 +0.0 +0.0 +``` + +The `DataStore` can be iterated over, and the keys and values can be accessed +using `keys(ds)` and `values(ds)`. `map(f, ds)` applies the function `f` +to each feature array: + +```jldoctest +julia> ds = DataStore(a = zeros(2), b = zeros(2)); + +julia> ds2 = map(x -> x .+ 1, ds) + +julia> ds2.a +2-element Vector{Float64}: + 1.0 + 1.0 +``` +""" +struct DataStore + _n::Int # either -1 or numobs(data) + _data::Dict{Symbol, Any} + + function DataStore(n::Int, data::Dict{Symbol, Any}) + if n >= 0 + for (k, v) in data + @assert numobs(v)==n "DataStore: data[$k] has $(numobs(v)) observations, but n = $n" + end + end + return new(n, data) + end +end + +@functor DataStore + +DataStore(data) = DataStore(-1, data) +DataStore(n::Int, data::NamedTuple) = DataStore(n, Dict{Symbol, Any}(pairs(data))) +DataStore(n::Int, data) = DataStore(n, Dict{Symbol, Any}(data)) + +DataStore(; kws...) = DataStore(-1; kws...) +DataStore(n::Int; kws...) = DataStore(n, Dict{Symbol, Any}(kws...)) + +getdata(ds::DataStore) = getfield(ds, :_data) +getn(ds::DataStore) = getfield(ds, :_n) +# setn!(ds::DataStore, n::Int) = setfield!(ds, :n, n) + +function Base.getproperty(ds::DataStore, s::Symbol) + if s === :_n + return getn(ds) + elseif s === :_data + return getdata(ds) + else + return getdata(ds)[s] + end +end + +function Base.getproperty(vds::Vector{DataStore}, s::Symbol) + if s === :_n + return [getn(ds) for ds in vds] + elseif s === :_data + return [getdata(ds) for ds in vds] + else + return [getdata(ds)[s] for ds in vds] + end +end + +function Base.setproperty!(ds::DataStore, s::Symbol, x) + @assert s != :_n "cannot set _n directly" + @assert s != :_data "cannot set _data directly" + if getn(ds) >= 0 + numobs(x) == getn(ds) || throw(DimensionMismatch("expected $(getn(ds)) object features but got $(numobs(x)).")) + end + return getdata(ds)[s] = x +end + +Base.getindex(ds::DataStore, s::Symbol) = getproperty(ds, s) +Base.setindex!(ds::DataStore, x, s::Symbol) = setproperty!(ds, s, x) + +function Base.show(io::IO, ds::DataStore) + len = length(ds) + n = getn(ds) + if n < 0 + print(io, "DataStore()") + else + print(io, "DataStore($(getn(ds)))") + end + if len > 0 + print(io, " with $(length(getdata(ds))) element") + len > 1 && print(io, "s") + print(io, ":") + for (k, v) in getdata(ds) + print(io, "\n $(k) = $(summary(v))") + end + else + print(io, " with no elements") + end +end + +Base.iterate(ds::DataStore) = iterate(getdata(ds)) +Base.iterate(ds::DataStore, state) = iterate(getdata(ds), state) +Base.keys(ds::DataStore) = keys(getdata(ds)) +Base.values(ds::DataStore) = values(getdata(ds)) +Base.length(ds::DataStore) = length(getdata(ds)) +Base.haskey(ds::DataStore, k) = haskey(getdata(ds), k) +Base.get(ds::DataStore, k, default) = get(getdata(ds), k, default) +Base.pairs(ds::DataStore) = pairs(getdata(ds)) +Base.:(==)(ds1::DataStore, ds2::DataStore) = getdata(ds1) == getdata(ds2) +Base.isempty(ds::DataStore) = isempty(getdata(ds)) +Base.delete!(ds::DataStore, k) = delete!(getdata(ds), k) + +function Base.map(f, ds::DataStore) + d = getdata(ds) + newd = Dict{Symbol, Any}(k => f(v) for (k, v) in d) + return DataStore(getn(ds), newd) +end + +MLUtils.numobs(ds::DataStore) = numobs(getdata(ds)) + +function MLUtils.getobs(ds::DataStore, i::Int) + newdata = getobs(getdata(ds), i) + return DataStore(-1, newdata) +end + +function MLUtils.getobs(ds::DataStore, + i::AbstractVector{T}) where {T <: Union{Integer, Bool}} + newdata = getobs(getdata(ds), i) + n = getn(ds) + if n >= 0 + if length(ds) > 0 + n = numobs(newdata) + else + # if newdata is empty, then we can't get the number of observations from it + n = T == Bool ? sum(i) : length(i) + end + end + if !(newdata isa Dict{Symbol, Any}) + newdata = Dict{Symbol, Any}(newdata) + end + return DataStore(n, newdata) +end + +function cat_features(ds1::DataStore, ds2::DataStore) + n1, n2 = getn(ds1), getn(ds2) + n1 = n1 >= 0 ? n1 : 1 + n2 = n2 >= 0 ? n2 : 1 + return DataStore(n1 + n2, cat_features(getdata(ds1), getdata(ds2))) +end + +function cat_features(dss::AbstractVector{DataStore}; kws...) + ns = getn.(dss) + ns = map(n -> n >= 0 ? n : 1, ns) + return DataStore(sum(ns), cat_features(getdata.(dss); kws...)) +end + +# DataStore is always already normalized +normalize_graphdata(ds::DataStore; kws...) = ds + +_gather(x::DataStore, i) = map(x -> _gather(x, i), x) + +function _scatter(aggr, src::DataStore, idx, n) + newdata = _scatter(aggr, getdata(src), idx, n) + if !(newdata isa Dict{Symbol, Any}) + newdata = Dict{Symbol, Any}(newdata) + end + return DataStore(n, newdata) +end + +function Base.hash(ds::D, h::UInt) where {D <: DataStore} + fs = (getfield(ds, k) for k in fieldnames(D)) + return foldl((h, f) -> hash(f, h), fs, init = hash(D, h)) +end diff --git a/GNNlib/src/GNNGraphs/gatherscatter.jl b/GNNlib/src/GNNGraphs/gatherscatter.jl new file mode 100644 index 000000000..e897399ed --- /dev/null +++ b/GNNlib/src/GNNGraphs/gatherscatter.jl @@ -0,0 +1,18 @@ +_gather(x::NamedTuple, i) = map(x -> _gather(x, i), x) +_gather(x::Dict, i) = Dict([k => _gather(v, i) for (k, v) in x]...) +_gather(x::Tuple, i) = map(x -> _gather(x, i), x) +_gather(x::AbstractArray, i) = NNlib.gather(x, i) +_gather(x::Nothing, i) = nothing + +_scatter(aggr, src::Nothing, idx, n) = nothing +_scatter(aggr, src::NamedTuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src) +_scatter(aggr, src::Tuple, idx, n) = map(s -> _scatter(aggr, s, idx, n), src) +_scatter(aggr, src::Dict, idx, n) = Dict([k => _scatter(aggr, v, idx, n) for (k, v) in src]...) + +function _scatter(aggr, + src::AbstractArray, + idx::AbstractVector{<:Integer}, + n::Integer) + dstsize = (size(src)[1:(end - 1)]..., n) + return NNlib.scatter(aggr, src, idx; dstsize) +end diff --git a/GNNlib/src/GNNGraphs/generate.jl b/GNNlib/src/GNNGraphs/generate.jl new file mode 100644 index 000000000..07a31c0da --- /dev/null +++ b/GNNlib/src/GNNGraphs/generate.jl @@ -0,0 +1,460 @@ +""" + rand_graph(n, m; bidirected=true, seed=-1, edge_weight = nothing, kws...) + +Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes and `m` edges. + +If `bidirected=true` the reverse edge of each edge will be present. +If `bidirected=false` instead, `m` unrelated edges are generated. +In any case, the output graph will contain no self-loops or multi-edges. + +A vector can be passed as `edge_weight`. Its length has to be equal to `m` +in the directed case, and `m÷2` in the bidirected one. + +Use a `seed > 0` for reproducibility. + +Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. + +# Examples + +```jldoctest +julia> g = rand_graph(5, 4, bidirected=false) +GNNGraph: + num_nodes = 5 + num_edges = 4 + +julia> edge_index(g) +([1, 3, 3, 4], [5, 4, 5, 2]) + +# In the bidirected case, edge data will be duplicated on the reverse edges if needed. +julia> g = rand_graph(5, 4, edata=rand(Float32, 16, 2)) +GNNGraph: + num_nodes = 5 + num_edges = 4 + edata: + e => (16, 4) + +# Each edge has a reverse +julia> edge_index(g) +([1, 3, 3, 4], [3, 4, 1, 3]) + +``` +""" +function rand_graph(n::Integer, m::Integer; bidirected = true, seed = -1, edge_weight = nothing, kws...) + if bidirected + @assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m." + end + m2 = bidirected ? m ÷ 2 : m + return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed = !bidirected, seed); edge_weight, kws...) +end + +""" + rand_heterograph(n, m; seed=-1, bidirected=false, kws...) + +Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges +specified by `n` and `m` respectively. `n` and `m` can be any iterable of pairs +specifing node/edge types and their numbers. + +Use a `seed > 0` for reproducibility. + +Setting `bidirected=true` will generate a bidirected graph, i.e. each edge will have a reverse edge. +Therefore, for each edge type `(:A, :rel, :B)` a corresponding reverse edge type `(:B, :rel, :A)` +will be generated. + +Additional keyword arguments will be passed to the [`GNNHeteroGraph`](@ref) constructor. + +# Examples + +```jldoctest +julia> g = rand_heterograph((:user => 10, :movie => 20), + (:user, :rate, :movie) => 30) +GNNHeteroGraph: + num_nodes: (:user => 10, :movie => 20) + num_edges: ((:user, :rate, :movie) => 30,) +``` +""" +function rand_heterograph end + +# for generic iterators of pairs +rand_heterograph(n, m; kws...) = rand_heterograph(Dict(n), Dict(m); kws...) + +function rand_heterograph(n::NDict, m::EDict; bidirected = false, seed = -1, kws...) + rng = seed > 0 ? MersenneTwister(seed) : Random.GLOBAL_RNG + if bidirected + return _rand_bidirected_heterograph(rng, n, m; kws...) + end + graphs = Dict(k => _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) for k in keys(m)) + return GNNHeteroGraph(graphs; num_nodes = n, kws...) +end + +function _rand_bidirected_heterograph(rng, n::NDict, m::EDict; kws...) + for k in keys(m) + if reverse(k) ∈ keys(m) + @assert m[k] == m[reverse(k)] "Number of edges must be the same in reverse edge types for bidirected graphs." + else + m[reverse(k)] = m[k] + end + end + graphs = Dict{EType, Tuple{Vector{Int}, Vector{Int}, Nothing}}() + for k in keys(m) + reverse(k) ∈ keys(graphs) && continue + s, t, val = _rand_edges(rng, (n[k[1]], n[k[3]]), m[k]) + graphs[k] = s, t, val + graphs[reverse(k)] = t, s, val + end + return GNNHeteroGraph(graphs; num_nodes = n, kws...) +end + +function _rand_edges(rng, (n1, n2), m) + idx = StatsBase.sample(rng, 1:(n1 * n2), m, replace = false) + s, t = edge_decoding(idx, n1, n2) + val = nothing + return s, t, val +end + +""" + rand_bipartite_heterograph(n1, n2, m; [bidirected, seed, node_t, edge_t, kws...]) + rand_bipartite_heterograph((n1, n2), m; ...) + rand_bipartite_heterograph((n1, n2), (m1, m2); ...) + +Construct an [`GNNHeteroGraph`](@ref) with number of nodes and edges +specified by `n1`, `n2` and `m1` and `m2` respectively. + +See [`rand_heterograph`](@ref) for a more general version. + +# Keyword arguments + +- `bidirected`: whether to generate a bidirected graph. Default is `true`. +- `seed`: random seed. Default is `-1` (no seed). +- `node_t`: node types. If `bipartite=true`, this should be a tuple of two node types, otherwise it should be a single node type. +- `edge_t`: edge types. If `bipartite=true`, this should be a tuple of two edge types, otherwise it should be a single edge type. +""" +function rand_bipartite_heterograph end + +rand_bipartite_heterograph(n1::Int, n2::Int, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...) + +rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, m::Int; kws...) = rand_bipartite_heterograph((n1, n2), (m, m); kws...) + +function rand_bipartite_heterograph((n1, n2)::NTuple{2,Int}, (m1, m2)::NTuple{2,Int}; bidirected=true, + node_t = (:A, :B), edge_t = :to, kws...) + if edge_t isa Symbol + edge_t = (edge_t, edge_t) + end + return rand_heterograph(Dict(node_t[1] => n1, node_t[2] => n2), + Dict((node_t[1], edge_t[1], node_t[2]) => m1, (node_t[2], edge_t[2], node_t[1]) => m2); + bidirected, kws...) +end + +""" + knn_graph(points::AbstractMatrix, + k::Int; + graph_indicator = nothing, + self_loops = false, + dir = :in, + kws...) + +Create a `k`-nearest neighbor graph where each node is linked +to its `k` closest `points`. + +# Arguments + +- `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes. +- `k`: The number of neighbors considered in the kNN algorithm. +- `graph_indicator`: Either nothing or a vector containing the graph assignment of each node, + in which case the returned graph will be a batch of graphs. +- `self_loops`: If `true`, consider the node itself among its `k` nearest neighbors, in which + case the graph will contain self-loops. +- `dir`: The direction of the edges. If `dir=:in` edges go from the `k` + neighbors to the central node. If `dir=:out` we have the opposite + direction. +- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. + +# Examples + +```jldoctest +julia> n, k = 10, 3; + +julia> x = rand(Float32, 3, n); + +julia> g = knn_graph(x, k) +GNNGraph: + num_nodes = 10 + num_edges = 30 + +julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2]; + +julia> g = knn_graph(x, k; graph_indicator) +GNNGraph: + num_nodes = 10 + num_edges = 30 + num_graphs = 2 + +``` +""" +function knn_graph(points::AbstractMatrix, k::Int; + graph_indicator = nothing, + self_loops = false, + dir = :in, + kws...) + if graph_indicator !== nothing + d, n = size(points) + @assert graph_indicator isa AbstractVector{<:Integer} + @assert length(graph_indicator) == n + # All graphs in the batch must have at least k nodes. + cm = StatsBase.countmap(graph_indicator) + @assert all(values(cm) .>= k) + + # Make sure that the distance between points in different graphs + # is always larger than any distance within the same graph. + points = points .- minimum(points) + points = points ./ maximum(points) + dummy_feature = 2d .* reshape(graph_indicator, 1, n) + points = vcat(points, dummy_feature) + end + + kdtree = NearestNeighbors.KDTree(points) + if !self_loops + k += 1 + end + sortres = false + idxs, dists = NearestNeighbors.knn(kdtree, points, k, sortres) + + g = GNNGraph(idxs; dir, graph_indicator, kws...) + if !self_loops + g = remove_self_loops(g) + end + return g +end + +""" + radius_graph(points::AbstractMatrix, + r::AbstractFloat; + graph_indicator = nothing, + self_loops = false, + dir = :in, + kws...) + +Create a graph where each node is linked +to its neighbors within a given distance `r`. + +# Arguments + +- `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes. +- `r`: The radius. +- `graph_indicator`: Either nothing or a vector containing the graph assignment of each node, + in which case the returned graph will be a batch of graphs. +- `self_loops`: If `true`, consider the node itself among its neighbors, in which + case the graph will contain self-loops. +- `dir`: The direction of the edges. If `dir=:in` edges go from the + neighbors to the central node. If `dir=:out` we have the opposite + direction. +- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor. + +# Examples + +```jldoctest +julia> n, r = 10, 0.75; + +julia> x = Float32, 3, n); + +julia> g = radius_graph(x, r) +GNNGraph: + num_nodes = 10 + num_edges = 46 + +julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2]; + +julia> g = radius_graph(x, r; graph_indicator) +GNNGraph: + num_nodes = 10 + num_edges = 20 + num_graphs = 2 + +``` +# References +Section B paragraphs 1 and 2 of the paper [Dynamic Hidden-Variable Network Models](https://arxiv.org/pdf/2101.00414.pdf) +""" +function radius_graph(points::AbstractMatrix, r::AbstractFloat; + graph_indicator = nothing, + self_loops = false, + dir = :in, + kws...) + if graph_indicator !== nothing + d, n = size(points) + @assert graph_indicator isa AbstractVector{<:Integer} + @assert length(graph_indicator) == n + + # Make sure that the distance between points in different graphs + # is always larger than r. + dummy_feature = 2r .* reshape(graph_indicator, 1, n) + points = vcat(points, dummy_feature) + end + + balltree = NearestNeighbors.BallTree(points) + + sortres = false + idxs = NearestNeighbors.inrange(balltree, points, r, sortres) + + g = GNNGraph(idxs; dir, graph_indicator, kws...) + if !self_loops + g = remove_self_loops(g) + end + return g +end + +""" + rand_temporal_radius_graph(number_nodes::Int, + number_snapshots::Int, + speed::AbstractFloat, + r::AbstractFloat; + self_loops = false, + dir = :in, + kws...) + +Create a random temporal graph given `number_nodes` nodes and `number_snapshots` snapshots. +First, the positions of the nodes are randomly generated in the unit square. Two nodes are connected if their distance is less than a given radius `r`. +Each following snapshot is obtained by applying the same construction to new positions obtained as follows. +For each snapshot, the new positions of the points are determined by applying random independent displacement vectors to the previous positions. The direction of the displacement is chosen uniformly at random and its length is chosen uniformly in `[0, speed]`. Then the connections are recomputed. +If a point happens to move outside the boundary, its position is updated as if it had bounced off the boundary. + +# Arguments + +- `number_nodes`: The number of nodes of each snapshot. +- `number_snapshots`: The number of snapshots. +- `speed`: The speed to update the nodes. +- `r`: The radius of connection. +- `self_loops`: If `true`, consider the node itself among its neighbors, in which + case the graph will contain self-loops. +- `dir`: The direction of the edges. If `dir=:in` edges go from the + neighbors to the central node. If `dir=:out` we have the opposite + direction. +- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor of each snapshot. + +# Example + +```jldoctest +julia> n, snaps, s, r = 10, 5, 0.1, 1.5; + +julia> tg = rand_temporal_radius_graph(n,snaps,s,r) # complete graph at each snapshot +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10] + num_edges: [90, 90, 90, 90, 90] + num_snapshots: 5 +``` + +""" +function rand_temporal_radius_graph(number_nodes::Int, + number_snapshots::Int, + speed::AbstractFloat, + r::AbstractFloat; + self_loops = false, + dir = :in, + kws...) + points=rand(2, number_nodes) + tg = Vector{GNNGraph}(undef, number_snapshots) + for t in 1:number_snapshots + tg[t] = radius_graph(points, r; graph_indicator = nothing, self_loops, dir, kws...) + for i in 1:number_nodes + ρ = 2 * speed * rand() - speed + theta=2*pi*rand() + points[1,i]=1-abs(1-(abs(points[1,i]+ρ*cos(theta)))) + points[2,i]=1-abs(1-(abs(points[2,i]+ρ*sin(theta)))) + end + end + return TemporalSnapshotsGNNGraph(tg) +end + + +function _hyperbolic_distance(nodeA::Array{Float64, 1},nodeB::Array{Float64, 1}; ζ::Real) + if nodeA != nodeB + a = cosh(ζ * nodeA[1]) * cosh(ζ * nodeB[1]) + b = sinh(ζ * nodeA[1]) * sinh(ζ * nodeB[1]) + c = cos(pi - abs(pi - abs(nodeA[2] - nodeB[2]))) + d = acosh(a - (b * c)) / ζ + else + d = 0.0 + end + return d +end + +""" + rand_temporal_hyperbolic_graph(number_nodes::Int, + number_snapshots::Int; + α::Real, + R::Real, + speed::Real, + ζ::Real=1, + self_loop = false, + kws...) + +Create a random temporal graph given `number_nodes` nodes and `number_snapshots` snapshots. +First, the positions of the nodes are generated with a quasi-uniform distribution (depending on the parameter `α`) in hyperbolic space within a disk of radius `R`. Two nodes are connected if their hyperbolic distance is less than `R`. Each following snapshot is created in order to keep the same initial distribution. + +# Arguments + +- `number_nodes`: The number of nodes of each snapshot. +- `number_snapshots`: The number of snapshots. +- `α`: The parameter that controls the position of the points. If `α=ζ`, the points are uniformly distributed on the disk of radius `R`. If `α>ζ`, the points are more concentrated in the center of the disk. If `α<ζ`, the points are more concentrated at the boundary of the disk. +- `R`: The radius of the disk and of connection. +- `speed`: The speed to update the nodes. +- `ζ`: The parameter that controls the curvature of the disk. +- `self_loops`: If `true`, consider the node itself among its neighbors, in which + case the graph will contain self-loops. +- `kws`: Further keyword arguments will be passed to the [`GNNGraph`](@ref) constructor of each snapshot. + +# Example + +```jldoctest +julia> n, snaps, α, R, speed, ζ = 10, 5, 1.0, 4.0, 0.1, 1.0; + +julia> thg = rand_temporal_hyperbolic_graph(n, snaps; α, R, speed, ζ) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10] + num_edges: [44, 46, 48, 42, 38] + num_snapshots: 5 +``` + +# References +Section D of the paper [Dynamic Hidden-Variable Network Models](https://arxiv.org/pdf/2101.00414.pdf) and the paper +[Hyperbolic Geometry of Complex Networks](https://arxiv.org/pdf/1006.5169.pdf) +""" +function rand_temporal_hyperbolic_graph(number_nodes::Int, + number_snapshots::Int; + α::Real, + R::Real, + speed::Real, + ζ::Real=1, + self_loop = false, + kws...) + @assert number_snapshots > 1 "The number of snapshots must be greater than 1" + @assert α > 0 "α must be greater than 0" + + probabilities = rand(number_nodes) + + points = Array{Float64}(undef,2,number_nodes) + points[1,:].= (1/α) * acosh.(1 .+ (cosh(α * R) - 1) * probabilities) + points[2,:].= 2 * pi * rand(number_nodes) + + tg = Vector{GNNGraph}(undef, number_snapshots) + + for time in 1:number_snapshots + adj = zeros(number_nodes,number_nodes) + for i in 1:number_nodes + for j in 1:number_nodes + if !self_loop && i==j + continue + elseif _hyperbolic_distance(points[:,i],points[:,j]; ζ) <= R + adj[i,j] = adj[j,i] = 1 + end + end + end + tg[time] = GNNGraph(adj) + + probabilities .= probabilities .+ (2 * speed * rand(number_nodes) .- speed) + probabilities[probabilities.>1] .= 1 .- (probabilities[probabilities .> 1] .% 1) + probabilities[probabilities.<0] .= abs.(probabilities[probabilities .< 0]) + + points[1,:].= (1/α) * acosh.(1 .+ (cosh(α * R) - 1) * probabilities) + points[2,:].= points[2,:] .+ (2 * speed * rand(number_nodes) .- speed) + end + return TemporalSnapshotsGNNGraph(tg) +end diff --git a/GNNlib/src/GNNGraphs/gnngraph.jl b/GNNlib/src/GNNGraphs/gnngraph.jl new file mode 100644 index 000000000..a26652d94 --- /dev/null +++ b/GNNlib/src/GNNGraphs/gnngraph.jl @@ -0,0 +1,347 @@ +#=================================== +Define GNNGraph type as a subtype of Graphs.AbstractGraph. +For the core methods to be implemented by any AbstractGraph, see +https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type +https://juliagraphs.org/Graphs.jl/latest/developing/#Developing-Alternate-Graph-Types +=============================================# + +""" + GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir]) + GNNGraph(g::GNNGraph; [ndata, edata, gdata]) + +A type representing a graph structure that also stores +feature arrays associated to nodes, edges, and the graph itself. + +The feature arrays are stored in the fields `ndata`, `edata`, and `gdata` +as [`DataStore`](@ref) objects offering a convenient dictionary-like +and namedtuple-like interface. The features can be passed at construction +time or added later. + +A `GNNGraph` can be constructed out of different `data` objects +expressing the connections inside the graph. The internal representation type +is determined by `graph_type`. + +When constructed from another `GNNGraph`, the internal graph representation +is preserved and shared. The node/edge/graph features are retained +as well, unless explicitely set by the keyword arguments +`ndata`, `edata`, and `gdata`. + +A `GNNGraph` can also represent multiple graphs batched togheter +(see [`MLUtils.batch`](@ref) or [`SparseArrays.blockdiag`](@ref)). +The field `g.graph_indicator` contains the graph membership +of each node. + +`GNNGraph`s are always directed graphs, therefore each edge is defined +by a source node and a target node (see [`edge_index`](@ref)). +Self loops (edges connecting a node to itself) and multiple edges +(more than one edge between the same pair of nodes) are supported. + +A `GNNGraph` is a Graphs.jl's `AbstractGraph`, therefore it supports most +functionality from that library. + +# Arguments + +- `data`: Some data representing the graph topology. Possible type are + - An adjacency matrix + - An adjacency list. + - A tuple containing the source and target vectors (COO representation) + - A Graphs.jl' graph. +- `graph_type`: A keyword argument that specifies + the underlying representation used by the GNNGraph. + Currently supported values are + - `:coo`. Graph represented as a tuple `(source, target)`, such that the `k`-th edge + connects the node `source[k]` to node `target[k]`. + Optionally, also edge weights can be given: `(source, target, weights)`. + - `:sparse`. A sparse adjacency matrix representation. + - `:dense`. A dense adjacency matrix representation. + Defaults to `:coo`, currently the most supported type. +- `dir`: The assumed edge direction when given adjacency matrix or adjacency list input data `g`. + Possible values are `:out` and `:in`. Default `:out`. +- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`. +- `graph_indicator`: For batched graphs, a vector containing the graph assignment of each node. Default `nothing`. +- `ndata`: Node features. An array or named tuple of arrays whose last dimension has size `num_nodes`. +- `edata`: Edge features. An array or named tuple of arrays whose last dimension has size `num_edges`. +- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`. + +# Examples + +```julia +using Flux, GraphNeuralNetworks, CUDA + +# Construct from adjacency list representation +data = [[2,3], [1,4,5], [1], [2,5], [2,4]] +g = GNNGraph(data) + +# Number of nodes, edges, and batched graphs +g.num_nodes # 5 +g.num_edges # 10 +g.num_graphs # 1 + +# Same graph in COO representation +s = [1,1,2,2,2,3,4,4,5,5] +t = [2,3,1,4,5,3,2,5,2,4] +g = GNNGraph(s, t) + +# From a Graphs' graph +g = GNNGraph(erdos_renyi(100, 20)) + +# Add 2 node feature arrays at creation time +g = GNNGraph(g, ndata = (x=rand(Float32,100,g.num_nodes), y=rand(Float32,g.num_nodes))) + +# Add 1 edge feature array, after the graph creation +g.edata.z = rand(Float32,16,g.num_edges) + +# Add node features and edge features with default names `x` and `e` +g = GNNGraph(g, ndata = rand(Float32,100,g.num_nodes), edata = rand(Float32,16,g.num_edges)) + +g.ndata.x # or just g.x +g.edata.e # or just g.e + +# Send to gpu +g = g |> gpu + +# Collect edges' source and target nodes. +# Both source and target are vectors of length num_edges +source, target = edge_index(g) +``` +""" +struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T} + graph::T + num_nodes::Int + num_edges::Int + num_graphs::Int + graph_indicator::Union{Nothing, AVecI} # vector of ints or nothing + ndata::DataStore + edata::DataStore + gdata::DataStore +end + +@functor GNNGraph + +function GNNGraph(data::D; + num_nodes = nothing, + graph_indicator = nothing, + graph_type = :coo, + dir = :out, + ndata = nothing, + edata = nothing, + gdata = nothing) where {D <: Union{COO_T, ADJMAT_T, ADJLIST_T}} + @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" + @assert dir ∈ [:in, :out] + + if graph_type == :coo + graph, num_nodes, num_edges = to_coo(data; num_nodes, dir) + elseif graph_type == :dense + graph, num_nodes, num_edges = to_dense(data; num_nodes, dir) + elseif graph_type == :sparse + graph, num_nodes, num_edges = to_sparse(data; num_nodes, dir) + end + + num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1 + + ndata = normalize_graphdata(ndata, default_name = :x, n = num_nodes) + edata = normalize_graphdata(edata, default_name = :e, n = num_edges, + duplicate_if_needed = true) + + # don't force the shape of the data when there is only one graph + gdata = normalize_graphdata(gdata, default_name = :u, + n = num_graphs > 1 ? num_graphs : -1) + + GNNGraph(graph, + num_nodes, num_edges, num_graphs, + graph_indicator, + ndata, edata, gdata) +end + +GNNGraph(; kws...) = GNNGraph(0; kws...) + +function (::Type{<:GNNGraph})(num_nodes::T; kws...) where {T <: Integer} + s, t = T[], T[] + return GNNGraph(s, t; num_nodes, kws...) +end + +Base.zero(::Type{G}) where {G <: GNNGraph} = G(0) + +# COO convenience constructors +function GNNGraph(s::AbstractVector, t::AbstractVector, v = nothing; kws...) + GNNGraph((s, t, v); kws...) +end +GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...) + +# GNNGraph(g::AbstractGraph; kws...) = GNNGraph(adjacency_matrix(g, dir=:out); kws...) + +function GNNGraph(g::AbstractGraph; edge_weight = nothing, kws...) + s = Graphs.src.(Graphs.edges(g)) + t = Graphs.dst.(Graphs.edges(g)) + w = edge_weight + if !Graphs.is_directed(g) + # add reverse edges since GNNGraph is directed + s, t = [s; t], [t; s] + if !isnothing(w) + @assert length(w) == Graphs.ne(g) "edge_weight must have length equal to the number of undirected edges" + w = [w; w] + end + end + num_nodes::Int = Graphs.nv(g) + GNNGraph((s, t, w); num_nodes = num_nodes, kws...) +end + +function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata, + graph_type = nothing) + ndata = normalize_graphdata(ndata, default_name = :x, n = g.num_nodes) + edata = normalize_graphdata(edata, default_name = :e, n = g.num_edges, + duplicate_if_needed = true) + gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs) + + if !isnothing(graph_type) + if graph_type == :coo + graph, num_nodes, num_edges = to_coo(g.graph; g.num_nodes) + elseif graph_type == :dense + graph, num_nodes, num_edges = to_dense(g.graph; g.num_nodes) + elseif graph_type == :sparse + graph, num_nodes, num_edges = to_sparse(g.graph; g.num_nodes) + end + @assert num_nodes == g.num_nodes + @assert num_edges == g.num_edges + else + graph = g.graph + end + GNNGraph(graph, + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + ndata, edata, gdata) +end + +""" + copy(g::GNNGraph; deep=false) + +Create a copy of `g`. If `deep` is `true`, then copy will be a deep copy (equivalent to `deepcopy(g)`), +otherwise it will be a shallow copy with the same underlying graph data. +""" +function Base.copy(g::GNNGraph; deep = false) + if deep + GNNGraph(deepcopy(g.graph), + g.num_nodes, g.num_edges, g.num_graphs, + deepcopy(g.graph_indicator), + deepcopy(g.ndata), deepcopy(g.edata), deepcopy(g.gdata)) + else + GNNGraph(g.graph, + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) + end +end + +function print_feature(io::IO, feature) + if !isempty(feature) + if length(keys(feature)) == 1 + k = first(keys(feature)) + v = first(values(feature)) + print(io, "$(k): $(dims2string(size(v)))") + else + print(io, "(") + for (i, (k, v)) in enumerate(pairs(feature)) + print(io, "$k: $(dims2string(size(v)))") + if i == length(feature) + print(io, ")") + else + print(io, ", ") + end + end + end + end +end + +function print_all_features(io::IO, feat1, feat2, feat3) + n1 = length(feat1) + n2 = length(feat2) + n3 = length(feat3) + if n1 == 0 && n2 == 0 && n3 == 0 + print(io, "no") + elseif n1 != 0 && (n2 != 0 || n3 != 0) + print_feature(io, feat1) + print(io, ", ") + elseif n2 == 0 && n3 == 0 + print_feature(io, feat1) + end + if n2 != 0 && n3 != 0 + print_feature(io, feat2) + print(io, ", ") + elseif n2 != 0 && n3 == 0 + print_feature(io, feat2) + end + print_feature(io, feat3) +end + +function Base.show(io::IO, g::GNNGraph) + print(io, "GNNGraph($(g.num_nodes), $(g.num_edges)) with ") + print_all_features(io, g.ndata, g.edata, g.gdata) + print(io, " data") +end + +function Base.show(io::IO, ::MIME"text/plain", g::GNNGraph) + if get(io, :compact, false) + print(io, "GNNGraph($(g.num_nodes), $(g.num_edges)) with ") + print_all_features(io, g.ndata, g.edata, g.gdata) + print(io, " data") + else + print(io, + "GNNGraph:\n num_nodes: $(g.num_nodes)\n num_edges: $(g.num_edges)") + g.num_graphs > 1 && print(io, "\n num_graphs: $(g.num_graphs)") + if !isempty(g.ndata) + print(io, "\n ndata:") + for k in keys(g.ndata) + print(io, "\n\t$k = $(shortsummary(g.ndata[k]))") + end + end + if !isempty(g.edata) + print(io, "\n edata:") + for k in keys(g.edata) + print(io, "\n\t$k = $(shortsummary(g.edata[k]))") + end + end + if !isempty(g.gdata) + print(io, "\n gdata:") + for k in keys(g.gdata) + print(io, "\n\t$k = $(shortsummary(g.gdata[k]))") + end + end + end +end + +MLUtils.numobs(g::GNNGraph) = g.num_graphs +MLUtils.getobs(g::GNNGraph, i) = getgraph(g, i) + +######################### + +function Base.:(==)(g1::GNNGraph, g2::GNNGraph) + g1 === g2 && return true + for k in fieldnames(typeof(g1)) + k === :graph_indicator && continue + getfield(g1, k) != getfield(g2, k) && return false + end + return true +end + +function Base.hash(g::T, h::UInt) where {T <: GNNGraph} + fs = (getfield(g, k) for k in fieldnames(T) if k !== :graph_indicator) + return foldl((h, f) -> hash(f, h), fs, init = hash(T, h)) +end + +function Base.getproperty(g::GNNGraph, s::Symbol) + if s in fieldnames(GNNGraph) + return getfield(g, s) + end + if (s in keys(g.ndata)) + (s in keys(g.edata)) + (s in keys(g.gdata)) > 1 + throw(ArgumentError("Ambiguous property name $s")) + end + if s in keys(g.ndata) + return g.ndata[s] + elseif s in keys(g.edata) + return g.edata[s] + elseif s in keys(g.gdata) + return g.gdata[s] + else + throw(ArgumentError("$(s) is not a field of GNNGraph")) + end +end diff --git a/GNNlib/src/GNNGraphs/gnnheterograph.jl b/GNNlib/src/GNNGraphs/gnnheterograph.jl new file mode 100644 index 000000000..72d67b34b --- /dev/null +++ b/GNNlib/src/GNNGraphs/gnnheterograph.jl @@ -0,0 +1,299 @@ + +const EType = Tuple{Symbol, Symbol, Symbol} +const NType = Symbol +const EDict{T} = Dict{EType, T} +const NDict{T} = Dict{NType, T} + +""" + GNNHeteroGraph(data; [ndata, edata, gdata, num_nodes]) + GNNHeteroGraph(pairs...; [ndata, edata, gdata, num_nodes]) + +A type representing a heterogeneous graph structure. +It is similar to [`GNNGraph`](@ref) but nodes and edges are of different types. + +# Constructor Arguments + +- `data`: A dictionary or an iterable object that maps `(source_type, edge_type, target_type)` + triples to `(source, target)` index vectors (or to `(source, target, weight)` if also edge weights are present). +- `pairs`: Passing multiple relations as pairs is equivalent to passing `data=Dict(pairs...)`. +- `ndata`: Node features. A dictionary of arrays or named tuple of arrays. + The size of the last dimension of each array must be given by `g.num_nodes`. +- `edata`: Edge features. A dictionary of arrays or named tuple of arrays. Default `nothing`. + The size of the last dimension of each array must be given by `g.num_edges`. Default `nothing`. +- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`. Default `nothing`. +- `num_nodes`: The number of nodes for each type. If not specified, inferred from `data`. Default `nothing`. + +# Fields + +- `graph`: A dictionary that maps (source_type, edge_type, target_type) triples to (source, target) index vectors. +- `num_nodes`: The number of nodes for each type. +- `num_edges`: The number of edges for each type. +- `ndata`: Node features. +- `edata`: Edge features. +- `gdata`: Graph features. +- `ntypes`: The node types. +- `etypes`: The edge types. + +# Examples + +```julia +julia> using GraphNeuralNetworks + +julia> nA, nB = 10, 20; + +julia> num_nodes = Dict(:A => nA, :B => nB); + +julia> edges1 = (rand(1:nA, 20), rand(1:nB, 20)) +([4, 8, 6, 3, 4, 7, 2, 7, 3, 2, 3, 4, 9, 4, 2, 9, 10, 1, 3, 9], [6, 4, 20, 8, 16, 7, 12, 16, 5, 4, 6, 20, 11, 19, 17, 9, 12, 2, 18, 12]) + +julia> edges2 = (rand(1:nB, 30), rand(1:nA, 30)) +([17, 5, 2, 4, 5, 3, 8, 7, 9, 7 … 19, 8, 20, 7, 16, 2, 9, 15, 8, 13], [1, 1, 3, 1, 1, 3, 2, 7, 4, 4 … 7, 10, 6, 3, 4, 9, 1, 5, 8, 5]) + +julia> data = ((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2); + +julia> hg = GNNHeteroGraph(data; num_nodes) +GNNHeteroGraph: + num_nodes: (:A => 10, :B => 20) + num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30) + +julia> hg.num_edges +Dict{Tuple{Symbol, Symbol, Symbol}, Int64} with 2 entries: +(:A, :rel1, :B) => 20 +(:B, :rel2, :A) => 30 + +# Let's add some node features +julia> ndata = Dict(:A => (x = rand(2, nA), y = rand(3, num_nodes[:A])), + :B => rand(10, nB)); + +julia> hg = GNNHeteroGraph(data; num_nodes, ndata) +GNNHeteroGraph: + num_nodes: (:A => 10, :B => 20) + num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30) + ndata: + :A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64}) + :B => x = 10×20 Matrix{Float64} + +# Access features of nodes of type :A +julia> hg.ndata[:A].x +2×10 Matrix{Float64}: + 0.825882 0.0797502 0.245813 0.142281 0.231253 0.685025 0.821457 0.888838 0.571347 0.53165 + 0.631286 0.316292 0.705325 0.239211 0.533007 0.249233 0.473736 0.595475 0.0623298 0.159307 +``` + +See also [`GNNGraph`](@ref) for a homogeneous graph type and [`rand_heterograph`](@ref) for a function to generate random heterographs. +""" +struct GNNHeteroGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T} + graph::EDict{T} + num_nodes::NDict{Int} + num_edges::EDict{Int} + num_graphs::Int + graph_indicator::Union{Nothing, NDict} + ndata::NDict{DataStore} + edata::EDict{DataStore} + gdata::DataStore + ntypes::Vector{NType} + etypes::Vector{EType} +end + +@functor GNNHeteroGraph + +GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...) +GNNHeteroGraph(data::Pair...; kws...) = GNNHeteroGraph(Dict(data...); kws...) + +GNNHeteroGraph() = GNNHeteroGraph(Dict{Tuple{Symbol,Symbol,Symbol}, Any}()) + +function GNNHeteroGraph(data::Dict; kws...) + all(k -> k isa EType, keys(data)) || throw(ArgumentError("Keys of data must be tuples of the form `(source_type, edge_type, target_type)`")) + return GNNHeteroGraph(Dict([k => v for (k, v) in pairs(data)]...); kws...) +end + +function GNNHeteroGraph(data::EDict; + num_nodes = nothing, + graph_indicator = nothing, + graph_type = :coo, + dir = :out, + ndata = nothing, + edata = nothing, + gdata = (;)) + @assert graph_type ∈ [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested" + @assert dir ∈ [:in, :out] + @assert graph_type==:coo "only :coo graph_type is supported for now" + + if num_nodes !== nothing + num_nodes = Dict(num_nodes) + end + + ntypes = union([[k[1] for k in keys(data)]; [k[3] for k in keys(data)]]) + etypes = collect(keys(data)) + + if graph_type == :coo + graph, num_nodes, num_edges = to_coo(data; num_nodes, dir) + elseif graph_type == :dense + graph, num_nodes, num_edges = to_dense(data; num_nodes, dir) + elseif graph_type == :sparse + graph, num_nodes, num_edges = to_sparse(data; num_nodes, dir) + end + + num_graphs = !isnothing(graph_indicator) ? + maximum([maximum(gi) for gi in values(graph_indicator)]) : 1 + + + if length(keys(graph)) == 0 + ndata = Dict{Symbol, DataStore}() + edata = Dict{Tuple{Symbol, Symbol, Symbol}, DataStore}() + gdata = DataStore() + else + ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes) + edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges, + duplicate_if_needed = true) + gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs) + end + + return GNNHeteroGraph(graph, + num_nodes, num_edges, num_graphs, + graph_indicator, + ndata, edata, gdata, + ntypes, etypes) +end + +function show_sorted_dict(io::IO, d::Dict, compact::Bool) + # if compact + print(io, "Dict") + # end + print(io, "(") + if !isempty(d) + _keys = sort!(collect(keys(d))) + for key in _keys[1:end-1] + print(io, "$(_str(key)) => $(d[key]), ") + end + print(io, "$(_str(_keys[end])) => $(d[_keys[end]])") + end + # if length(d) == 1 + # print(io, ",") + # end + print(io, ")") +end + +function Base.show(io::IO, g::GNNHeteroGraph) + print(io, "GNNHeteroGraph(") + show_sorted_dict(io, g.num_nodes, true) + print(io, ", ") + show_sorted_dict(io, g.num_edges, true) + print(io, ")") +end + +function Base.show(io::IO, ::MIME"text/plain", g::GNNHeteroGraph) + if get(io, :compact, false) + print(io, "GNNHeteroGraph(") + show_sorted_dict(io, g.num_nodes, true) + print(io, ", ") + show_sorted_dict(io, g.num_edges, true) + print(io, ")") + else + print(io, "GNNHeteroGraph:\n num_nodes: ") + show_sorted_dict(io, g.num_nodes, false) + print(io, "\n num_edges: ") + show_sorted_dict(io, g.num_edges, false) + g.num_graphs > 1 && print(io, "\n num_graphs: $(g.num_graphs)") + if !isempty(g.ndata) && !all(isempty, values(g.ndata)) + print(io, "\n ndata:") + for k in sort(collect(keys(g.ndata))) + isempty(g.ndata[k]) && continue + print(io, "\n\t", _str(k), " => $(shortsummary(g.ndata[k]))") + end + end + if !isempty(g.edata) && !all(isempty, values(g.edata)) + print(io, "\n edata:") + for k in sort(collect(keys(g.edata))) + isempty(g.edata[k]) && continue + print(io, "\n\t$k => $(shortsummary(g.edata[k]))") + end + end + if !isempty(g.gdata) + print(io, "\n gdata:\n\t") + shortsummary(io, g.gdata) + end + end +end + +_str(s::Symbol) = ":$s" +_str(s) = "$s" + +MLUtils.numobs(g::GNNHeteroGraph) = g.num_graphs +# MLUtils.getobs(g::GNNHeteroGraph, i) = getgraph(g, i) + + +""" + num_edge_types(g) + +Return the number of edge types in the graph. For [`GNNGraph`](@ref)s, this is always 1. +For [`GNNHeteroGraph`](@ref)s, this is the number of unique edge types. +""" +num_edge_types(g::GNNGraph) = 1 + +num_edge_types(g::GNNHeteroGraph) = length(g.etypes) + +""" + num_node_types(g) + +Return the number of node types in the graph. For [`GNNGraph`](@ref)s, this is always 1. +For [`GNNHeteroGraph`](@ref)s, this is the number of unique node types. +""" +num_node_types(g::GNNGraph) = 1 + +num_node_types(g::GNNHeteroGraph) = length(g.ntypes) + +""" + edge_type_subgraph(g::GNNHeteroGraph, edge_ts) + +Return a subgraph of `g` that contains only the edges of type `edge_ts`. +Edge types can be specified as a single edge type (i.e. a tuple containing 3 symbols) or a vector of edge types. +""" +edge_type_subgraph(g::GNNHeteroGraph, edge_t::EType) = edge_type_subgraph(g, [edge_t]) + +function edge_type_subgraph(g::GNNHeteroGraph, edge_ts::AbstractVector{<:EType}) + for edge_t in edge_ts + @assert edge_t in g.etypes "Edge type $(edge_t) not found in graph" + end + node_ts = _ntypes_from_edges(edge_ts) + graph = Dict([edge_t => g.graph[edge_t] for edge_t in edge_ts]...) + num_nodes = Dict([node_t => g.num_nodes[node_t] for node_t in node_ts]...) + num_edges = Dict([edge_t => g.num_edges[edge_t] for edge_t in edge_ts]...) + if g.graph_indicator === nothing + graph_indicator = nothing + else + graph_indicator = Dict([node_t => g.graph_indicator[node_t] for node_t in node_ts]...) + end + ndata = Dict([node_t => g.ndata[node_t] for node_t in node_ts if node_t in keys(g.ndata)]...) + edata = Dict([edge_t => g.edata[edge_t] for edge_t in edge_ts if edge_t in keys(g.edata)]...) + + return GNNHeteroGraph(graph, num_nodes, num_edges, g.num_graphs, + graph_indicator, ndata, edata, g.gdata, + node_ts, edge_ts) +end + +# TODO this is not correct but Zygote cannot differentiate +# through dictionary generation +# @non_differentiable edge_type_subgraph(::Any...) + +function _ntypes_from_edges(edge_ts::AbstractVector{<:EType}) + ntypes = Symbol[] + for edge_t in edge_ts + node1_t, _, node2_t = edge_t + !in(node1_t, ntypes) && push!(ntypes, node1_t) + !in(node2_t, ntypes) && push!(ntypes, node2_t) + end + return ntypes +end + +@non_differentiable _ntypes_from_edges(::Any...) + +function Base.getindex(g::GNNHeteroGraph, node_t::NType) + return g.ndata[node_t] +end + +Base.getindex(g::GNNHeteroGraph, n1_t::Symbol, rel::Symbol, n2_t::Symbol) = g[(n1_t, rel, n2_t)] + +function Base.getindex(g::GNNHeteroGraph, edge_t::EType) + return g.edata[edge_t] +end diff --git a/GNNlib/src/GNNGraphs/operators.jl b/GNNlib/src/GNNGraphs/operators.jl new file mode 100644 index 000000000..655602b41 --- /dev/null +++ b/GNNlib/src/GNNGraphs/operators.jl @@ -0,0 +1,13 @@ +# 2 or more args graph operators +function Base.intersect(g1::GNNGraph, g2::GNNGraph) + @assert g1.num_nodes == g2.num_nodes + @assert graph_type_symbol(g1) == graph_type_symbol(g2) + graph_type = graph_type_symbol(g1) + num_nodes = g1.num_nodes + + idx1, _ = edge_encoding(edge_index(g1)..., num_nodes) + idx2, _ = edge_encoding(edge_index(g2)..., num_nodes) + idx = intersect(idx1, idx2) + s, t = edge_decoding(idx, num_nodes) + return GNNGraph(s, t; num_nodes, graph_type) +end diff --git a/GNNlib/src/GNNGraphs/query.jl b/GNNlib/src/GNNGraphs/query.jl new file mode 100644 index 000000000..6e32a2df1 --- /dev/null +++ b/GNNlib/src/GNNGraphs/query.jl @@ -0,0 +1,633 @@ + +""" + edge_index(g::GNNGraph) + +Return a tuple containing two vectors, respectively storing +the source and target nodes for each edges in `g`. + +```julia +s, t = edge_index(g) +``` +""" +edge_index(g::GNNGraph{<:COO_T}) = g.graph[1:2] + +edge_index(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][1:2] + +""" + edge_index(g::GNNHeteroGraph, [edge_t]) + +Return a tuple containing two vectors, respectively storing the source and target nodes +for each edges in `g` of type `edge_t = (src_t, rel_t, trg_t)`. + +If `edge_t` is not provided, it will error if `g` has more than one edge type. +""" +edge_index(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][1:2] +edge_index(g::GNNHeteroGraph{<:COO_T}) = only(g.graph)[2][1:2] + +get_edge_weight(g::GNNGraph{<:COO_T}) = g.graph[3] + +get_edge_weight(g::GNNGraph{<:ADJMAT_T}) = to_coo(g.graph, num_nodes = g.num_nodes)[1][3] + +get_edge_weight(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) = g.graph[edge_t][3] + +Graphs.edges(g::GNNGraph) = Graphs.Edge.(edge_index(g)...) + +Graphs.edgetype(g::GNNGraph) = Graphs.Edge{eltype(g)} + +# """ +# eltype(g::GNNGraph) +# +# Type of nodes in `g`, +# an integer type like `Int`, `Int32`, `Uint16`, .... +# """ +function Base.eltype(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + w = get_edge_weight(g) + return w !== nothing ? eltype(w) : eltype(s) +end + +Base.eltype(g::GNNGraph{<:ADJMAT_T}) = eltype(g.graph) + +function Graphs.has_edge(g::GNNGraph{<:COO_T}, i::Integer, j::Integer) + s, t = edge_index(g) + return any((s .== i) .& (t .== j)) +end + +Graphs.has_edge(g::GNNGraph{<:ADJMAT_T}, i::Integer, j::Integer) = g.graph[i, j] != 0 + +""" + has_edge(g::GNNHeteroGraph, edge_t, i, j) + +Return `true` if there is an edge of type `edge_t` from node `i` to node `j` in `g`. + +# Examples + +```jldoctest +julia> g = rand_bipartite_heterograph((2, 2), (4, 0), bidirected=false) +GNNHeteroGraph: + num_nodes: (:A => 2, :B => 2) + num_edges: ((:A, :to, :B) => 4, (:B, :to, :A) => 0) + +julia> has_edge(g, (:A,:to,:B), 1, 1) +true + +julia> has_edge(g, (:B,:to,:A), 1, 1) +false +``` +""" +function Graphs.has_edge(g::GNNHeteroGraph, edge_t::EType, i::Integer, j::Integer) + s, t = edge_index(g, edge_t) + return any((s .== i) .& (t .== j)) +end + +graph_type_symbol(::GNNGraph{<:COO_T}) = :coo +graph_type_symbol(::GNNGraph{<:SPARSE_T}) = :sparse +graph_type_symbol(::GNNGraph{<:ADJMAT_T}) = :dense + +Graphs.nv(g::GNNGraph) = g.num_nodes +Graphs.ne(g::GNNGraph) = g.num_edges +Graphs.has_vertex(g::GNNGraph, i::Int) = 1 <= i <= g.num_nodes +Graphs.vertices(g::GNNGraph) = 1:(g.num_nodes) + +function Graphs.neighbors(g::GNNGraph, i; dir = :out) + @assert dir ∈ (:in, :out) + if dir == :out + outneighbors(g, i) + else + inneighbors(g, i) + end +end + +function Graphs.outneighbors(g::GNNGraph{<:COO_T}, i::Integer) + s, t = edge_index(g) + return t[s .== i] +end + +function Graphs.outneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer) + A = g.graph + return findall(!=(0), A[i, :]) +end + +function Graphs.inneighbors(g::GNNGraph{<:COO_T}, i::Integer) + s, t = edge_index(g) + return s[t .== i] +end + +function Graphs.inneighbors(g::GNNGraph{<:ADJMAT_T}, i::Integer) + A = g.graph + return findall(!=(0), A[:, i]) +end + +Graphs.is_directed(::GNNGraph) = true +Graphs.is_directed(::Type{<:GNNGraph}) = true + +""" + adjacency_list(g; dir=:out) + adjacency_list(g, nodes; dir=:out) + +Return the adjacency list representation (a vector of vectors) +of the graph `g`. + +Calling `a` the adjacency list, if `dir=:out` than +`a[i]` will contain the neighbors of node `i` through +outgoing edges. If `dir=:in`, it will contain neighbors from +incoming edges instead. + +If `nodes` is given, return the neighborhood of the nodes in `nodes` only. +""" +function adjacency_list(g::GNNGraph, nodes; dir = :out, with_eid = false) + @assert dir ∈ [:out, :in] + s, t = edge_index(g) + if dir == :in + s, t = t, s + end + T = eltype(s) + idict = 0 + dmap = Dict(n => (idict += 1) for n in nodes) + adjlist = [T[] for _ in 1:length(dmap)] + eidlist = [T[] for _ in 1:length(dmap)] + for (eid, (i, j)) in enumerate(zip(s, t)) + inew = get(dmap, i, 0) + inew == 0 && continue + push!(adjlist[inew], j) + push!(eidlist[inew], eid) + end + if with_eid + return adjlist, eidlist + else + return adjlist + end +end + +# function adjacency_list(g::GNNGraph, nodes; dir=:out) +# @assert dir ∈ [:out, :in] +# fneighs = dir == :out ? outneighbors : inneighbors +# return [fneighs(g, i) for i in nodes] +# end + +adjacency_list(g::GNNGraph; dir = :out) = adjacency_list(g, 1:(g.num_nodes); dir) + +""" + adjacency_matrix(g::GNNGraph, T=eltype(g); dir=:out, weighted=true) + +Return the adjacency matrix `A` for the graph `g`. + +If `dir=:out`, `A[i,j] > 0` denotes the presence of an edge from node `i` to node `j`. +If `dir=:in` instead, `A[i,j] > 0` denotes the presence of an edge from node `j` to node `i`. + +User may specify the eltype `T` of the returned matrix. + +If `weighted=true`, the `A` will contain the edge weights if any, otherwise the elements of `A` will be either 0 or 1. +""" +function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType = eltype(g); dir = :out, + weighted = true) + if iscuarray(g.graph[1]) + # Revisit after + # https://github.com/JuliaGPU/CUDA.jl/issues/1113 + A, n, m = to_dense(g.graph, T; num_nodes = g.num_nodes, weighted) + else + A, n, m = to_sparse(g.graph, T; num_nodes = g.num_nodes, weighted) + end + @assert size(A) == (n, n) + return dir == :out ? A : A' +end + +function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g); + dir = :out, weighted = true) + @assert dir ∈ [:in, :out] + A = g.graph + if !weighted + A = binarize(A) + end + A = T != eltype(A) ? T.(A) : A + return dir == :out ? A : A' +end + +function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType; + dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}} + A = adjacency_matrix(g, T; dir, weighted) + if !weighted + function adjacency_matrix_pullback_noweight(Δ) + return (NoTangent(), ZeroTangent(), NoTangent()) + end + return A, adjacency_matrix_pullback_noweight + else + function adjacency_matrix_pullback_weighted(Δ) + dg = Tangent{G}(; graph = Δ .* binarize(A)) + return (NoTangent(), dg, NoTangent()) + end + return A, adjacency_matrix_pullback_weighted + end +end + +function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType; + dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}} + A = adjacency_matrix(g, T; dir, weighted) + w = get_edge_weight(g) + if !weighted || w === nothing + function adjacency_matrix_pullback_noweight(Δ) + return (NoTangent(), ZeroTangent(), NoTangent()) + end + return A, adjacency_matrix_pullback_noweight + else + function adjacency_matrix_pullback_weighted(Δ) + s, t = edge_index(g) + dg = Tangent{G}(; graph = (NoTangent(), NoTangent(), NNlib.gather(Δ, s, t))) + return (NoTangent(), dg, NoTangent()) + end + return A, adjacency_matrix_pullback_weighted + end +end + +function _get_edge_weight(g, edge_weight::Bool) + if edge_weight === true + return get_edge_weight(g) + elseif edge_weight === false + return nothing + end +end + +_get_edge_weight(g, edge_weight::AbstractVector) = edge_weight + +""" + degree(g::GNNGraph, T=nothing; dir=:out, edge_weight=true) + +Return a vector containing the degrees of the nodes in `g`. + +The gradient is propagated through this function only if `edge_weight` is `true` +or a vector. + +# Arguments + +- `g`: A graph. +- `T`: Element type of the returned vector. If `nothing`, is + chosen based on the graph type and will be an integer + if `edge_weight = false`. Default `nothing`. +- `dir`: For `dir = :out` the degree of a node is counted based on the outgoing edges. + For `dir = :in`, the ingoing edges are used. If `dir = :both` we have the sum of the two. +- `edge_weight`: If `true` and the graph contains weighted edges, the degree will + be weighted. Set to `false` instead to just count the number of + outgoing/ingoing edges. + Finally, you can also pass a vector of weights to be used + instead of the graph's own weights. + Default `true`. + +""" +function Graphs.degree(g::GNNGraph{<:COO_T}, T::TT = nothing; dir = :out, + edge_weight = true) where { + TT <: Union{Nothing, Type{<:Number}}} + s, t = edge_index(g) + + ew = _get_edge_weight(g, edge_weight) + + T = if isnothing(T) + if !isnothing(ew) + eltype(ew) + else + eltype(s) + end + else + T + end + return _degree((s, t), T, dir, ew, g.num_nodes) +end + +# TODO:: Make efficient +Graphs.degree(g::GNNGraph, i::Union{Int, AbstractVector}; dir = :out) = degree(g; dir)[i] + +function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out, + edge_weight = true) where {TT<:Union{Nothing, Type{<:Number}}} + + # edge_weight=true or edge_weight=nothing act the same here + @assert !(edge_weight isa AbstractArray) "passing the edge weights is not support by adjacency matrix representations" + @assert dir ∈ (:in, :out, :both) + if T === nothing + Nt = eltype(g) + if edge_weight === false && !(Nt <: Integer) + T = Nt == Float32 ? Int32 : + Nt == Float16 ? Int16 : Int + else + T = Nt + end + end + A = adjacency_matrix(g) + return _degree(A, T, dir, edge_weight, g.num_nodes) +end + +""" + degree(g::GNNHeteroGraph, edge_type::EType; dir = :in) + +Return a vector containing the degrees of the nodes in `g` GNNHeteroGraph +given `edge_type`. + +# Arguments + +- `g`: A graph. +- `edge_type`: A tuple of symbols `(source_t, edge_t, target_t)` representing the edge type. +- `T`: Element type of the returned vector. If `nothing`, is + chosen based on the graph type. Default `nothing`. +- `dir`: For `dir = :out` the degree of a node is counted based on the outgoing edges. + For `dir = :in`, the ingoing edges are used. If `dir = :both` we have the sum of the two. + Default `dir = :out`. + +""" +function Graphs.degree(g::GNNHeteroGraph, edge::EType, + T::TT = nothing; dir = :out) where { + TT <: Union{Nothing, Type{<:Number}}} + + s, t = edge_index(g, edge) + + T = isnothing(T) ? eltype(s) : T + + n_type = dir == :in ? g.ntypes[2] : g.ntypes[1] + + return _degree((s, t), T, dir, nothing, g.num_nodes[n_type]) +end + +function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::Nothing, num_nodes::Int) + _degree((s, t), T, dir, ones_like(s, T), num_nodes) +end + +function _degree((s, t)::Tuple, T::Type, dir::Symbol, edge_weight::AbstractVector, num_nodes::Int) + degs = fill!(similar(s, T, num_nodes), 0) + + if dir ∈ [:out, :both] + degs = degs .+ NNlib.scatter(+, edge_weight, s, dstsize = (num_nodes,)) + end + if dir ∈ [:in, :both] + degs = degs .+ NNlib.scatter(+, edge_weight, t, dstsize = (num_nodes,)) + end + return degs +end + +function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num_nodes::Int) + if edge_weight === false + A = binarize(A) + end + A = eltype(A) != T ? T.(A) : A + return dir == :out ? vec(sum(A, dims = 2)) : + dir == :in ? vec(sum(A, dims = 1)) : + vec(sum(A, dims = 1)) .+ vec(sum(A, dims = 2)) +end + +function ChainRulesCore.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes) + degs = _degree(graph, T, dir, edge_weight, num_nodes) + function _degree_pullback(Δ) + return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()) + end + return degs, _degree_pullback +end + +function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes) + degs = _degree(A, T, dir, edge_weight, num_nodes) + if edge_weight === false + function _degree_pullback_noweights(Δ) + return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent()) + end + return degs, _degree_pullback_noweights + else + function _degree_pullback_weights(Δ) + # We propagate the gradient only to the non-zero elements + # of the adjacency matrix. + bA = binarize(A) + if dir == :in + dA = bA .* Δ' + elseif dir == :out + dA = Δ .* bA + else # dir == :both + dA = Δ .* bA + Δ' .* bA + end + return (NoTangent(), dA, NoTangent(), NoTangent(), NoTangent(), NoTangent()) + end + return degs, _degree_pullback_weights + end +end + +""" + has_isolated_nodes(g::GNNGraph; dir=:out) + +Return true if the graph `g` contains nodes with out-degree (if `dir=:out`) +or in-degree (if `dir = :in`) equal to zero. +""" +function has_isolated_nodes(g::GNNGraph; dir = :out) + return any(iszero, degree(g; dir)) +end + +function Graphs.laplacian_matrix(g::GNNGraph, T::DataType = eltype(g); dir::Symbol = :out) + A = adjacency_matrix(g, T; dir = dir) + D = Diagonal(vec(sum(A; dims = 2))) + return D - A +end + +""" + normalized_laplacian(g, T=Float32; add_self_loops=false, dir=:out) + +Normalized Laplacian matrix of graph `g`. + +# Arguments + +- `g`: A `GNNGraph`. +- `T`: result element type. +- `add_self_loops`: add self-loops while calculating the matrix. +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function normalized_laplacian(g::GNNGraph, T::DataType = Float32; + add_self_loops::Bool = false, dir::Symbol = :out) + Ã = normalized_adjacency(g, T; dir, add_self_loops) + return I - Ã +end + +function normalized_adjacency(g::GNNGraph, T::DataType = Float32; + add_self_loops::Bool = false, dir::Symbol = :out) + A = adjacency_matrix(g, T; dir = dir) + if add_self_loops + A = A + I + end + degs = vec(sum(A; dims = 2)) + ChainRulesCore.ignore_derivatives() do + @assert all(!iszero, degs) "Graph contains isolated nodes, cannot compute `normalized_adjacency`." + end + inv_sqrtD = Diagonal(inv.(sqrt.(degs))) + return inv_sqrtD * A * inv_sqrtD +end + +@doc raw""" + scaled_laplacian(g, T=Float32; dir=:out) + +Scaled Laplacian matrix of graph `g`, +defined as ``\hat{L} = \frac{2}{\lambda_{max}} L - I`` where ``L`` is the normalized Laplacian matrix. + +# Arguments + +- `g`: A `GNNGraph`. +- `T`: result element type. +- `dir`: the edge directionality considered (:out, :in, :both). +""" +function scaled_laplacian(g::GNNGraph, T::DataType = Float32; dir = :out) + L = normalized_laplacian(g, T) + # @assert issymmetric(L) "scaled_laplacian only works with symmetric matrices" + λmax = _eigmax(L) + return 2 / λmax * L - I +end + +# _eigmax(A) = eigmax(Symmetric(A)) # Doesn't work on sparse arrays +function _eigmax(A) + x0 = _rand_dense_vector(A) + KrylovKit.eigsolve(Symmetric(A), x0, 1, :LR)[1][1] # also eigs(A, x0, nev, mode) available +end + +_rand_dense_vector(A::AbstractMatrix{T}) where {T} = randn(float(T), size(A, 1)) + +# Eigenvalues for cuarray don't seem to be well supported. +# https://github.com/JuliaGPU/CUDA.jl/issues/154 +# https://discourse.julialang.org/t/cuda-eigenvalues-of-a-sparse-matrix/46851/5 + +""" + graph_indicator(g::GNNGraph; edges=false) + +Return a vector containing the graph membership +(an integer from `1` to `g.num_graphs`) of each node in the graph. +If `edges=true`, return the graph membership of each edge instead. +""" +function graph_indicator(g::GNNGraph; edges = false) + if isnothing(g.graph_indicator) + gi = ones_like(edge_index(g)[1], Int, g.num_nodes) + else + gi = g.graph_indicator + end + if edges + s, t = edge_index(g) + return gi[s] + else + return gi + end +end + +""" + graph_indicator(g::GNNHeteroGraph, [node_t]) + +Return a Dict of vectors containing the graph membership +(an integer from `1` to `g.num_graphs`) of each node in the graph for each node type. +If `node_t` is provided, return the graph membership of each node of type `node_t` instead. + +See also [`batch`](@ref). +""" +function graph_indicator(g::GNNHeteroGraph) + return g.graph_indicator +end + +function graph_indicator(g::GNNHeteroGraph, node_t::Symbol) + @assert node_t ∈ g.ntypes + if isnothing(g.graph_indicator) + gi = ones_like(edge_index(g, first(g.etypes))[1], Int, g.num_nodes[node_t]) + else + gi = g.graph_indicator[node_t] + end + return gi +end + +function node_features(g::GNNGraph) + if isempty(g.ndata) + return nothing + elseif length(g.ndata) > 1 + @error "Multiple feature arrays, access directly through `g.ndata`" + else + return first(values(g.ndata)) + end +end + +function edge_features(g::GNNGraph) + if isempty(g.edata) + return nothing + elseif length(g.edata) > 1 + @error "Multiple feature arrays, access directly through `g.edata`" + else + return first(values(g.edata)) + end +end + +function graph_features(g::GNNGraph) + if isempty(g.gdata) + return nothing + elseif length(g.gdata) > 1 + @error "Multiple feature arrays, access directly through `g.gdata`" + else + return first(values(g.gdata)) + end +end + +""" + is_bidirected(g::GNNGraph) + +Check if the directed graph `g` essentially corresponds +to an undirected graph, i.e. if for each edge it also contains the +reverse edge. +""" +function is_bidirected(g::GNNGraph) + s, t = edge_index(g) + s1, t1 = sort_edge_index(s, t) + s2, t2 = sort_edge_index(t, s) + all((s1 .== s2) .& (t1 .== t2)) +end + +""" + has_self_loops(g::GNNGraph) + +Return `true` if `g` has any self loops. +""" +function Graphs.has_self_loops(g::GNNGraph) + s, t = edge_index(g) + any(s .== t) +end + +""" + has_multi_edges(g::GNNGraph) + +Return `true` if `g` has any multiple edges. +""" +function has_multi_edges(g::GNNGraph) + s, t = edge_index(g) + idxs, _ = edge_encoding(s, t, g.num_nodes) + length(union(idxs)) < length(idxs) +end + +""" + khop_adj(g::GNNGraph,k::Int,T::DataType=eltype(g); dir=:out, weighted=true) + +Return ``A^k`` where ``A`` is the adjacency matrix of the graph 'g'. + +""" +function khop_adj(g::GNNGraph, k::Int, T::DataType = eltype(g); dir = :out, weighted = true) + return (adjacency_matrix(g, T; dir, weighted))^k +end + +""" + laplacian_lambda_max(g::GNNGraph, T=Float32; add_self_loops=false, dir=:out) + +Return the largest eigenvalue of the normalized symmetric Laplacian of the graph `g`. + +If the graph is batched from multiple graphs, return the list of the largest eigenvalue for each graph. +""" +function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32; + add_self_loops::Bool = false, dir::Symbol = :out) + if g.num_graphs == 1 + return _eigmax(normalized_laplacian(g, T; add_self_loops, dir)) + else + eigenvalues = zeros(g.num_graphs) + for i in 1:(g.num_graphs) + eigenvalues[i] = _eigmax(normalized_laplacian(getgraph(g, i), T; add_self_loops, + dir)) + end + return eigenvalues + end +end + +@non_differentiable edge_index(x...) +@non_differentiable adjacency_list(x...) +@non_differentiable graph_indicator(x...) +@non_differentiable has_multi_edges(x...) +@non_differentiable Graphs.has_self_loops(x...) +@non_differentiable is_bidirected(x...) +@non_differentiable normalized_adjacency(x...) # TODO remove this in the future +@non_differentiable normalized_laplacian(x...) # TODO remove this in the future +@non_differentiable scaled_laplacian(x...) # TODO remove this in the future diff --git a/GNNlib/src/GNNGraphs/sampling.jl b/GNNlib/src/GNNGraphs/sampling.jl new file mode 100644 index 000000000..01a601f5b --- /dev/null +++ b/GNNlib/src/GNNGraphs/sampling.jl @@ -0,0 +1,118 @@ +""" + sample_neighbors(g, nodes, K=-1; dir=:in, replace=false, dropnodes=false) + +Sample neighboring edges of the given nodes and return the induced subgraph. +For each node, a number of inbound (or outbound when `dir = :out``) edges will be randomly chosen. +If `dropnodes=false`, the graph returned will then contain all the nodes in the original graph, +but only the sampled edges. + +The returned graph will contain an edge feature `EID` corresponding to the id of the edge +in the original graph. If `dropnodes=true`, it will also contain a node feature `NID` with +the node ids in the original graph. + +# Arguments + +- `g`. The graph. +- `nodes`. A list of node IDs to sample neighbors from. +- `K`. The maximum number of edges to be sampled for each node. + If -1, all the neighboring edges will be selected. +- `dir`. Determines whether to sample inbound (`:in`) or outbound (``:out`) edges (Default `:in`). +- `replace`. If `true`, sample with replacement. +- `dropnodes`. If `true`, the resulting subgraph will contain only the nodes involved in the sampled edges. + +# Examples + +```julia +julia> g = rand_graph(20, 100) +GNNGraph: + num_nodes = 20 + num_edges = 100 + +julia> sample_neighbors(g, 2:3) +GNNGraph: + num_nodes = 20 + num_edges = 9 + edata: + EID => (9,) + +julia> sg = sample_neighbors(g, 2:3, dropnodes=true) +GNNGraph: + num_nodes = 10 + num_edges = 9 + ndata: + NID => (10,) + edata: + EID => (9,) + +julia> sg.ndata.NID +10-element Vector{Int64}: + 2 + 3 + 17 + 14 + 18 + 15 + 16 + 20 + 7 + 10 + +julia> sample_neighbors(g, 2:3, 5, replace=true) +GNNGraph: + num_nodes = 20 + num_edges = 10 + edata: + EID => (10,) +``` +""" +function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1; + dir = :in, replace = false, dropnodes = false) + @assert dir ∈ (:in, :out) + _, eidlist = adjacency_list(g, nodes; dir, with_eid = true) + for i in 1:length(eidlist) + if replace + k = K > 0 ? K : length(eidlist[i]) + else + k = K > 0 ? min(length(eidlist[i]), K) : length(eidlist[i]) + end + eidlist[i] = StatsBase.sample(eidlist[i], k; replace) + end + eids = reduce(vcat, eidlist) + s, t = edge_index(g) + w = get_edge_weight(g) + s = s[eids] + t = t[eids] + w = isnothing(w) ? nothing : w[eids] + + edata = getobs(g.edata, eids) + edata.EID = eids + + num_edges = length(eids) + + if !dropnodes + graph = (s, t, w) + + gnew = GNNGraph(graph, + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) + else + nodes_other = dir == :in ? setdiff(s, nodes) : setdiff(t, nodes) + nodes_all = [nodes; nodes_other] + nodemap = Dict(n => i for (i, n) in enumerate(nodes_all)) + s = [nodemap[s] for s in s] + t = [nodemap[t] for t in t] + graph = (s, t, w) + graph_indicator = g.graph_indicator !== nothing ? g.graph_indicator[nodes_all] : + nothing + num_nodes = length(nodes_all) + ndata = getobs(g.ndata, nodes_all) + ndata.NID = nodes_all + + gnew = GNNGraph(graph, + num_nodes, num_edges, g.num_graphs, + graph_indicator, + ndata, edata, g.gdata) + end + return gnew +end diff --git a/GNNlib/src/GNNGraphs/temporalsnapshotsgnngraph.jl b/GNNlib/src/GNNGraphs/temporalsnapshotsgnngraph.jl new file mode 100644 index 000000000..a08d069a2 --- /dev/null +++ b/GNNlib/src/GNNGraphs/temporalsnapshotsgnngraph.jl @@ -0,0 +1,244 @@ +""" + TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph}) + +A type representing a temporal graph as a sequence of snapshots. In this case a snapshot is a [`GNNGraph`](@ref). + +`TemporalSnapshotsGNNGraph` can store the feature array associated to the graph itself as a [`DataStore`](@ref) object, +and it uses the [`DataStore`](@ref) objects of each snapshot for the node and edge features. +The features can be passed at construction time or added later. + +# Constructor Arguments + +- `snapshot`: a vector of snapshots, where each snapshot must have the same number of nodes. + +# Examples + +```julia +julia> using GraphNeuralNetworks + +julia> snapshots = [rand_graph(10,20) for i in 1:5]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10] + num_edges: [20, 20, 20, 20, 20] + num_snapshots: 5 + +julia> tg.tgdata.x = rand(4); # add temporal graph feature + +julia> tg # show temporal graph with new feature +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10] + num_edges: [20, 20, 20, 20, 20] + num_snapshots: 5 + tgdata: + x = 4-element Vector{Float64} +``` +""" +struct TemporalSnapshotsGNNGraph + num_nodes::AbstractVector{Int} + num_edges::AbstractVector{Int} + num_snapshots::Int + snapshots::AbstractVector{<:GNNGraph} + tgdata::DataStore +end + +function TemporalSnapshotsGNNGraph(snapshots::AbstractVector{<:GNNGraph}) + @assert all([s.num_nodes == snapshots[1].num_nodes for s in snapshots]) "all snapshots must have the same number of nodes" + return TemporalSnapshotsGNNGraph( + [s.num_nodes for s in snapshots], + [s.num_edges for s in snapshots], + length(snapshots), + snapshots, + DataStore() + ) +end + +function Base.:(==)(tsg1::TemporalSnapshotsGNNGraph, tsg2::TemporalSnapshotsGNNGraph) + tsg1 === tsg2 && return true + for k in fieldnames(typeof(tsg1)) + getfield(tsg1, k) != getfield(tsg2, k) && return false + end + return true +end + +function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::Int) + return tg.snapshots[t] +end + +function Base.getindex(tg::TemporalSnapshotsGNNGraph, t::AbstractVector) + return TemporalSnapshotsGNNGraph(tg.num_nodes[t], tg.num_edges[t], length(t), tg.snapshots[t], tg.tgdata) +end + +""" + add_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) + +Return a `TemporalSnapshotsGNNGraph` created starting from `tg` by adding the snapshot `g` at time index `t`. + +# Examples + +```jldoctest +julia> using GraphNeuralNetworks + +julia> snapshots = [rand_graph(10, 20) for i in 1:5]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10] + num_edges: [20, 20, 20, 20, 20] + num_snapshots: 5 + +julia> new_tg = add_snapshot(tg, 3, rand_graph(10, 16)) # add a new snapshot at time 3 +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10, 10, 10, 10] + num_edges: [20, 20, 16, 20, 20, 20] + num_snapshots: 6 +``` +""" +function add_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) + if tg.num_snapshots > 0 + @assert g.num_nodes == first(tg.num_nodes) "number of nodes must match" + end + @assert t <= tg.num_snapshots + 1 "cannot add snapshot at time $t, the temporal graph has only $(tg.num_snapshots) snapshots" + num_nodes = tg.num_nodes |> copy + num_edges = tg.num_edges |> copy + snapshots = tg.snapshots |> copy + num_snapshots = tg.num_snapshots + 1 + insert!(num_nodes, t, g.num_nodes) + insert!(num_edges, t, g.num_edges) + insert!(snapshots, t, g) + return TemporalSnapshotsGNNGraph(num_nodes, num_edges, num_snapshots, snapshots, tg.tgdata) +end + +# """ +# add_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) + +# Add to `tg` the snapshot `g` at time index `t`. + +# See also [`add_snapshot`](@ref) for a non-mutating version. +# """ +# function add_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int, g::GNNGraph) +# if t > tg.num_snapshots + 1 +# error("cannot add snapshot at time $t, the temporal graph has only $(tg.num_snapshots) snapshots") +# end +# if tg.num_snapshots > 0 +# @assert g.num_nodes == first(tg.num_nodes) "number of nodes must match" +# end +# insert!(tg.num_nodes, t, g.num_nodes) +# insert!(tg.num_edges, t, g.num_edges) +# insert!(tg.snapshots, t, g) +# return tg +# end + +""" + remove_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int) + +Return a [`TemporalSnapshotsGNNGraph`](@ref) created starting from `tg` by removing the snapshot at time index `t`. + +# Examples + +```jldoctest +julia> using GraphNeuralNetworks + +julia> snapshots = [rand_graph(10,20), rand_graph(10,14), rand_graph(10,22)]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10] + num_edges: [20, 14, 22] + num_snapshots: 3 + +julia> new_tg = remove_snapshot(tg, 2) # remove snapshot at time 2 +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10] + num_edges: [20, 22] + num_snapshots: 2 +``` +""" +function remove_snapshot(tg::TemporalSnapshotsGNNGraph, t::Int) + num_nodes = tg.num_nodes |> copy + num_edges = tg.num_edges |> copy + snapshots = tg.snapshots |> copy + num_snapshots = tg.num_snapshots - 1 + deleteat!(num_nodes, t) + deleteat!(num_edges, t) + deleteat!(snapshots, t) + return TemporalSnapshotsGNNGraph(num_nodes, num_edges, num_snapshots, snapshots, tg.tgdata) +end + +# """ +# remove_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int) + +# Remove the snapshot at time index `t` from `tg` and return `tg`. + +# See [`remove_snapshot`](@ref) for a non-mutating version. +# """ +# function remove_snapshot!(tg::TemporalSnapshotsGNNGraph, t::Int) +# @assert t <= tg.num_snapshots "snapshot index $t out of bounds" +# tg.num_snapshots -= 1 +# deleteat!(tg.num_nodes, t) +# deleteat!(tg.num_edges, t) +# deleteat!(tg.snapshots, t) +# return tg +# end + +function Base.getproperty(tg::TemporalSnapshotsGNNGraph, prop::Symbol) + if prop ∈ fieldnames(TemporalSnapshotsGNNGraph) + return getfield(tg, prop) + elseif prop == :ndata + return [s.ndata for s in tg.snapshots] + elseif prop == :edata + return [s.edata for s in tg.snapshots] + elseif prop == :gdata + return [s.gdata for s in tg.snapshots] + else + return [getproperty(s,prop) for s in tg.snapshots] + end +end + +function Base.show(io::IO, tsg::TemporalSnapshotsGNNGraph) + print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ") + print_feature_t(io, tsg.tgdata) + print(io, " data") +end + +function Base.show(io::IO, ::MIME"text/plain", tsg::TemporalSnapshotsGNNGraph) + if get(io, :compact, false) + print(io, "TemporalSnapshotsGNNGraph($(tsg.num_snapshots)) with ") + print_feature_t(io, tsg.tgdata) + print(io, " data") + else + print(io, + "TemporalSnapshotsGNNGraph:\n num_nodes: $(tsg.num_nodes)\n num_edges: $(tsg.num_edges)\n num_snapshots: $(tsg.num_snapshots)") + if !isempty(tsg.tgdata) + print(io, "\n tgdata:") + for k in keys(tsg.tgdata) + print(io, "\n\t$k = $(shortsummary(tsg.tgdata[k]))") + end + end + end +end + +function print_feature_t(io::IO, feature) + if !isempty(feature) + if length(keys(feature)) == 1 + k = first(keys(feature)) + v = first(values(feature)) + print(io, "$(k): $(dims2string(size(v)))") + else + print(io, "(") + for (i, (k, v)) in enumerate(pairs(feature)) + print(io, "$k: $(dims2string(size(v)))") + if i == length(feature) + print(io, ")") + else + print(io, ", ") + end + end + end + else + print(io, "no") + end +end + +@functor TemporalSnapshotsGNNGraph diff --git a/GNNlib/src/GNNGraphs/transform.jl b/GNNlib/src/GNNGraphs/transform.jl new file mode 100644 index 000000000..0c8e7d74b --- /dev/null +++ b/GNNlib/src/GNNGraphs/transform.jl @@ -0,0 +1,1131 @@ + +""" + add_self_loops(g::GNNGraph) + +Return a graph with the same features as `g` +but also adding edges connecting the nodes to themselves. + +Nodes with already existing self-loops will obtain a second self-loop. + +If the graphs has edge weights, the new edges will have weight 1. +""" +function add_self_loops(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + @assert isempty(g.edata) + ew = get_edge_weight(g) + n = g.num_nodes + nodes = convert(typeof(s), [1:n;]) + s = [s; nodes] + t = [t; nodes] + if ew !== nothing + ew = [ew; fill!(similar(ew, n), 1)] + end + + return GNNGraph((s, t, ew), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +function add_self_loops(g::GNNGraph{<:ADJMAT_T}) + A = g.graph + @assert isempty(g.edata) + num_edges = g.num_edges + g.num_nodes + A = A + I + return GNNGraph(A, + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +""" + add_self_loops(g::GNNHeteroGraph, edge_t::EType) + add_self_loops(g::GNNHeteroGraph) + +If the source node type is the same as the destination node type in `edge_t`, +return a graph with the same features as `g` but also add self-loops +of the specified type, `edge_t`. Otherwise, it returns `g` unchanged. + +Nodes with already existing self-loops of type `edge_t` will obtain +a second set of self-loops of the same type. + +If the graph has edge weights for edges of type `edge_t`, the new edges will have weight 1. + +If no edges of type `edge_t` exist, or all existing edges have no weight, +then all new self loops will have no weight. + +If `edge_t` is not passed as argument, for the entire graph self-loop is added to each node for every edge type in the graph where the source and destination node types are the same. +This iterates over all edge types present in the graph, applying the self-loop addition logic to each applicable edge type. +""" +function add_self_loops(g::GNNHeteroGraph{Tuple{T, T, V}}, edge_t::EType) where {T <: AbstractVector{<:Integer}, V} + function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + get(g.graph, edge_t, (nothing, nothing, nothing))[3] + end + + src_t, _, tgt_t = edge_t + (src_t === tgt_t) || + return g + + n = get(g.num_nodes, src_t, 0) + + if haskey(g.graph, edge_t) + x = g.graph[edge_t] + s, t = x[1:2] + nodes = convert(typeof(s), [1:n;]) + s = [s; nodes] + t = [t; nodes] + else + nodes = convert(T, [1:n;]) + s = nodes + t = nodes + end + + graph = g.graph |> copy + ew = get(g.graph, edge_t, (nothing, nothing, nothing))[3] + + if ew !== nothing + ew = [ew; fill!(similar(ew, n), 1)] + end + + graph[edge_t] = (s, t, ew) + edata = g.edata |> copy + ndata = g.ndata |> copy + ntypes = g.ntypes |> copy + etypes = g.etypes |> copy + num_nodes = g.num_nodes |> copy + num_edges = g.num_edges |> copy + num_edges[edge_t] = length(get(graph, edge_t, ([],[]))[1]) + + return GNNHeteroGraph(graph, + num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + ndata, edata, g.gdata, + ntypes, etypes) +end + +function add_self_loops(g::GNNHeteroGraph) + for edge_t in keys(g.graph) + g = add_self_loops(g, edge_t) + end + return g +end + +""" + remove_self_loops(g::GNNGraph) + +Return a graph constructed from `g` where self-loops (edges from a node to itself) +are removed. + +See also [`add_self_loops`](@ref) and [`remove_multi_edges`](@ref). +""" +function remove_self_loops(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + + mask_old_loops = s .!= t + s = s[mask_old_loops] + t = t[mask_old_loops] + edata = getobs(edata, mask_old_loops) + w = isnothing(w) ? nothing : getobs(w, mask_old_loops) + + GNNGraph((s, t, w), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + +function remove_self_loops(g::GNNGraph{<:ADJMAT_T}) + @assert isempty(g.edata) + A = g.graph + A[diagind(A)] .= 0 + if A isa AbstractSparseMatrix + dropzeros!(A) + end + num_edges = numnonzeros(A) + return GNNGraph(A, + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +""" + remove_edges(g::GNNGraph, edges_to_remove::AbstractVector{<:Integer}) + +Remove specified edges from a GNNGraph. + +# Arguments +- `g`: The input graph from which edges will be removed. +- `edges_to_remove`: Vector of edge indices to be removed. + +# Returns +A new GNNGraph with the specified edges removed. + +# Example +```julia +julia> using GraphNeuralNetworks + +# Construct a GNNGraph +julia> g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1]) +GNNGraph: + num_nodes: 3 + num_edges: 5 + +# Remove the second edge +julia> g_new = remove_edges(g, [2]); + +julia> g_new +GNNGraph: + num_nodes: 3 + num_edges: 4 +``` +""" +function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:Integer}) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + + mask_to_keep = trues(length(s)) + + mask_to_keep[edges_to_remove] .= false + + s = s[mask_to_keep] + t = t[mask_to_keep] + edata = getobs(edata, mask_to_keep) + w = isnothing(w) ? nothing : getobs(w, mask_to_keep) + + return GNNGraph((s, t, w), + g.num_nodes, length(s), g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + +""" + remove_multi_edges(g::GNNGraph; aggr=+) + +Remove multiple edges (also called parallel edges or repeated edges) from graph `g`. +Possible edge features are aggregated according to `aggr`, that can take value +`+`,`min`, `max` or `mean`. + +See also [`remove_self_loops`](@ref), [`has_multi_edges`](@ref), and [`to_bidirected`](@ref). +""" +function remove_multi_edges(g::GNNGraph{<:COO_T}; aggr = +) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + num_edges = g.num_edges + idxs, idxmax = edge_encoding(s, t, g.num_nodes) + + perm = sortperm(idxs) + idxs = idxs[perm] + s, t = s[perm], t[perm] + edata = getobs(edata, perm) + w = isnothing(w) ? nothing : getobs(w, perm) + idxs = [-1; idxs] + mask = idxs[2:end] .> idxs[1:(end - 1)] + if !all(mask) + s, t = s[mask], t[mask] + idxs = similar(s, num_edges) + idxs .= 1:num_edges + idxs .= idxs .- cumsum(.!mask) + num_edges = length(s) + w = _scatter(aggr, w, idxs, num_edges) + edata = _scatter(aggr, edata, idxs, num_edges) + end + + return GNNGraph((s, t, w), + g.num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) +end + +""" + remove_nodes(g::GNNGraph, nodes_to_remove::AbstractVector) + +Remove specified nodes, and their associated edges, from a GNNGraph. This operation reindexes the remaining nodes to maintain a continuous sequence of node indices, starting from 1. Similarly, edges are reindexed to account for the removal of edges connected to the removed nodes. + +# Arguments +- `g`: The input graph from which nodes (and their edges) will be removed. +- `nodes_to_remove`: Vector of node indices to be removed. + +# Returns +A new GNNGraph with the specified nodes and all edges associated with these nodes removed. + +# Example +```julia +using GraphNeuralNetworks + +g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1]) + +# Remove nodes with indices 2 and 3, for example +g_new = remove_nodes(g, [2, 3]) + +# g_new now does not contain nodes 2 and 3, and any edges that were connected to these nodes. +println(g_new) +``` +""" +function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector) + nodes_to_remove = sort(union(nodes_to_remove)) + s, t = edge_index(g) + w = get_edge_weight(g) + edata = g.edata + ndata = g.ndata + + function find_edges_to_remove(nodes, nodes_to_remove) + return findall(node_id -> begin + idx = searchsortedlast(nodes_to_remove, node_id) + idx >= 1 && idx <= length(nodes_to_remove) && nodes_to_remove[idx] == node_id + end, nodes) + end + + edges_to_remove_s = find_edges_to_remove(s, nodes_to_remove) + edges_to_remove_t = find_edges_to_remove(t, nodes_to_remove) + edges_to_remove = union(edges_to_remove_s, edges_to_remove_t) + + mask_edges_to_keep = trues(length(s)) + mask_edges_to_keep[edges_to_remove] .= false + s = s[mask_edges_to_keep] + t = t[mask_edges_to_keep] + + w = isnothing(w) ? nothing : getobs(w, mask_edges_to_keep) + + for node in sort(nodes_to_remove, rev=true) + s[s .> node] .-= 1 + t[t .> node] .-= 1 + end + + nodes_to_keep = setdiff(1:g.num_nodes, nodes_to_remove) + ndata = getobs(ndata, nodes_to_keep) + edata = getobs(edata, mask_edges_to_keep) + + num_nodes = g.num_nodes - length(nodes_to_remove) + + return GNNGraph((s, t, w), + num_nodes, length(s), g.num_graphs, + g.graph_indicator, + ndata, edata, g.gdata) +end + +""" + add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata]) + add_edges(g::GNNGraph, (s, t); [edata]) + add_edges(g::GNNGraph, (s, t, w); [edata]) + +Add to graph `g` the edges with source nodes `s` and target nodes `t`. +Optionally, pass the edge weight `w` and the features `edata` for the new edges. +Returns a new graph sharing part of the underlying data with `g`. + +If the `s` or `t` contain nodes that are not already present in the graph, +they are added to the graph as well. + +# Examples + +```jldoctest +julia> s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4]; + +julia> w = Float32[1.0, 2.0, 3.0, 4.0, 5.0]; + +julia> g = GNNGraph((s, t, w)) +GNNGraph: + num_nodes: 4 + num_edges: 5 + +julia> add_edges(g, ([2, 3], [4, 1], [10.0, 20.0])) +GNNGraph: + num_nodes: 4 + num_edges: 7 +``` +```jldoctest +julia> g = GNNGraph() +GNNGraph: + num_nodes: 0 + num_edges: 0 + +julia> add_edges(g, [1,2], [2,3]) +GNNGraph: + num_nodes: 3 + num_edges: 2 +``` +""" +add_edges(g::GNNGraph{<:COO_T}, snew::AbstractVector, tnew::AbstractVector; kws...) = add_edges(g, (snew, tnew, nothing); kws...) +add_edges(g, data::Tuple{<:AbstractVector, <:AbstractVector}; kws...) = add_edges(g, (data..., nothing); kws...) + +function add_edges(g::GNNGraph{<:COO_T}, data::COO_T; edata = nothing) + snew, tnew, wnew = data + @assert length(snew) == length(tnew) + @assert isnothing(wnew) || length(wnew) == length(snew) + if length(snew) == 0 + return g + end + @assert minimum(snew) >= 1 + @assert minimum(tnew) >= 1 + num_new = length(snew) + edata = normalize_graphdata(edata, default_name = :e, n = num_new) + edata = cat_features(g.edata, edata) + + s, t = edge_index(g) + s = [s; snew] + t = [t; tnew] + w = get_edge_weight(g) + w = cat_features(w, wnew, g.num_edges, num_new) + + num_nodes = max(maximum(snew), maximum(tnew), g.num_nodes) + if num_nodes > g.num_nodes + ndata_new = normalize_graphdata((;), default_name = :x, n = num_nodes - g.num_nodes) + ndata = cat_features(g.ndata, ndata_new) + else + ndata = g.ndata + end + + return GNNGraph((s, t, w), + num_nodes, length(s), g.num_graphs, + g.graph_indicator, + ndata, edata, g.gdata) +end + +""" + add_edges(g::GNNHeteroGraph, edge_t, s, t; [edata, num_nodes]) + add_edges(g::GNNHeteroGraph, edge_t => (s, t); [edata, num_nodes]) + add_edges(g::GNNHeteroGraph, edge_t => (s, t, w); [edata, num_nodes]) + +Add to heterograph `g` edges of type `edge_t` with source node vector `s` and target node vector `t`. +Optionally, pass the edge weights `w` or the features `edata` for the new edges. +`edge_t` is a triplet of symbols `(src_t, rel_t, dst_t)`. + +If the edge type is not already present in the graph, it is added. +If it involves new node types, they are added to the graph as well. +In this case, a dictionary or named tuple of `num_nodes` can be passed to specify the number of nodes of the new types, +otherwise the number of nodes is inferred from the maximum node id in `s` and `t`. +""" +add_edges(g::GNNHeteroGraph{<:COO_T}, edge_t::EType, snew::AbstractVector, tnew::AbstractVector; kws...) = add_edges(g, edge_t => (snew, tnew, nothing); kws...) +add_edges(g::GNNHeteroGraph{<:COO_T}, data::Pair{EType, <:Tuple{<:AbstractVector, <:AbstractVector}}; kws...) = add_edges(g, data.first => (data.second..., nothing); kws...) + +function add_edges(g::GNNHeteroGraph{<:COO_T}, + data::Pair{EType, <:COO_T}; + edata = nothing, + num_nodes = Dict{Symbol,Int}()) + edge_t, (snew, tnew, wnew) = data + @assert length(snew) == length(tnew) + if length(snew) == 0 + return g + end + @assert minimum(snew) >= 1 + @assert minimum(tnew) >= 1 + + is_existing_rel = haskey(g.graph, edge_t) + + edata = normalize_graphdata(edata, default_name = :e, n = length(snew)) + _edata = g.edata |> copy + if haskey(_edata, edge_t) + _edata[edge_t] = cat_features(g.edata[edge_t], edata) + else + _edata[edge_t] = edata + end + + graph = g.graph |> copy + etypes = g.etypes |> copy + ntypes = g.ntypes |> copy + _num_nodes = g.num_nodes |> copy + ndata = g.ndata |> copy + if !is_existing_rel + for (node_t, st) in [(edge_t[1], snew), (edge_t[3], tnew)] + if node_t ∉ ntypes + push!(ntypes, node_t) + if haskey(num_nodes, node_t) + _num_nodes[node_t] = num_nodes[node_t] + else + _num_nodes[node_t] = maximum(st) + end + ndata[node_t] = DataStore(_num_nodes[node_t]) + end + end + push!(etypes, edge_t) + else + s, t = edge_index(g, edge_t) + snew = [s; snew] + tnew = [t; tnew] + w = get_edge_weight(g, edge_t) + wnew = cat_features(w, wnew, length(s), length(snew)) + end + + if maximum(snew) > _num_nodes[edge_t[1]] + ndata_new = normalize_graphdata((;), default_name = :x, n = maximum(snew) - _num_nodes[edge_t[1]]) + ndata[edge_t[1]] = cat_features(ndata[edge_t[1]], ndata_new) + _num_nodes[edge_t[1]] = maximum(snew) + end + if maximum(tnew) > _num_nodes[edge_t[3]] + ndata_new = normalize_graphdata((;), default_name = :x, n = maximum(tnew) - _num_nodes[edge_t[3]]) + ndata[edge_t[3]] = cat_features(ndata[edge_t[3]], ndata_new) + _num_nodes[edge_t[3]] = maximum(tnew) + end + + graph[edge_t] = (snew, tnew, wnew) + num_edges = g.num_edges |> copy + num_edges[edge_t] = length(graph[edge_t][1]) + + return GNNHeteroGraph(graph, + _num_nodes, num_edges, g.num_graphs, + g.graph_indicator, + ndata, _edata, g.gdata, + ntypes, etypes) +end + + + +### TODO Cannot implement this since GNNGraph is immutable (cannot change num_edges). make it mutable +# function Graphs.add_edge!(g::GNNGraph{<:COO_T}, snew::T, tnew::T; edata=nothing) where T<:Union{Integer, AbstractVector} +# s, t = edge_index(g) +# @assert length(snew) == length(tnew) +# # TODO remove this constraint +# @assert get_edge_weight(g) === nothing + +# edata = normalize_graphdata(edata, default_name=:e, n=length(snew)) +# edata = cat_features(g.edata, edata) + +# s, t = edge_index(g) +# append!(s, snew) +# append!(t, tnew) +# g.num_edges += length(snew) +# return true +# end + +""" + to_bidirected(g) + +Adds a reverse edge for each edge in the graph, then calls +[`remove_multi_edges`](@ref) with `mean` aggregation to simplify the graph. + +See also [`is_bidirected`](@ref). + +# Examples + +```jldoctest +julia> s, t = [1, 2, 3, 3, 4], [2, 3, 4, 4, 4]; + +julia> w = [1.0, 2.0, 3.0, 4.0, 5.0]; + +julia> e = [10.0, 20.0, 30.0, 40.0, 50.0]; + +julia> g = GNNGraph(s, t, w, edata = e) +GNNGraph: + num_nodes = 4 + num_edges = 5 + edata: + e => (5,) + +julia> g2 = to_bidirected(g) +GNNGraph: + num_nodes = 4 + num_edges = 7 + edata: + e => (7,) + +julia> edge_index(g2) +([1, 2, 2, 3, 3, 4, 4], [2, 1, 3, 2, 4, 3, 4]) + +julia> get_edge_weight(g2) +7-element Vector{Float64}: + 1.0 + 1.0 + 2.0 + 2.0 + 3.5 + 3.5 + 5.0 + +julia> g2.edata.e +7-element Vector{Float64}: + 10.0 + 10.0 + 20.0 + 20.0 + 35.0 + 35.0 + 50.0 +``` +""" +function to_bidirected(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + w = get_edge_weight(g) + snew = [s; t] + tnew = [t; s] + w = cat_features(w, w) + edata = cat_features(g.edata, g.edata) + + g = GNNGraph((snew, tnew, w), + g.num_nodes, length(snew), g.num_graphs, + g.graph_indicator, + g.ndata, edata, g.gdata) + + return remove_multi_edges(g; aggr = mean) +end + +""" + to_unidirected(g::GNNGraph) + +Return a graph that for each multiple edge between two nodes in `g` +keeps only an edge in one direction. +""" +function to_unidirected(g::GNNGraph{<:COO_T}) + s, t = edge_index(g) + w = get_edge_weight(g) + idxs, _ = edge_encoding(s, t, g.num_nodes, directed = false) + snew, tnew = edge_decoding(idxs, g.num_nodes, directed = false) + + g = GNNGraph((snew, tnew, w), + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) + + return remove_multi_edges(g; aggr = mean) +end + +function Graphs.SimpleGraph(g::GNNGraph) + G = Graphs.SimpleGraph(g.num_nodes) + for e in Graphs.edges(g) + Graphs.add_edge!(G, e) + end + return G +end +function Graphs.SimpleDiGraph(g::GNNGraph) + G = Graphs.SimpleDiGraph(g.num_nodes) + for e in Graphs.edges(g) + Graphs.add_edge!(G, e) + end + return G +end + +""" + add_nodes(g::GNNGraph, n; [ndata]) + +Add `n` new nodes to graph `g`. In the +new graph, these nodes will have indexes from `g.num_nodes + 1` +to `g.num_nodes + n`. +""" +function add_nodes(g::GNNGraph{<:COO_T}, n::Integer; ndata = (;)) + ndata = normalize_graphdata(ndata, default_name = :x, n = n) + ndata = cat_features(g.ndata, ndata) + + GNNGraph(g.graph, + g.num_nodes + n, g.num_edges, g.num_graphs, + g.graph_indicator, + ndata, g.edata, g.gdata) +end + +""" + set_edge_weight(g::GNNGraph, w::AbstractVector) + +Set `w` as edge weights in the returned graph. +""" +function set_edge_weight(g::GNNGraph, w::AbstractVector) + s, t = edge_index(g) + @assert length(w) == length(s) + + return GNNGraph((s, t, w), + g.num_nodes, g.num_edges, g.num_graphs, + g.graph_indicator, + g.ndata, g.edata, g.gdata) +end + +function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph) + nv1, nv2 = g1.num_nodes, g2.num_nodes + if g1.graph isa COO_T + s1, t1 = edge_index(g1) + s2, t2 = edge_index(g2) + s = vcat(s1, nv1 .+ s2) + t = vcat(t1, nv1 .+ t2) + w = cat_features(get_edge_weight(g1), get_edge_weight(g2)) + graph = (s, t, w) + ind1 = isnothing(g1.graph_indicator) ? ones_like(s1, nv1) : g1.graph_indicator + ind2 = isnothing(g2.graph_indicator) ? ones_like(s2, nv2) : g2.graph_indicator + elseif g1.graph isa ADJMAT_T + graph = blockdiag(g1.graph, g2.graph) + ind1 = isnothing(g1.graph_indicator) ? ones_like(graph, nv1) : g1.graph_indicator + ind2 = isnothing(g2.graph_indicator) ? ones_like(graph, nv2) : g2.graph_indicator + end + graph_indicator = vcat(ind1, g1.num_graphs .+ ind2) + + GNNGraph(graph, + nv1 + nv2, g1.num_edges + g2.num_edges, g1.num_graphs + g2.num_graphs, + graph_indicator, + cat_features(g1.ndata, g2.ndata), + cat_features(g1.edata, g2.edata), + cat_features(g1.gdata, g2.gdata)) +end + +# PIRACY +function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix) + m1, n1 = size(A1) + @assert m1 == n1 + m2, n2 = size(A2) + @assert m2 == n2 + O1 = fill!(similar(A1, eltype(A1), (m1, n2)), 0) + O2 = fill!(similar(A1, eltype(A1), (m2, n1)), 0) + return [A1 O1 + O2 A2] +end + +""" + blockdiag(xs::GNNGraph...) + +Equivalent to [`MLUtils.batch`](@ref). +""" +function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...) + g = g1 + for go in gothers + g = blockdiag(g, go) + end + return g +end + +""" + batch(gs::Vector{<:GNNGraph}) + +Batch together multiple `GNNGraph`s into a single one +containing the total number of original nodes and edges. + +Equivalent to [`SparseArrays.blockdiag`](@ref). +See also [`MLUtils.unbatch`](@ref). + +# Examples + +```jldoctest +julia> g1 = rand_graph(4, 6, ndata=ones(8, 4)) +GNNGraph: + num_nodes = 4 + num_edges = 6 + ndata: + x => (8, 4) + +julia> g2 = rand_graph(7, 4, ndata=zeros(8, 7)) +GNNGraph: + num_nodes = 7 + num_edges = 4 + ndata: + x => (8, 7) + +julia> g12 = MLUtils.batch([g1, g2]) +GNNGraph: + num_nodes = 11 + num_edges = 10 + num_graphs = 2 + ndata: + x => (8, 11) + +julia> g12.ndata.x +8×11 Matrix{Float64}: + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 + 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 +``` +""" +function MLUtils.batch(gs::AbstractVector{<:GNNGraph}) + Told = eltype(gs) + # try to restrict the eltype + gs = [g for g in gs] + if eltype(gs) != Told + return MLUtils.batch(gs) + else + return blockdiag(gs...) + end +end + +function MLUtils.batch(gs::AbstractVector{<:GNNGraph{T}}) where {T <: COO_T} + v_num_nodes = [g.num_nodes for g in gs] + edge_indices = [edge_index(g) for g in gs] + nodesum = cumsum([0; v_num_nodes])[1:(end - 1)] + s = cat_features([ei[1] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)]) + t = cat_features([ei[2] .+ nodesum[ii] for (ii, ei) in enumerate(edge_indices)]) + w = cat_features([get_edge_weight(g) for g in gs]) + graph = (s, t, w) + + function materialize_graph_indicator(g) + g.graph_indicator === nothing ? ones_like(s, g.num_nodes) : g.graph_indicator + end + + v_gi = materialize_graph_indicator.(gs) + v_num_graphs = [g.num_graphs for g in gs] + graphsum = cumsum([0; v_num_graphs])[1:(end - 1)] + v_gi = [ng .+ gi for (ng, gi) in zip(graphsum, v_gi)] + graph_indicator = cat_features(v_gi) + + GNNGraph(graph, + sum(v_num_nodes), + sum([g.num_edges for g in gs]), + sum(v_num_graphs), + graph_indicator, + cat_features([g.ndata for g in gs]), + cat_features([g.edata for g in gs]), + cat_features([g.gdata for g in gs])) +end + +function MLUtils.batch(g::GNNGraph) + throw(ArgumentError("Cannot batch a `GNNGraph` (containing $(g.num_graphs) graphs). Pass a vector of `GNNGraph`s instead.")) +end + + +function MLUtils.batch(gs::AbstractVector{<:GNNHeteroGraph}) + function edge_index_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + if haskey(g.graph, edge_t) + g.graph[edge_t][1:2] + else + nothing + end + end + + function get_edge_weight_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType) + get(g.graph, edge_t, (nothing, nothing, nothing))[3] + end + + @assert length(gs) > 0 + ntypes = union([g.ntypes for g in gs]...) + etypes = union([g.etypes for g in gs]...) + + v_num_nodes = Dict(node_t => [get(g.num_nodes, node_t, 0) for g in gs] for node_t in ntypes) + num_nodes = Dict(node_t => sum(v_num_nodes[node_t]) for node_t in ntypes) + num_edges = Dict(edge_t => sum(get(g.num_edges, edge_t, 0) for g in gs) for edge_t in etypes) + edge_indices = edge_indices = Dict(edge_t => [edge_index_nullable(g, edge_t) for g in gs] for edge_t in etypes) + nodesum = Dict(node_t => cumsum([0; v_num_nodes[node_t]])[1:(end - 1)] for node_t in ntypes) + graphs = [] + for edge_t in etypes + src_t, _, dst_t = edge_t + # @show edge_t edge_indices[edge_t] first(edge_indices[edge_t]) + # for ei in edge_indices[edge_t] + # @show ei[1] + # end + # # [ei[1] for (ii, ei) in enumerate(edge_indices[edge_t])] + s = cat_features([ei[1] .+ nodesum[src_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t]) if ei !== nothing]) + t = cat_features([ei[2] .+ nodesum[dst_t][ii] for (ii, ei) in enumerate(edge_indices[edge_t]) if ei !== nothing]) + w = cat_features(filter(x -> x !== nothing, [get_edge_weight_nullable(g, edge_t) for g in gs])) + push!(graphs, edge_t => (s, t, w)) + end + graph = Dict(graphs...) + + #TODO relax this restriction + @assert all(g -> g.num_graphs == 1, gs) + + s = edge_index(gs[1], gs[1].etypes[1])[1] # grab any source vector + + function materialize_graph_indicator(g, node_t) + n = get(g.num_nodes, node_t, 0) + return ones_like(s, n) + end + v_gi = Dict(node_t => [materialize_graph_indicator(g, node_t) for g in gs] for node_t in ntypes) + v_num_graphs = [g.num_graphs for g in gs] + graphsum = cumsum([0; v_num_graphs])[1:(end - 1)] + v_gi = Dict(node_t => [ng .+ gi for (ng, gi) in zip(graphsum, v_gi[node_t])] for node_t in ntypes) + graph_indicator = Dict(node_t => cat_features(v_gi[node_t]) for node_t in ntypes) + + function data_or_else(data, types) + Dict(type => get(data, type, DataStore(0)) for type in types) + end + + return GNNHeteroGraph(graph, + num_nodes, + num_edges, + sum(v_num_graphs), + graph_indicator, + cat_features([data_or_else(g.ndata, ntypes) for g in gs]), + cat_features([data_or_else(g.edata, etypes) for g in gs]), + cat_features([g.gdata for g in gs]), + ntypes, etypes) +end + +""" + unbatch(g::GNNGraph) + +Opposite of the [`MLUtils.batch`](@ref) operation, returns +an array of the individual graphs batched together in `g`. + +See also [`MLUtils.batch`](@ref) and [`getgraph`](@ref). + +# Examples + +```jldoctest +julia> gbatched = MLUtils.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)]) +GNNGraph: + num_nodes = 19 + num_edges = 16 + num_graphs = 3 + +julia> MLUtils.unbatch(gbatched) +3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}: + GNNGraph: + num_nodes = 5 + num_edges = 6 + + GNNGraph: + num_nodes = 10 + num_edges = 8 + + GNNGraph: + num_nodes = 4 + num_edges = 2 +``` +""" +function MLUtils.unbatch(g::GNNGraph{T}) where {T <: COO_T} + g.num_graphs == 1 && return [g] + + nodemasks = _unbatch_nodemasks(g.graph_indicator, g.num_graphs) + num_nodes = length.(nodemasks) + cumnum_nodes = [0; cumsum(num_nodes)] + + s, t = edge_index(g) + w = get_edge_weight(g) + + edgemasks = _unbatch_edgemasks(s, t, g.num_graphs, cumnum_nodes) + num_edges = length.(edgemasks) + @assert sum(num_edges)==g.num_edges "Error in unbatching, likely the edges are not sorted (first edges belong to the first graphs, then edges in the second graph and so on)" + + function build_graph(i) + node_mask = nodemasks[i] + edge_mask = edgemasks[i] + snew = s[edge_mask] .- cumnum_nodes[i] + tnew = t[edge_mask] .- cumnum_nodes[i] + wnew = w === nothing ? nothing : w[edge_mask] + graph = (snew, tnew, wnew) + graph_indicator = nothing + ndata = getobs(g.ndata, node_mask) + edata = getobs(g.edata, edge_mask) + gdata = getobs(g.gdata, i) + + nedges = num_edges[i] + nnodes = num_nodes[i] + ngraphs = 1 + + return GNNGraph(graph, + nnodes, nedges, ngraphs, + graph_indicator, + ndata, edata, gdata) + end + + return [build_graph(i) for i in 1:(g.num_graphs)] +end + +function MLUtils.unbatch(g::GNNGraph) + return [getgraph(g, i) for i in 1:(g.num_graphs)] +end + +function _unbatch_nodemasks(graph_indicator, num_graphs) + @assert issorted(graph_indicator) "The graph_indicator vector must be sorted." + idxslast = [searchsortedlast(graph_indicator, i) for i in 1:num_graphs] + + nodemasks = [1:idxslast[1]] + for i in 2:num_graphs + push!(nodemasks, (idxslast[i - 1] + 1):idxslast[i]) + end + return nodemasks +end + +function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes) + edgemasks = [] + for i in 1:(num_graphs - 1) + lastedgeid = findfirst(s) do x + x > cumnum_nodes[i + 1] && x <= cumnum_nodes[i + 2] + end + firstedgeid = i == 1 ? 1 : last(edgemasks[i - 1]) + 1 + # if nothing make empty range + lastedgeid = lastedgeid === nothing ? firstedgeid - 1 : lastedgeid - 1 + + push!(edgemasks, firstedgeid:lastedgeid) + end + push!(edgemasks, (last(edgemasks[end]) + 1):length(s)) + return edgemasks +end + +@non_differentiable _unbatch_nodemasks(::Any...) +@non_differentiable _unbatch_edgemasks(::Any...) + +""" + getgraph(g::GNNGraph, i; nmap=false) + +Return the subgraph of `g` induced by those nodes `j` +for which `g.graph_indicator[j] == i` or, +if `i` is a collection, `g.graph_indicator[j] ∈ i`. +In other words, it extract the component graphs from a batched graph. + +If `nmap=true`, return also a vector `v` mapping the new nodes to the old ones. +The node `i` in the subgraph will correspond to the node `v[i]` in `g`. +""" +getgraph(g::GNNGraph, i::Int; kws...) = getgraph(g, [i]; kws...) + +function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap = false) + if g.graph_indicator === nothing + @assert i == [1] + if nmap + return g, 1:(g.num_nodes) + else + return g + end + end + + node_mask = g.graph_indicator .∈ Ref(i) + + nodes = (1:(g.num_nodes))[node_mask] + nodemap = Dict(v => vnew for (vnew, v) in enumerate(nodes)) + + graphmap = Dict(i => inew for (inew, i) in enumerate(i)) + graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]] + + s, t = edge_index(g) + w = get_edge_weight(g) + edge_mask = s .∈ Ref(nodes) + + if g.graph isa COO_T + s = [nodemap[i] for i in s[edge_mask]] + t = [nodemap[i] for i in t[edge_mask]] + w = isnothing(w) ? nothing : w[edge_mask] + graph = (s, t, w) + elseif g.graph isa ADJMAT_T + graph = g.graph[nodes, nodes] + end + + ndata = getobs(g.ndata, node_mask) + edata = getobs(g.edata, edge_mask) + gdata = getobs(g.gdata, i) + + num_edges = sum(edge_mask) + num_nodes = length(graph_indicator) + num_graphs = length(i) + + gnew = GNNGraph(graph, + num_nodes, num_edges, num_graphs, + graph_indicator, + ndata, edata, gdata) + + if nmap + return gnew, nodes + else + return gnew + end +end + +""" + negative_sample(g::GNNGraph; + num_neg_edges = g.num_edges, + bidirected = is_bidirected(g)) + +Return a graph containing random negative edges (i.e. non-edges) from graph `g` as edges. + +If `bidirected=true`, the output graph will be bidirected and there will be no +leakage from the origin graph. + +See also [`is_bidirected`](@ref). +""" +function negative_sample(g::GNNGraph; + max_trials = 3, + num_neg_edges = g.num_edges, + bidirected = is_bidirected(g)) + @assert g.num_graphs == 1 + # Consider self-loops as positive edges + # Construct new graph dropping features + g = add_self_loops(GNNGraph(edge_index(g), num_nodes = g.num_nodes)) + + s, t = edge_index(g) + n = g.num_nodes + idx_pos, maxid = edge_encoding(s, t, n) + if bidirected + num_neg_edges = num_neg_edges ÷ 2 + pneg = 1 - g.num_edges / 2maxid # prob of selecting negative edge + else + pneg = 1 - g.num_edges / 2maxid # prob of selecting negative edge + end + # pneg * sample_prob * maxid == num_neg_edges + sample_prob = min(1, num_neg_edges / (pneg * maxid) * 1.1) + idx_neg = Int[] + for _ in 1:max_trials + rnd = randsubseq(1:maxid, sample_prob) + setdiff!(rnd, idx_pos) + union!(idx_neg, rnd) + if length(idx_neg) >= num_neg_edges + idx_neg = idx_neg[1:num_neg_edges] + break + end + end + s_neg, t_neg = edge_decoding(idx_neg, n) + if bidirected + s_neg, t_neg = [s_neg; t_neg], [t_neg; s_neg] + end + return GNNGraph(s_neg, t_neg, num_nodes = n) +end + +""" + rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g)) -> g1, g2 + +Randomly partition the edges in `g` to form two graphs, `g1` +and `g2`. Both will have the same number of nodes as `g`. +`g1` will contain a fraction `frac` of the original edges, +while `g2` wil contain the rest. + +If `bidirected = true` makes sure that an edge and its reverse go into the same split. +This option is supported only for bidirected graphs with no self-loops +and multi-edges. + +`rand_edge_split` is tipically used to create train/test splits in link prediction tasks. +""" +function rand_edge_split(g::GNNGraph, frac; bidirected = is_bidirected(g)) + s, t = edge_index(g) + ne = bidirected ? g.num_edges ÷ 2 : g.num_edges + eids = randperm(ne) + size1 = round(Int, ne * frac) + + if !bidirected + s1, t1 = s[eids[1:size1]], t[eids[1:size1]] + s2, t2 = s[eids[(size1 + 1):end]], t[eids[(size1 + 1):end]] + else + # @assert is_bidirected(g) + # @assert !has_self_loops(g) + # @assert !has_multi_edges(g) + mask = s .< t + s, t = s[mask], t[mask] + s1, t1 = s[eids[1:size1]], t[eids[1:size1]] + s1, t1 = [s1; t1], [t1; s1] + s2, t2 = s[eids[(size1 + 1):end]], t[eids[(size1 + 1):end]] + s2, t2 = [s2; t2], [t2; s2] + end + g1 = GNNGraph(s1, t1, num_nodes = g.num_nodes) + g2 = GNNGraph(s2, t2, num_nodes = g.num_nodes) + return g1, g2 +end + +""" + random_walk_pe(g, walk_length) + +Return the random walk positional encoding from the paper [Graph Neural Networks with Learnable Structural and Positional Representations](https://arxiv.org/abs/2110.07875) of the given graph `g` and the length of the walk `walk_length` as a matrix of size `(walk_length, g.num_nodes)`. +""" +function random_walk_pe(g::GNNGraph, walk_length::Int) + matrix = zeros(walk_length, g.num_nodes) + adj = adjacency_matrix(g, Float32; dir = :out) + matrix = dense_zeros_like(adj, Float32, (walk_length, g.num_nodes)) + deg = sum(adj, dims = 2) |> vec + deg_inv = inv.(deg) + deg_inv[isinf.(deg_inv)] .= 0 + RW = adj * Diagonal(deg_inv) + out = RW + matrix[1, :] .= diag(RW) + for i in 2:walk_length + out = out * RW + matrix[i, :] .= diag(out) + end + return matrix +end + +dense_zeros_like(a::SparseMatrixCSC, T::Type, sz = size(a)) = zeros(T, sz) +dense_zeros_like(a::AbstractArray, T::Type, sz = size(a)) = fill!(similar(a, T, sz), 0) +dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz) + +# """ +# Transform vector of cartesian indexes into a tuple of vectors containing integers. +# """ +ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims) + +@non_differentiable negative_sample(x...) +@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule +@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule +@non_differentiable dense_zeros_like(x...) + diff --git a/GNNlib/src/GNNGraphs/utils.jl b/GNNlib/src/GNNGraphs/utils.jl new file mode 100644 index 000000000..e2b821369 --- /dev/null +++ b/GNNlib/src/GNNGraphs/utils.jl @@ -0,0 +1,304 @@ +function check_num_nodes(g::GNNGraph, x::AbstractArray) + @assert g.num_nodes==size(x, ndims(x)) "Got $(size(x, ndims(x))) as last dimension size instead of num_nodes=$(g.num_nodes)" + return true +end +function check_num_nodes(g::GNNGraph, x::Union{Tuple, NamedTuple}) + map(x -> check_num_nodes(g, x), x) + return true +end + +check_num_nodes(::GNNGraph, ::Nothing) = true + +function check_num_nodes(g::GNNGraph, x::Tuple) + @assert length(x) == 2 + check_num_nodes(g, x[1]) + check_num_nodes(g, x[2]) + return true +end + +# x = (Xsrc, Xdst) = (Xj, Xi) +function check_num_nodes(g::GNNHeteroGraph, x::Tuple) + @assert length(x) == 2 + @assert length(g.etypes) == 1 + nt1, _, nt2 = only(g.etypes) + if x[1] isa AbstractArray + @assert size(x[1], ndims(x[1])) == g.num_nodes[nt1] + end + if x[2] isa AbstractArray + @assert size(x[2], ndims(x[2])) == g.num_nodes[nt2] + end + return true +end + +function check_num_edges(g::GNNGraph, e::AbstractArray) + @assert g.num_edges==size(e, ndims(e)) "Got $(size(e, ndims(e))) as last dimension size instead of num_edges=$(g.num_edges)" + return true +end +function check_num_edges(g::AbstractGNNGraph, x::Union{Tuple, NamedTuple}) + map(x -> check_num_edges(g, x), x) + return true +end + +check_num_edges(::AbstractGNNGraph, ::Nothing) = true + +function check_num_edges(g::GNNHeteroGraph, e::AbstractArray) + num_edgs = only(g.num_edges)[2] + @assert only(num_edgs)==size(e, ndims(e)) "Got $(size(e, ndims(e))) as last dimension size instead of num_edges=$(num_edgs)" + return true +end + +sort_edge_index(eindex::Tuple) = sort_edge_index(eindex...) + +function sort_edge_index(u, v) + uv = collect(zip(u, v)) + p = sortperm(uv) # isless lexicographically defined for tuples + return u[p], v[p] +end + + + +cat_features(x1::Nothing, x2::Nothing) = nothing +cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims = ndims(x1)) +function cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector}) + cat(x1, x2, dims = 1) +end + +# workaround for issue #98 #104 +# See https://github.com/JuliaStrings/InlineStrings.jl/issues/21 +# Remove when minimum supported version is julia v1.8 +cat_features(x1::NamedTuple{(), Tuple{}}, x2::NamedTuple{(), Tuple{}}) = (;) +cat_features(xs::AbstractVector{NamedTuple{(), Tuple{}}}) = (;) + +function cat_features(x1::NamedTuple, x2::NamedTuple) + sort(collect(keys(x1))) == sort(collect(keys(x2))) || + @error "cannot concatenate feature data with different keys" + + return NamedTuple(k => cat_features(x1[k], x2[k]) for k in keys(x1)) +end + +function cat_features(x1::Dict{Symbol, T}, x2::Dict{Symbol, T}) where {T} + sort(collect(keys(x1))) == sort(collect(keys(x2))) || + @error "cannot concatenate feature data with different keys" + + return Dict{Symbol, T}([k => cat_features(x1[k], x2[k]) for k in keys(x1)]...) +end + +function cat_features(x::Dict) + return Dict([k => cat_features(v) for (k, v) in pairs(x)]...) +end + + +function cat_features(xs::AbstractVector{<:AbstractArray{T, N}}) where {T <: Number, N} + cat(xs...; dims = N) +end + +cat_features(xs::AbstractVector{Nothing}) = nothing +cat_features(xs::AbstractVector{<:Number}) = xs + +function cat_features(xs::AbstractVector{<:NamedTuple}) + symbols = [sort(collect(keys(x))) for x in xs] + all(y -> y == symbols[1], symbols) || + @error "cannot concatenate feature data with different keys" + length(xs) == 1 && return xs[1] + + # concatenate + syms = symbols[1] + NamedTuple(k => cat_features([x[k] for x in xs]) for k in syms) +end + +# function cat_features(xs::AbstractVector{Dict{Symbol, T}}) where {T} +# symbols = [sort(collect(keys(x))) for x in xs] +# all(y -> y == symbols[1], symbols) || +# @error "cannot concatenate feature data with different keys" +# length(xs) == 1 && return xs[1] + +# # concatenate +# syms = symbols[1] +# return Dict{Symbol, T}([k => cat_features([x[k] for x in xs]) for k in syms]...) +# end + +function cat_features(xs::AbstractVector{<:Dict}) + _allkeys = [sort(collect(keys(x))) for x in xs] + _keys = union(_allkeys...) + length(xs) == 1 && return xs[1] + + # concatenate + return Dict([k => cat_features([x[k] for x in xs if haskey(x, k)]) for k in _keys]...) +end + + +# Used to concatenate edge weights +cat_features(w1::Nothing, w2::Nothing, n1::Int, n2::Int) = nothing +cat_features(w1::AbstractVector, w2::Nothing, n1::Int, n2::Int) = cat_features(w1, ones_like(w1, n2)) +cat_features(w1::Nothing, w2::AbstractVector, n1::Int, n2::Int) = cat_features(ones_like(w2, n1), w2) +cat_features(w1::AbstractVector, w2::AbstractVector, n1::Int, n2::Int) = cat_features(w1, w2) + + +# Turns generic type into named tuple +normalize_graphdata(data::Nothing; n, kws...) = DataStore(n) + +function normalize_graphdata(data; default_name::Symbol, kws...) + normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...) +end + +function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false) + # This had to workaround two Zygote bugs with NamedTuples + # https://github.com/FluxML/Zygote.jl/issues/1071 + # https://github.com/FluxML/Zygote.jl/issues/1072 + + if n > 1 + @assert all(x -> x isa AbstractArray, data) "Non-array features provided." + end + + if n <= 1 + # If last array dimension is not 1, add a new dimension. + # This is mostly useful to reshape global feature vectors + # of size D to Dx1 matrices. + unsqz_last(v::AbstractArray) = size(v)[end] != 1 ? reshape(v, size(v)..., 1) : v + unsqz_last(v) = v + + data = map(unsqz_last, data) + end + + if n > 0 + if duplicate_if_needed + function duplicate(v) + if v isa AbstractArray && size(v)[end] == n ÷ 2 + v = cat(v, v, dims = ndims(v)) + end + return v + end + data = map(duplicate, data) + end + + for x in data + if x isa AbstractArray + @assert size(x)[end]==n "Wrong size in last dimension for feature array, expected $n but got $(size(x)[end])." + end + end + end + + return DataStore(n, data) +end + +# For heterogeneous graphs +function normalize_heterographdata(data::Nothing; default_name::Symbol, ns::Dict, kws...) + Dict([k => normalize_graphdata(nothing; default_name = default_name, n, kws...) + for (k, n) in ns]...) +end + +normalize_heterographdata(data; kws...) = normalize_heterographdata(Dict(data); kws...) + +function normalize_heterographdata(data::Dict; default_name::Symbol, ns::Dict, kws...) + Dict([k => normalize_graphdata(get(data, k, nothing); default_name = default_name, n, kws...) + for (k, n) in ns]...) +end + +numnonzeros(a::AbstractSparseMatrix) = nnz(a) +numnonzeros(a::AbstractMatrix) = count(!=(0), a) + +# each edge is represented by a number in +# 1:N^2 +function edge_encoding(s, t, n; directed = true) + if directed + # directed edges and self-loops allowed + idx = (s .- 1) .* n .+ t + maxid = n^2 + else + # Undirected edges and self-loops allowed + maxid = n * (n + 1) ÷ 2 + + mask = s .> t + snew = copy(s) + tnew = copy(t) + snew[mask] .= t[mask] + tnew[mask] .= s[mask] + s, t = snew, tnew + + # idx = ∑_{i',i'=i'}^n 1 + ∑_{j',i<=j'<=j} 1 + # = ∑_{i',i'=i'}^n 1 + (j - i + 1) + # = ∑_{i',i'(0), x) + +@non_differentiable binarize(x...) +@non_differentiable edge_encoding(x...) +@non_differentiable edge_decoding(x...) + +### PRINTING ##### + +function shortsummary(io::IO, x) + s = shortsummary(x) + s === nothing && return + print(io, s) +end + +shortsummary(x) = summary(x) +shortsummary(x::Number) = "$x" + +function shortsummary(x::NamedTuple) + if length(x) == 0 + return nothing + elseif length(x) === 1 + return "$(keys(x)[1]) = $(shortsummary(x[1]))" + else + "(" * join(("$k = $(shortsummary(x[k]))" for k in keys(x)), ", ") * ")" + end +end + +function shortsummary(x::DataStore) + length(x) == 0 && return nothing + return "DataStore(" * join(("$k = [$(shortsummary(x[k]))]" for k in keys(x)), ", ") * + ")" +end + +# from (2,2,3) output of size function to a string "2×2×3" +function dims2string(d) + isempty(d) ? "0-dimensional" : + length(d) == 1 ? "$(d[1])-element" : + join(map(string, d), '×') +end + +@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}}) +@non_differentiable normalize_graphdata(::Nothing) + +iscuarray(x::AbstractArray) = false +@non_differentiable iscuarray(::Any) \ No newline at end of file diff --git a/GNNlib/src/GNNlib.jl b/GNNlib/src/GNNlib.jl new file mode 100644 index 000000000..ca28a7343 --- /dev/null +++ b/GNNlib/src/GNNlib.jl @@ -0,0 +1,95 @@ +module GNNlib + +using Statistics: mean +using LinearAlgebra, Random +using Base: tail +using MacroTools: @forward +using MLUtils +using NNlib +using NNlib: scatter, gather +using ChainRulesCore +using SparseArrays, Graphs # not needed but if removed Documenter will complain +using DataStructures: nlargest +using Reexport: @reexport + +include("GNNGraphs/GNNGraphs.jl") + +@reexport using .GNNGraphs + +using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, + check_num_nodes, check_num_edges, + EType, NType # for heteroconvs + +export + + +# utils + reduce_nodes, + reduce_edges, + softmax_nodes, + softmax_edges, + broadcast_nodes, + broadcast_edges, + softmax_edge_neighbors, + +# msgpass + apply_edges, + aggregate_neighbors, + propagate, + copy_xj, + copy_xi, + xi_dot_xj, + xi_sub_xj, + xj_sub_xi, + e_mul_xj, + w_mul_xj, + +# mldatasets + mldataset2gnngraph + +## The following methods are defined but not exported + +# # layers/basic +# dot_decoder, + +# # layers/conv +# agnn_conv, +# cg_conv, +# cheb_conv, +# edge_conv, +# egnn_conv, +# gat_conv, +# gatv2_conv, +# gated_graph_conv, +# gcn_conv, +# gin_conv, +# gmm_conv, +# graph_conv, +# megnet_conv, +# nn_conv, +# res_gated_graph_conv, +# sage_conv, +# sg_conv, +# transformer_conv, + +# # layers/temporalconv +# a3tgcn_conv, + +# # layers/pool +# global_pool, +# global_attention_pool, +# set2set_pool, +# topk_pool, +# topk_index, + + +include("utils.jl") +include("layers/basic.jl") +include("layers/conv.jl") +# include("layers/heteroconv.jl") # no functional part at the moment +include("layers/temporalconv.jl") +include("layers/pool.jl") +include("msgpass.jl") +include("mldatasets.jl") + +end diff --git a/GNNlib/src/layers/basic.jl b/GNNlib/src/layers/basic.jl new file mode 100644 index 000000000..c5f83e1aa --- /dev/null +++ b/GNNlib/src/layers/basic.jl @@ -0,0 +1,3 @@ +function dot_decoder(g, x) + return apply_edges(xi_dot_xj, g, xi = x, xj = x) +end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl new file mode 100644 index 000000000..c44a8210d --- /dev/null +++ b/GNNlib/src/layers/conv.jl @@ -0,0 +1,590 @@ + +check_gcnconv_input(g::AbstractGNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) = + throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs")) + +function check_gcnconv_input(g::AbstractGNNGraph, edge_weight::AbstractVector) + if length(edge_weight) !== g.num_edges + throw(ArgumentError("Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))")) + end +end + +check_gcnconv_input(g::AbstractGNNGraph, edge_weight::Nothing) = nothing + +function gcn_conv(l, g::AbstractGNNGraph, x, + edge_weight::EW = nothing, + norm_fn::Function = d -> 1 ./ sqrt.(d) + ) where {EW <: Union{Nothing, AbstractVector}} + + check_gcnconv_input(g, edge_weight) + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + # Pad weights with ones + # TODO for ADJMAT_T the new edges are not generally at the end + edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(l.weight) + if Dout < Din && !(g isa GNNHeteroGraph) + # multiply before convolution if it is more convenient, otherwise multiply after + # (this works only for homogenous graph) + x = l.weight * x + end + + xj, xi = expand_srcdst(g, x) # expand only after potential multiplication + T = eltype(xi) + + if g isa GNNHeteroGraph + din = degree(g, g.etypes[1], T; dir = :in) + dout = degree(g, g.etypes[1], T; dir = :out) + + cout = norm_fn(dout) + cin = norm_fn(din) + else + if edge_weight !== nothing + d = degree(g, T; dir = :in, edge_weight) + else + d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight) + end + cin = cout = norm_fn(d) + end + xj = xj .* cout' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj = xj, e = edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj = xj) + else + x = propagate(copy_xj, g, +, xj = xj) + end + x = x .* cin' + if Dout >= Din || g isa GNNHeteroGraph + x = l.weight * x + end + return l.σ.(x .+ l.bias) +end + +function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, + edge_weight::AbstractVector, norm_fn::Function) + + g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO + return gcn_conv(l, g, x, edge_weight, norm_fn) +end + + +function cheb_conv(c, g::GNNGraph, X::AbstractMatrix{T}) where {T} + check_num_nodes(g, X) + @assert size(X, 1)==size(c.weight, 2) "Input feature size must match input channel size." + + L̃ = scaled_laplacian(g, eltype(X)) + + Z_prev = X + Z = X * L̃ + Y = view(c.weight, :, :, 1) * Z_prev + Y += view(c.weight, :, :, 2) * Z + for k in 3:(c.k) + Z, Z_prev = 2 * Z * L̃ - Z_prev, Z + Y += view(c.weight, :, :, k) * Z + end + return Y .+ c.bias +end + +function graph_conv(l, g::AbstractGNNGraph, x) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + m = propagate(copy_xj, g, l.aggr, xj = xj) + x = l.σ.(l.weight1 * xi .+ l.weight2 * m .+ l.bias) + return x +end + +function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) + check_num_nodes(g, x) + @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" + @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" + + xj, xi = expand_srcdst(g, x) + + if l.add_self_loops + @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." + g = add_self_loops(g) + end + + _, chout = l.channel + heads = l.heads + + Wxi = Wxj = l.dense_x(xj) + Wxi = Wxj = reshape(Wxj, chout, heads, :) + + if xi !== xj + Wxi = l.dense_x(xi) + Wxi = reshape(Wxi, chout, heads, :) + end + + # a hand-written message passing + message = Fix1(gat_message, l) + m = apply_edges(message, g, Wxi, Wxj, e) + α = softmax_edge_neighbors(g, m.logα) + α = dropout(α, l.dropout) + β = α .* m.Wxj + x = aggregate_neighbors(g, +, β) + + if !l.concat + x = mean(x, dims = 2) + end + x = reshape(x, :, size(x, 3)) # return a matrix + x = l.σ.(x .+ l.bias) + + return x +end + +function gat_message(l, Wxi, Wxj, e) + _, chout = l.channel + heads = l.heads + + if e === nothing + Wxx = vcat(Wxi, Wxj) + else + We = l.dense_e(e) + We = reshape(We, chout, heads, :) # chout × nheads × nnodes + Wxx = vcat(Wxi, Wxj, We) + end + aWW = sum(l.a .* Wxx, dims = 1) # 1 × nheads × nedges + logα = leakyrelu.(aWW, l.negative_slope) + return (; logα, Wxj) +end + +function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) + check_num_nodes(g, x) + @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" + @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" + + xj, xi = expand_srcdst(g, x) + + if l.add_self_loops + @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." + g = add_self_loops(g) + end + _, out = l.channel + heads = l.heads + + Wxi = reshape(l.dense_i(xi), out, heads, :) # out × heads × nnodes + Wxj = reshape(l.dense_j(xj), out, heads, :) # out × heads × nnodes + + message = Fix1(gatv2_message, l) + m = apply_edges(message, g, Wxi, Wxj, e) + α = softmax_edge_neighbors(g, m.logα) + α = dropout(α, l.dropout) + β = α .* m.Wxj + x = aggregate_neighbors(g, +, β) + + if !l.concat + x = mean(x, dims = 2) + end + x = reshape(x, :, size(x, 3)) + x = l.σ.(x .+ l.bias) + return x +end + +function gatv2_message(l, Wxi, Wxj, e) + _, out = l.channel + heads = l.heads + + Wx = Wxi + Wxj # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?" + if e !== nothing + Wx += reshape(l.dense_e(e), out, heads, :) + end + logα = sum(l.a .* leakyrelu.(Wx, l.negative_slope), dims = 1) # 1 × heads × nedges + return (; logα, Wxj) +end + + +# TODO remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521 +@non_differentiable fill!(x...) + +function gated_graph_conv(l, g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real} + check_num_nodes(g, H) + m, n = size(H) + @assert (m<=l.out_ch) "number of input features must less or equals to output features." + if m < l.out_ch + Hpad = similar(H, S, l.out_ch - m, n) + H = vcat(H, fill!(Hpad, 0)) + end + for i in 1:(l.num_layers) + M = view(l.weight, :, :, i) * H + M = propagate(copy_xj, g, l.aggr; xj = M) + H, _ = l.gru(H, M) + end + return H +end + +function edge_conv(l, g::AbstractGNNGraph, x) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + + message = Fix1(edge_conv_message, l) + x = propagate(message, g, l.aggr, xi = xi, xj = xj, e = nothing) + return x +end + +edge_conv_message(l, xi, xj, e) = l.nn(vcat(xi, xj .- xi)) + + +function gin_conv(l, g::AbstractGNNGraph, x) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + + m = propagate(copy_xj, g, l.aggr, xj = xj) + + return l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m) +end + +function nn_conv(l, g::GNNGraph, x::AbstractMatrix, e) + check_num_nodes(g, x) + message = Fix1(nn_conv_message, l) + m = propagate(message, g, l.aggr, xj = x, e = e) + return l.σ.(l.weight * x .+ m .+ l.bias) +end + +function nn_conv_message(l, xi, xj, e) + nin, nedges = size(xj) + W = reshape(l.nn(e), (:, nin, nedges)) + xj = reshape(xj, (nin, 1, nedges)) # needed by batched_mul + m = NNlib.batched_mul(W, xj) + return reshape(m, :, nedges) +end + +function sage_conv(l, g::AbstractGNNGraph, x) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + m = propagate(copy_xj, g, l.aggr, xj = xj) + x = l.σ.(l.weight * vcat(xi, m) .+ l.bias) + return x +end + +function res_gated_graph_conv(l, g::AbstractGNNGraph, x) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + + message(xi, xj, e) = sigmoid.(xi.Ax .+ xj.Bx) .* xj.Vx + + Ax = l.A * xi + Bx = l.B * xj + Vx = l.V * xj + + m = propagate(message, g, +, xi = (; Ax), xj = (; Bx, Vx)) + + return l.σ.(l.U * xi .+ m .+ l.bias) +end + +function cg_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) + check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) + + if e !== nothing + check_num_edges(g, e) + end + + message = Fix1(cg_message, l) + m = propagate(message, g, +, xi = xi, xj = xj, e = e) + + if l.residual + if size(x, 1) == size(m, 1) + m += x + else + @warn "number of output features different from number of input features, residual not applied." + end + end + + return m +end + +function cg_message(l, xi, xj, e) + if e !== nothing + z = vcat(xi, xj, e) + else + z = vcat(xi, xj) + end + return l.dense_f(z) .* l.dense_s(z) +end + + +function agnn_conv(l, g::GNNGraph, x::AbstractMatrix) + check_num_nodes(g, x) + if l.add_self_loops + g = add_self_loops(g) + end + + xn = x ./ sqrt.(sum(x .^ 2, dims = 1)) + cos_dist = apply_edges(xi_dot_xj, g, xi = xn, xj = xn) + α = softmax_edge_neighbors(g, l.β .* cos_dist) + + x = propagate(g, +; xj = x, e = α) do xi, xj, α + α .* xj + end + + return x +end + +function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) + check_num_nodes(g, x) + + ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e + l.ϕe(vcat(xi, xj, e)) + end + + xᵉ = aggregate_neighbors(g, l.aggr, ē) + + x̄ = l.ϕv(vcat(x, xᵉ)) + + return x̄, ē +end + +function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) + (nin, ein), out = l.ch #Notational Simplicity + + @assert (ein == size(e)[1]&&g.num_edges == size(e)[2]) "Pseudo-cordinate dimension is not equal to (ein,num_edge)" + + num_edges = g.num_edges + w = reshape(e, (ein, 1, num_edges)) + mu = reshape(l.mu, (ein, l.K, 1)) + + w = @. ((w - mu)^2) / 2 + w = w .* reshape(l.sigma_inv .^ 2, (ein, l.K, 1)) + w = exp.(sum(w, dims = 1)) # (1, K, num_edge) + + xj = reshape(l.dense_x(x), (out, l.K, :)) # (out, K, num_nodes) + + m = propagate(e_mul_xj, g, mean, xj = xj, e = w) + m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes) + + m = l.σ(m .+ l.bias) + + if l.residual + if size(x, 1) == size(m, 1) + m += x + else + @warn "Residual not applied : output feature is not equal to input_feature" + end + end + + return m +end + +# this layer is not stable enough to be supported by GNNHeteroGraph type +# due to it's looping mechanism +function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T}, + edge_weight::EW = nothing) where + {T, EW <: Union{Nothing, AbstractVector}} + @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" + + if edge_weight !== nothing + @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" + end + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(l.weight) + if Dout < Din + x = l.weight * x + end + if edge_weight !== nothing + d = degree(g, T; dir = :in, edge_weight) + else + d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) + end + c = 1 ./ sqrt.(d) + for iter in 1:(l.k) + x = x .* c' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj = x) + else + x = propagate(copy_xj, g, +, xj = x) + end + x = x .* c' + end + if Dout >= Din + x = l.weight * x + end + return (x .+ l.bias) +end + +function sgc_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, + edge_weight::AbstractVector) + g = GNNGraph(edge_index(g)...; g.num_nodes) + return sgc_conv(l, g, x, edge_weight) +end + +function egnn_conv(l, g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e = nothing) + if l.num_features.edge > 0 + @assert e!==nothing "Edge features must be provided." + end + @assert size(h, 1)==l.num_features.in "Input features must match layer input size." + + x_diff = apply_edges(xi_sub_xj, g, x, x) + sqnorm_xdiff = sum(x_diff .^ 2, dims = 1) + x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1.0f-6) + + message = Fix1(egnn_message, l) + msg = apply_edges(message, g, + xi = (; h), xj = (; h), e = (; e, x_diff, sqnorm_xdiff)) + h_aggr = aggregate_neighbors(g, +, msg.h) + x_aggr = aggregate_neighbors(g, mean, msg.x) + + hnew = l.ϕh(vcat(h, h_aggr)) + if l.residual + h = h .+ hnew + else + h = hnew + end + x = x .+ x_aggr + return h, x +end + +function egnn_message(l, xi, xj, e) + if l.num_features.edge > 0 + f = vcat(xi.h, xj.h, e.sqnorm_xdiff, e.e) + else + f = vcat(xi.h, xj.h, e.sqnorm_xdiff) + end + + msg_h = l.ϕe(f) + msg_x = l.ϕx(msg_h) .* e.x_diff + return (; x = msg_x, h = msg_h) +end + +# this layer is not stable enough to be supported by GNNHeteroGraph type +# due to it's looping mechanism +function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T}, + edge_weight::EW = nothing) where + {T, EW <: Union{Nothing, AbstractVector}} + @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" + + if edge_weight !== nothing + @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" + end + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(l.weight) + if Dout < Din + x = l.weight * x + end + if edge_weight !== nothing + d = degree(g, T; dir = :in, edge_weight) + else + d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) + end + c = 1 ./ sqrt.(d) + for iter in 1:(l.k) + x = x .* c' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj = x) + else + x = propagate(copy_xj, g, +, xj = x) + end + x = x .* c' + end + if Dout >= Din + x = l.weight * x + end + return (x .+ l.bias) +end + +function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, + edge_weight::AbstractVector) + g = GNNGraph(edge_index(g)...; g.num_nodes) + return sg_conv(l, g, x, edge_weight) +end + +function transformer_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing} = nothing) + check_num_nodes(g, x) + + if l.add_self_loops + g = add_self_loops(g) + end + + out = l.channels[2] + heads = l.heads + W1x = !isnothing(l.W1) ? l.W1(x) : nothing + W2x = reshape(l.W2(x), out, heads, :) + W3x = reshape(l.W3(x), out, heads, :) + W4x = reshape(l.W4(x), out, heads, :) + W6e = !isnothing(l.W6) ? reshape(l.W6(e), out, heads, :) : nothing + + message_uij = Fix1(transformer_message_uij, l) + m = apply_edges(message_uij, g; xi = (; W3x), xj = (; W4x), e = (; W6e)) + α = softmax_edge_neighbors(g, m) + α_val = propagate(transformer_message_main, g, +; + xi = (; W3x), xj = (; W2x), e = (; W6e, α)) + + h = α_val + if l.concat + h = reshape(h, out * heads, :) # concatenate heads + else + h = mean(h, dims = 2) # average heads + h = reshape(h, out, :) + end + + if !isnothing(W1x) # root_weight + if !isnothing(l.W5) # gating + β = l.W5(vcat(h, W1x, h .- W1x)) + h = β .* W1x + (1.0f0 .- β) .* h + else + h += W1x + end + end + + if l.skip_connection + @assert size(h, 1)==size(x, 1) "In-channels must correspond to out-channels * heads if skip_connection is used" + h += x + end + if !isnothing(l.BN1) + h = l.BN1(h) + end + + if !isnothing(l.FF) + h1 = h + h = l.FF(h) + if l.skip_connection + h += h1 + end + if !isnothing(l.BN2) + h = l.BN2(h) + end + end + + return h +end + +# TODO remove l dependence +function transformer_message_uij(l, xi, xj, e) + key = xj.W4x + if !isnothing(e.W6e) + key += e.W6e + end + uij = sum(xi.W3x .* key, dims = 1) ./ l.sqrt_out + return uij +end + +function transformer_message_main(xi, xj, e) + val = xj.W2x + if !isnothing(e.W6e) + val += e.W6e + end + return e.α .* val +end diff --git a/GNNlib/src/layers/pool.jl b/GNNlib/src/layers/pool.jl new file mode 100644 index 000000000..6c7f95a6d --- /dev/null +++ b/GNNlib/src/layers/pool.jl @@ -0,0 +1,40 @@ + + +function global_pool(aggr, g::GNNGraph, x::AbstractArray) + return reduce_nodes(aggr, g, x) +end + +function global_attention_pool(fgate, ffeat, g::GNNGraph, x::AbstractArray) + α = softmax_nodes(g, fgate(x)) + feats = α .* ffeat(x) + u = reduce_nodes(+, g, feats) + return u +end + +function topk_pool(t, X::AbstractArray) + y = t.p' * X / norm(t.p) + idx = topk_index(y, t.k) + t.Ã .= view(t.A, idx, idx) + X_ = view(X, :, idx) .* σ.(view(y, idx)') + return X_ +end + +function topk_index(y::AbstractVector, k::Int) + v = nlargest(k, y) + return collect(1:length(y))[y .>= v[end]] +end + +topk_index(y::Adjoint, k::Int) = topk_index(y', k) + +function set2set_pool(lstm, num_iters, g::GNNGraph, x::AbstractMatrix) + n_in = size(x, 1) + qstar = zeros_like(x, (2*n_in, g.num_graphs)) + for t in 1:num_iters + q = lstm(qstar) # [n_in, n_graphs] + qn = broadcast_nodes(g, q) # [n_in, n_nodes] + α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes] + r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs] + qstar = vcat(q, r) # [2*n_in, n_graphs] + end + return qstar +end diff --git a/GNNlib/src/layers/temporalconv.jl b/GNNlib/src/layers/temporalconv.jl new file mode 100644 index 000000000..8cff3f033 --- /dev/null +++ b/GNNlib/src/layers/temporalconv.jl @@ -0,0 +1,12 @@ +function a3tgcn_conv(a3tgcn, g::GNNGraph, x::AbstractArray) + h = a3tgcn.tgcn(g, x) + e = a3tgcn.dense1(h) + e = a3tgcn.dense2(e) + a = softmax(e, dims = 3) + c = sum(a .* h , dims = 3) + if length(size(c)) == 3 + c = dropdims(c, dims = 3) + end + return c +end + diff --git a/GNNlib/src/mldatasets.jl b/GNNlib/src/mldatasets.jl new file mode 100644 index 000000000..1f2bf7139 --- /dev/null +++ b/GNNlib/src/mldatasets.jl @@ -0,0 +1,41 @@ +# We load a Graph Dataset from MLDatasets without explicitly depending on it + +""" + mldataset2gnngraph(dataset) + +Convert a graph dataset from the package MLDatasets.jl into one or many [`GNNGraph`](@ref)s. + +# Examples + +```jldoctest +julia> using MLDatasets, GraphNeuralNetworks + +julia> mldataset2gnngraph(Cora()) +GNNGraph: + num_nodes = 2708 + num_edges = 10556 + ndata: + features => 1433×2708 Matrix{Float32} + targets => 2708-element Vector{Int64} + train_mask => 2708-element BitVector + val_mask => 2708-element BitVector + test_mask => 2708-element BitVector +``` +""" +function mldataset2gnngraph(dataset::D) where {D} + @assert hasproperty(dataset, :graphs) + graphs = mlgraph2gnngraph.(dataset.graphs) + if length(graphs) == 1 + return graphs[1] + else + return graphs + end +end + +function mlgraph2gnngraph(g::G) where {G} + @assert hasproperty(g, :num_nodes) + @assert hasproperty(g, :edge_index) + @assert hasproperty(g, :node_data) + @assert hasproperty(g, :edge_data) + return GNNGraph(g.edge_index; ndata = g.node_data, edata = g.edge_data, g.num_nodes) +end diff --git a/GNNlib/src/msgpass.jl b/GNNlib/src/msgpass.jl new file mode 100644 index 000000000..413a60556 --- /dev/null +++ b/GNNlib/src/msgpass.jl @@ -0,0 +1,259 @@ +""" + propagate(fmsg, g, aggr [layer]; [xi, xj, e]) + propagate(fmsg, g, aggr, [layer,] xi, xj, e=nothing) + +Performs message passing on graph `g`. Takes care of materializing the node features on each edge, +applying the message function `fmsg`, and returning an aggregated message ``\\bar{\\mathbf{m}}`` +(depending on the return value of `fmsg`, an array or a named tuple of +arrays with last dimension's size `g.num_nodes`). + +If also a [`GNNLayer`](@ref) `layer` is provided, it will be passed to `fmsg` +as a first argument. + +It can be decomposed in two steps: + +```julia +m = apply_edges(fmsg, g, xi, xj, e) +m̄ = aggregate_neighbors(g, aggr, m) +``` + +GNN layers typically call `propagate` in their forward pass, +providing as input `f` a closure. + +# Arguments + +- `g`: A `GNNGraph`. +- `xi`: An array or a named tuple containing arrays whose last dimension's size + is `g.num_nodes`. It will be appropriately materialized on the + target node of each edge (see also [`edge_index`](@ref)). +- `xj`: As `xj`, but to be materialized on edges' sources. +- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. +- `fmsg`: A generic function that will be passed over to [`apply_edges`](@ref). + Has to take as inputs the edge-materialized `xi`, `xj`, and `e` + (arrays or named tuples of arrays whose last dimension' size is the size of + a batch of edges). Its output has to be an array or a named tuple of arrays + with the same batch size. If also `layer` is passed to propagate, + the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` + instead of `fmsg(xi, xj, e)`. +- `layer`: A [`GNNLayer`](@ref). If provided it will be passed to `fmsg` as a first argument. +- `aggr`: Neighborhood aggregation operator. Use `+`, `mean`, `max`, or `min`. + + +# Examples + +```julia +using GraphNeuralNetworks, Flux + +struct GNNConv <: GNNLayer + W + b + σ +end + +Flux.@functor GNNConv + +function GNNConv(ch::Pair{Int,Int}, σ=identity) + in, out = ch + W = Flux.glorot_uniform(out, in) + b = zeros(Float32, out) + GNNConv(W, b, σ) +end + +function (l::GNNConv)(g::GNNGraph, x::AbstractMatrix) + message(xi, xj, e) = l.W * xj + m̄ = propagate(message, g, +, xj=x) + return l.σ.(m̄ .+ l.bias) +end + +l = GNNConv(10 => 20) +l(g, x) +``` + +See also [`apply_edges`](@ref) and [`aggregate_neighbors`](@ref). +""" +function propagate end + +function propagate(f, g::AbstractGNNGraph, aggr; xi = nothing, xj = nothing, e = nothing) + propagate(f, g, aggr, xi, xj, e) +end + +function propagate(f, g::AbstractGNNGraph, aggr, xi, xj, e = nothing) + m = apply_edges(f, g, xi, xj, e) + m̄ = aggregate_neighbors(g, aggr, m) + return m̄ +end + +## APPLY EDGES + +""" + apply_edges(fmsg, g, [layer]; [xi, xj, e]) + apply_edges(fmsg, g, [layer,] xi, xj, e=nothing) + +Returns the message from node `j` to node `i` applying +the message function `fmsg` on the edges in graph `g`. +In the message-passing scheme, the incoming messages +from the neighborhood of `i` will later be aggregated +in order to update the features of node `i` (see [`aggregate_neighbors`](@ref)). + +The function `fmsg` operates on batches of edges, therefore +`xi`, `xj`, and `e` are tensors whose last dimension +is the batch size, or can be named tuples of +such tensors. + +If also a [`GNNLayer`](@ref) `layer` is provided, it will be passed to `fmsg` +as a first argument. + +# Arguments + +- `g`: An `AbstractGNNGraph`. +- `xi`: An array or a named tuple containing arrays whose last dimension's size + is `g.num_nodes`. It will be appropriately materialized on the + target node of each edge (see also [`edge_index`](@ref)). +- `xj`: As `xi`, but now to be materialized on each edge's source node. +- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. +- `fmsg`: A function that takes as inputs the edge-materialized `xi`, `xj`, and `e`. + These are arrays (or named tuples of arrays) whose last dimension' size is the size of + a batch of edges. The output of `f` has to be an array (or a named tuple of arrays) + with the same batch size. If also `layer` is passed to propagate, + the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` + instead of `fmsg(xi, xj, e)`. +- `layer`: A [`GNNLayer`](@ref). If provided it will be passed to `fmsg` as a first argument. + +See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref). +""" +function apply_edges end + +function apply_edges(f, g::AbstractGNNGraph; xi = nothing, xj = nothing, e = nothing) + apply_edges(f, g, xi, xj, e) +end + +function apply_edges(f, g::AbstractGNNGraph, xi, xj, e = nothing) + check_num_nodes(g, (xj, xi)) + check_num_edges(g, e) + s, t = edge_index(g) # for heterographs, errors if more than one edge type + xi = GNNGraphs._gather(xi, t) # size: (D, num_nodes) -> (D, num_edges) + xj = GNNGraphs._gather(xj, s) + m = f(xi, xj, e) + return m +end + +## AGGREGATE NEIGHBORS +@doc raw""" + aggregate_neighbors(g, aggr, m) + +Given a graph `g`, edge features `m`, and an aggregation +operator `aggr` (e.g `+, min, max, mean`), returns the new node +features +```math +\mathbf{x}_i = \square_{j \in \mathcal{N}(i)} \mathbf{m}_{j\to i} +``` + +Neighborhood aggregation is the second step of [`propagate`](@ref), +where it comes after [`apply_edges`](@ref). +""" +function aggregate_neighbors(g::GNNGraph, aggr, m) + check_num_edges(g, m) + s, t = edge_index(g) + return GNNGraphs._scatter(aggr, m, t, g.num_nodes) +end + +function aggregate_neighbors(g::GNNHeteroGraph, aggr, m) + check_num_edges(g, m) + s, t = edge_index(g) + dest_node_t = only(g.etypes)[3] + return GNNGraphs._scatter(aggr, m, t, g.num_nodes[dest_node_t]) +end + +### MESSAGE FUNCTIONS ### +""" + copy_xj(xi, xj, e) = xj +""" +copy_xj(xi, xj, e) = xj + +""" + copy_xi(xi, xj, e) = xi +""" +copy_xi(xi, xj, e) = xi + +""" + xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1) +""" +xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims = 1) + +""" + xi_sub_xj(xi, xj, e) = xi .- xj +""" +xi_sub_xj(xi, xj, e) = xi .- xj + +""" + xj_sub_xi(xi, xj, e) = xj .- xi +""" +xj_sub_xi(xi, xj, e) = xj .- xi + +""" + e_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj + +Reshape `e` into broadcast compatible shape with `xj` +(by prepending singleton dimensions) then perform +broadcasted multiplication. +""" +function e_mul_xj(xi, xj::AbstractArray{Tj, Nj}, + e::AbstractArray{Te, Ne}) where {Tj, Te, Nj, Ne} + @assert Ne <= Nj + e = reshape(e, ntuple(_ -> 1, Nj - Ne)..., size(e)...) + return e .* xj +end + +""" + w_mul_xj(xi, xj, w) = reshape(w, (...)) .* xj + +Similar to [`e_mul_xj`](@ref) but specialized on scalar edge features (weights). +""" +w_mul_xj(xi, xj::AbstractArray, w::Nothing) = xj # same as copy_xj if no weights + +function w_mul_xj(xi, xj::AbstractArray{Tj, Nj}, w::AbstractVector) where {Tj, Nj} + w = reshape(w, ntuple(_ -> 1, Nj - 1)..., length(w)) + return w .* xj +end + +###### PROPAGATE SPECIALIZATIONS #################### +## See also the methods defined in the package extensions. + +## COPY_XJ + +function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e) + A = adjacency_matrix(g, weighted = false) + return xj * A +end + +## E_MUL_XJ + +# for weighted convolution +function propagate(::typeof(e_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, + e::AbstractVector) + g = set_edge_weight(g, e) + A = adjacency_matrix(g, weighted = true) + return xj * A +end + + +## W_MUL_XJ + +# for weighted convolution +function propagate(::typeof(w_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, + e::Nothing) + A = adjacency_matrix(g, weighted = true) + return xj * A +end + + +# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) +# A = adjacency_matrix(g, weighted=false) +# D = compute_degree(A) +# return xj * A * D +# end + +# # Zygote bug. Error with sparse matrix without nograd +# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2))) + +# Flux.Zygote.@nograd compute_degree diff --git a/GNNlib/src/utils.jl b/GNNlib/src/utils.jl new file mode 100644 index 000000000..8c739f3d9 --- /dev/null +++ b/GNNlib/src/utils.jl @@ -0,0 +1,133 @@ +ofeltype(x, y) = convert(float(eltype(x)), y) + +""" + reduce_nodes(aggr, g, x) + +For a batched graph `g`, return the graph-wise aggregation of the node +features `x`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. +The returned array will have last dimension `g.num_graphs`. + +See also: [`reduce_edges`](@ref). +""" +function reduce_nodes(aggr, g::GNNGraph, x) + @assert size(x)[end] == g.num_nodes + indexes = graph_indicator(g) + return NNlib.scatter(aggr, x, indexes) +end + +""" + reduce_nodes(aggr, indicator::AbstractVector, x) + +Return the graph-wise aggregation of the node features `x` given the +graph indicator `indicator`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. + +See also [`graph_indicator`](@ref). +""" +function reduce_nodes(aggr, indicator::AbstractVector, x) + return NNlib.scatter(aggr, x, indicator) +end + +""" + reduce_edges(aggr, g, e) + +For a batched graph `g`, return the graph-wise aggregation of the edge +features `e`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. +The returned array will have last dimension `g.num_graphs`. +""" +function reduce_edges(aggr, g::GNNGraph, e) + @assert size(e)[end] == g.num_edges + s, t = edge_index(g) + indexes = graph_indicator(g)[s] + return NNlib.scatter(aggr, e, indexes) +end + +""" + softmax_nodes(g, x) + +Graph-wise softmax of the node features `x`. +""" +function softmax_nodes(g::GNNGraph, x) + @assert size(x)[end] == g.num_nodes + gi = graph_indicator(g) + max_ = gather(scatter(max, x, gi), gi) + num = exp.(x .- max_) + den = reduce_nodes(+, g, num) + den = gather(den, gi) + return num ./ den +end + +""" + softmax_edges(g, e) + +Graph-wise softmax of the edge features `e`. +""" +function softmax_edges(g::GNNGraph, e) + @assert size(e)[end] == g.num_edges + gi = graph_indicator(g, edges = true) + max_ = gather(scatter(max, e, gi), gi) + num = exp.(e .- max_) + den = reduce_edges(+, g, num) + den = gather(den, gi) + return num ./ (den .+ eps(eltype(e))) +end + +@doc raw""" + softmax_edge_neighbors(g, e) + +Softmax over each node's neighborhood of the edge features `e`. + +```math +\mathbf{e}'_{j\to i} = \frac{e^{\mathbf{e}_{j\to i}}} + {\sum_{j'\in N(i)} e^{\mathbf{e}_{j'\to i}}}. +``` +""" +function softmax_edge_neighbors(g::AbstractGNNGraph, e) + if g isa GNNHeteroGraph + for (key, value) in g.num_edges + @assert size(e)[end] == value + end + else + @assert size(e)[end] == g.num_edges + end + s, t = edge_index(g) + max_ = gather(scatter(max, e, t), t) + num = exp.(e .- max_) + den = gather(scatter(+, num, t), t) + return num ./ den +end + +""" + broadcast_nodes(g, x) + +Graph-wise broadcast array `x` of size `(*, g.num_graphs)` +to size `(*, g.num_nodes)`. +""" +function broadcast_nodes(g::GNNGraph, x) + @assert size(x)[end] == g.num_graphs + gi = graph_indicator(g) + return gather(x, gi) +end + +""" + broadcast_edges(g, x) + +Graph-wise broadcast array `x` of size `(*, g.num_graphs)` +to size `(*, g.num_edges)`. +""" +function broadcast_edges(g::GNNGraph, x) + @assert size(x)[end] == g.num_graphs + gi = graph_indicator(g, edges = true) + return gather(x, gi) +end + +expand_srcdst(g::AbstractGNNGraph, x) = throw(ArgumentError("Invalid input type, expected matrix or tuple of matrices.")) +expand_srcdst(g::AbstractGNNGraph, x::AbstractMatrix) = (x, x) +expand_srcdst(g::AbstractGNNGraph, x::Tuple{<:AbstractMatrix, <:AbstractMatrix}) = x + +# Replacement for Base.Fix1 to allow for multiple arguments +struct Fix1{F,X} + f::F + x::X +end + +(f::Fix1)(y...) = f.f(f.x, y...)